Skip to content

Commit 84bd4e2

Browse files
thomasjpfanogrisel
andauthored
FIX Fixes test_scale_and_stability (scikit-learn#18746)
Co-authored-by: Olivier Grisel <[email protected]>
1 parent 9cd19ba commit 84bd4e2

File tree

3 files changed

+52
-47
lines changed

3 files changed

+52
-47
lines changed

doc/whats_new/v0.24.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ Changelog
132132
predictions for `est.transform(Y)` when the training data is single-target.
133133
:pr:`17095` by `Nicolas Hug`_.
134134

135+
- |Fix| Increases the stability of :class:`cross_decomposition.CCA` :pr:`18746`
136+
by `Thomas Fan`_.
137+
135138
- |API| For :class:`cross_decomposition.NMF`,
136139
the `init` value, when 'init=None' and
137140
n_components <= min(n_samples, n_features) will be changed from

sklearn/cross_decomposition/_pls.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from scipy.linalg import pinv2, svd
1313

1414
from ..base import BaseEstimator, RegressorMixin, TransformerMixin
15-
from ..base import _UnstableArchMixin
1615
from ..base import MultiOutputMixin
1716
from ..utils import check_array, check_consistent_length
1817
from ..utils.extmath import svd_flip
@@ -45,8 +44,8 @@ def _get_first_singular_vectors_power_method(X, Y, mode="A", max_iter=500,
4544
# As a result, and as detailed in the Wegelin's review, CCA (i.e. mode
4645
# B) will be unstable if n_features > n_samples or n_targets >
4746
# n_samples
48-
X_pinv = pinv2(X, check_finite=False)
49-
Y_pinv = pinv2(Y, check_finite=False)
47+
X_pinv = pinv2(X, check_finite=False, cond=10*eps)
48+
Y_pinv = pinv2(Y, check_finite=False, cond=10*eps)
5049

5150
for i in range(max_iter):
5251
if mode == "B":
@@ -683,7 +682,7 @@ def __init__(self, n_components=2, *, scale=True, algorithm="nipals",
683682
max_iter=max_iter, tol=tol, copy=copy)
684683

685684

686-
class CCA(_UnstableArchMixin, _PLS):
685+
class CCA(_PLS):
687686
"""Canonical Correlation Analysis, also known as "Mode B" PLS.
688687
689688
Read more in the :ref:`User Guide <cross_decomposition>`.

sklearn/cross_decomposition/tests/test_pls.py

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -382,54 +382,57 @@ def test_copy(Est):
382382
pls.predict(X.copy(), copy=False))
383383

384384

385-
@pytest.mark.xfail
386-
@pytest.mark.parametrize('Est', (CCA, PLSCanonical, PLSRegression, PLSSVD))
387-
def test_scale_and_stability(Est):
388-
# scale=True is equivalent to scale=False on centered/scaled data
389-
# This allows to check numerical stability over platforms as well
390-
385+
def _generate_test_scale_and_stability_datasets():
386+
"""Generate dataset for test_scale_and_stability"""
387+
# dataset for non-regression 7818
391388
rng = np.random.RandomState(0)
392-
393-
d = load_linnerud()
394-
X1 = d.data
395-
Y1 = d.target
396-
# causes X[:, -1].std() to be zero
397-
X1[:, -1] = 1.0
398-
399-
# From bug #2821
400-
# Test with X2, Y2 s.t. clf.x_score[:, 1] == 0, clf.y_score[:, 1] == 0
401-
# This test robustness of algorithm when dealing with value close to 0
402-
X2 = np.array([[0., 0., 1.],
403-
[1., 0., 0.],
404-
[2., 2., 2.],
405-
[3., 5., 4.]])
406-
Y2 = np.array([[0.1, -0.2],
407-
[0.9, 1.1],
408-
[6.2, 5.9],
409-
[11.9, 12.3]])
410-
411-
# Non-regression for https://github.com/scikit-learn/scikit-learn/pull/7819
412389
n_samples = 1000
413390
n_targets = 5
414391
n_features = 10
415392
Q = rng.randn(n_targets, n_features)
416-
Y3 = rng.randn(n_samples, n_targets)
417-
X3 = np.dot(Y3, Q) + 2 * rng.randn(n_samples, n_features) + 1
418-
X3 *= 1000
419-
420-
for (X, Y) in [(X1, Y1), (X2, Y2), (X3, Y3)]:
421-
X_std = X.std(axis=0, ddof=1)
422-
X_std[X_std == 0] = 1
423-
Y_std = Y.std(axis=0, ddof=1)
424-
Y_std[Y_std == 0] = 1
425-
X_s = (X - X.mean(axis=0)) / X_std
426-
Y_s = (Y - Y.mean(axis=0)) / Y_std
427-
428-
X_score, Y_score = Est(scale=True).fit_transform(X, Y)
429-
X_s_score, Y_s_score = Est(scale=False).fit_transform(X_s, Y_s)
430-
431-
assert_array_almost_equal(X_s_score, X_score)
432-
assert_array_almost_equal(Y_s_score, Y_score)
393+
Y = rng.randn(n_samples, n_targets)
394+
X = np.dot(Y, Q) + 2 * rng.randn(n_samples, n_features) + 1
395+
X *= 1000
396+
yield X, Y
397+
398+
# Data set where one of the features is constaint
399+
X, Y = load_linnerud(return_X_y=True)
400+
# causes X[:, -1].std() to be zero
401+
X[:, -1] = 1.0
402+
yield X, Y
403+
404+
X = np.array([[0., 0., 1.],
405+
[1., 0., 0.],
406+
[2., 2., 2.],
407+
[3., 5., 4.]])
408+
Y = np.array([[0.1, -0.2],
409+
[0.9, 1.1],
410+
[6.2, 5.9],
411+
[11.9, 12.3]])
412+
yield X, Y
413+
414+
# Seeds that provide a non-regression test for #18746, where CCA fails
415+
seeds = [530, 741]
416+
for seed in seeds:
417+
rng = np.random.RandomState(seed)
418+
X = rng.randn(4, 3)
419+
Y = rng.randn(4, 2)
420+
yield X, Y
421+
422+
423+
@pytest.mark.parametrize('Est', (CCA, PLSCanonical, PLSRegression, PLSSVD))
424+
@pytest.mark.parametrize('X, Y', _generate_test_scale_and_stability_datasets())
425+
def test_scale_and_stability(Est, X, Y):
426+
"""scale=True is equivalent to scale=False on centered/scaled data
427+
This allows to check numerical stability over platforms as well"""
428+
429+
X_s, Y_s, *_ = _center_scale_xy(X, Y)
430+
431+
X_score, Y_score = Est(scale=True).fit_transform(X, Y)
432+
X_s_score, Y_s_score = Est(scale=False).fit_transform(X_s, Y_s)
433+
434+
assert_allclose(X_s_score, X_score, atol=1e-4)
435+
assert_allclose(Y_s_score, Y_score, atol=1e-4)
433436

434437

435438
@pytest.mark.parametrize('Est', (PLSSVD, PLSCanonical, CCA))

0 commit comments

Comments
 (0)