Skip to content

Commit da66111

Browse files
adrinjalaliglemaitre
authored andcommitted
FIX introduce a refresh_cache param to fetch_... functions (scikit-learn#14197)
1 parent b27a37a commit da66111

File tree

9 files changed

+126
-11
lines changed

9 files changed

+126
-11
lines changed

doc/whats_new/v0.21.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,18 @@ Version 0.21.3
1212
Changelog
1313
---------
1414

15+
:mod:`sklearn.datasets`
16+
.......................
17+
18+
- |Fix| :func:`datasets.fetch_california_housing`,
19+
:func:`datasets.fetch_covtype`,
20+
:func:`datasets.fetch_kddcup99`, :func:`datasets.fetch_olivetti_faces`,
21+
:func:`datasets.fetch_rcv1`, and :func:`datasets.fetch_species_distributions`
22+
try to persist the previously cache using the new ``joblib`` if the cahced
23+
data was persisted using the deprecated ``sklearn.externals.joblib``. This
24+
behavior is set to be deprecated and removed in v0.23.
25+
:pr:`14197` by `Adrin Jalali`_.
26+
1527
:mod:`sklearn.impute`
1628
.....................
1729

sklearn/datasets/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import csv
1111
import sys
1212
import shutil
13+
import warnings
1314
from collections import namedtuple
1415
from os import environ, listdir, makedirs
1516
from os.path import dirname, exists, expanduser, isdir, join, splitext
@@ -919,3 +920,31 @@ def _fetch_remote(remote, dirname=None):
919920
"file may be corrupted.".format(file_path, checksum,
920921
remote.checksum))
921922
return file_path
923+
924+
925+
def _refresh_cache(files, compress):
926+
# TODO: REMOVE in v0.23
927+
import joblib
928+
msg = "sklearn.externals.joblib is deprecated in 0.21"
929+
with warnings.catch_warnings(record=True) as warns:
930+
data = tuple([joblib.load(f) for f in files])
931+
932+
refresh_needed = any([str(x.message).startswith(msg) for x in warns])
933+
934+
other_warns = [w for w in warns if not str(w.message).startswith(msg)]
935+
for w in other_warns:
936+
warnings.warn(message=w.message, category=w.category)
937+
938+
if refresh_needed:
939+
try:
940+
for value, path in zip(data, files):
941+
joblib.dump(value, path, compress=compress)
942+
except IOError:
943+
message = ("This dataset will stop being loadable in scikit-learn "
944+
"version 0.23 because it references a deprecated "
945+
"import path. Consider removing the following files "
946+
"and allowing it to be cached anew:\n%s"
947+
% ("\n".join(files)))
948+
warnings.warn(message=message, category=DeprecationWarning)
949+
950+
return data[0] if len(data) == 1 else data

sklearn/datasets/california_housing.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .base import _fetch_remote
3535
from .base import _pkl_filepath
3636
from .base import RemoteFileMetadata
37+
from .base import _refresh_cache
3738
from ..utils import Bunch
3839

3940
# The original data can be found at:
@@ -129,7 +130,9 @@ def fetch_california_housing(data_home=None, download_if_missing=True,
129130
remove(archive_path)
130131

131132
else:
132-
cal_housing = joblib.load(filepath)
133+
cal_housing = _refresh_cache([filepath], 6)
134+
# TODO: Revert to the following line in v0.23
135+
# cal_housing = joblib.load(filepath)
133136

134137
feature_names = ["MedInc", "HouseAge", "AveRooms", "AveBedrms",
135138
"Population", "AveOccup", "Latitude", "Longitude"]

sklearn/datasets/covtype.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .base import get_data_home
2626
from .base import _fetch_remote
2727
from .base import RemoteFileMetadata
28+
from .base import _refresh_cache
2829
from ..utils import Bunch
2930
from .base import _pkl_filepath
3031
from ..utils import check_random_state
@@ -125,8 +126,10 @@ def fetch_covtype(data_home=None, download_if_missing=True,
125126
try:
126127
X, y
127128
except NameError:
128-
X = joblib.load(samples_path)
129-
y = joblib.load(targets_path)
129+
X, y = _refresh_cache([samples_path, targets_path], 9)
130+
# TODO: Revert to the following two lines in v0.23
131+
# X = joblib.load(samples_path)
132+
# y = joblib.load(targets_path)
130133

131134
if shuffle:
132135
ind = np.arange(X.shape[0])

sklearn/datasets/kddcup99.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .base import _fetch_remote
2121
from .base import get_data_home
2222
from .base import RemoteFileMetadata
23+
from .base import _refresh_cache
2324
from ..utils import Bunch
2425
from ..utils import check_random_state
2526
from ..utils import shuffle as shuffle_method
@@ -292,8 +293,10 @@ def _fetch_brute_kddcup99(data_home=None,
292293
try:
293294
X, y
294295
except NameError:
295-
X = joblib.load(samples_path)
296-
y = joblib.load(targets_path)
296+
X, y = _refresh_cache([samples_path, targets_path], 0)
297+
# TODO: Revert to the following two lines in v0.23
298+
# X = joblib.load(samples_path)
299+
# y = joblib.load(targets_path)
297300

298301
return Bunch(data=X, target=y)
299302

sklearn/datasets/olivetti_faces.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .base import _fetch_remote
2525
from .base import RemoteFileMetadata
2626
from .base import _pkl_filepath
27+
from .base import _refresh_cache
2728
from ..utils import check_random_state, Bunch
2829

2930
# The original data can be found at:
@@ -107,7 +108,9 @@ def fetch_olivetti_faces(data_home=None, shuffle=False, random_state=0,
107108
joblib.dump(faces, filepath, compress=6)
108109
del mfile
109110
else:
110-
faces = joblib.load(filepath)
111+
faces = _refresh_cache([filepath], 6)
112+
# TODO: Revert to the following line in v0.23
113+
# faces = joblib.load(filepath)
111114

112115
# We want floating point data, but float32 is enough (there is only
113116
# one byte of precision in the original uint8s anyway)

sklearn/datasets/rcv1.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .base import _pkl_filepath
2323
from .base import _fetch_remote
2424
from .base import RemoteFileMetadata
25+
from .base import _refresh_cache
2526
from .svmlight_format import load_svmlight_files
2627
from ..utils import shuffle as shuffle_
2728
from ..utils import Bunch
@@ -189,8 +190,10 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True,
189190
f.close()
190191
remove(f.name)
191192
else:
192-
X = joblib.load(samples_path)
193-
sample_id = joblib.load(sample_id_path)
193+
X, sample_id = _refresh_cache([samples_path, sample_id_path], 9)
194+
# TODO: Revert to the following two lines in v0.23
195+
# X = joblib.load(samples_path)
196+
# sample_id = joblib.load(sample_id_path)
194197

195198
# load target (y), categories, and sample_id_bis
196199
if download_if_missing and (not exists(sample_topics_path) or
@@ -243,8 +246,10 @@ def fetch_rcv1(data_home=None, subset='all', download_if_missing=True,
243246
joblib.dump(y, sample_topics_path, compress=9)
244247
joblib.dump(categories, topics_path, compress=9)
245248
else:
246-
y = joblib.load(sample_topics_path)
247-
categories = joblib.load(topics_path)
249+
y, categories = _refresh_cache([sample_topics_path, topics_path], 9)
250+
# TODO: Revert to the following two lines in v0.23
251+
# y = joblib.load(sample_topics_path)
252+
# categories = joblib.load(topics_path)
248253

249254
if subset == 'all':
250255
pass

sklearn/datasets/species_distributions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from .base import RemoteFileMetadata
5252
from ..utils import Bunch
5353
from .base import _pkl_filepath
54+
from .base import _refresh_cache
5455

5556
# The original data can be found at:
5657
# https://biodiversityinformatics.amnh.org/open_source/maxent/samples.zip
@@ -259,6 +260,8 @@ def fetch_species_distributions(data_home=None,
259260
**extra_params)
260261
joblib.dump(bunch, archive_path, compress=9)
261262
else:
262-
bunch = joblib.load(archive_path)
263+
bunch = _refresh_cache([archive_path], 9)
264+
# TODO: Revert to the following line in v0.23
265+
# bunch = joblib.load(archive_path)
263266

264267
return bunch

sklearn/datasets/tests/test_base.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from functools import partial
99

1010
import pytest
11+
import joblib
1112

1213
import numpy as np
1314
from sklearn.datasets import get_data_home
@@ -23,6 +24,7 @@
2324
from sklearn.datasets import load_boston
2425
from sklearn.datasets import load_wine
2526
from sklearn.datasets.base import Bunch
27+
from sklearn.datasets.base import _refresh_cache
2628
from sklearn.datasets.tests.test_common import check_return_X_y
2729

2830
from sklearn.externals._pilutil import pillow_installed
@@ -276,3 +278,55 @@ def test_bunch_dir():
276278
# check that dir (important for autocomplete) shows attributes
277279
data = load_iris()
278280
assert "data" in dir(data)
281+
282+
283+
def test_refresh_cache(monkeypatch):
284+
# uses pytests monkeypatch fixture
285+
# https://docs.pytest.org/en/latest/monkeypatch.html
286+
287+
def _load_warn(*args, **kwargs):
288+
# raise the warning from "externals.joblib.__init__.py"
289+
# this is raised when a file persisted by the old joblib is loaded now
290+
msg = ("sklearn.externals.joblib is deprecated in 0.21 and will be "
291+
"removed in 0.23. Please import this functionality directly "
292+
"from joblib, which can be installed with: pip install joblib. "
293+
"If this warning is raised when loading pickled models, you "
294+
"may need to re-serialize those models with scikit-learn "
295+
"0.21+.")
296+
warnings.warn(msg, DeprecationWarning)
297+
return 0
298+
299+
def _load_warn_unrelated(*args, **kwargs):
300+
warnings.warn("unrelated warning", DeprecationWarning)
301+
return 0
302+
303+
def _dump_safe(*args, **kwargs):
304+
pass
305+
306+
def _dump_raise(*args, **kwargs):
307+
# this happens if the file is read-only and joblib.dump fails to write
308+
# on it.
309+
raise IOError()
310+
311+
# test if the dataset spesific warning is raised if load raises the joblib
312+
# warning, and dump fails to dump with new joblib
313+
monkeypatch.setattr(joblib, "load", _load_warn)
314+
monkeypatch.setattr(joblib, "dump", _dump_raise)
315+
msg = "This dataset will stop being loadable in scikit-learn"
316+
with pytest.warns(DeprecationWarning, match=msg):
317+
_refresh_cache('test', 0)
318+
319+
# make sure no warning is raised if load raises the warning, but dump
320+
# manages to dump the new data
321+
monkeypatch.setattr(joblib, "load", _load_warn)
322+
monkeypatch.setattr(joblib, "dump", _dump_safe)
323+
with pytest.warns(None) as warns:
324+
_refresh_cache('test', 0)
325+
assert len(warns) == 0
326+
327+
# test if an unrelated warning is still passed through and not suppressed
328+
# by _refresh_cache
329+
monkeypatch.setattr(joblib, "load", _load_warn_unrelated)
330+
monkeypatch.setattr(joblib, "dump", _dump_safe)
331+
with pytest.warns(DeprecationWarning, match="unrelated warning"):
332+
_refresh_cache('test', 0)

0 commit comments

Comments
 (0)