Skip to content

Commit 60e7e15

Browse files
qinhanmin2014ogrisel
authored andcommitted
[MRG+1] BUG Avoid unexpected error in PCA when n_components='mle' (scikit-learn#9886)
* n_components mle * update doc * improve * update what's new * update what's new
1 parent a0dfa30 commit 60e7e15

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

doc/whats_new/v0.20.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ Decomposition, manifold learning and clustering
107107
- Fixed a bug in :func:`datasets.fetch_kddcup99`, where data were not properly
108108
shuffled. :issue:`9731` by `Nicolas Goix`_.
109109

110+
- Fixed a bug in :class:`decomposition.PCA` where users will get unexpected error
111+
with large datasets when ``n_components='mle'`` on Python 3 versions.
112+
:issue:`9886` by :user:`Hanmin Qin <qinhanmin2014>`.
113+
110114
Metrics
111115

112116
- Fixed a bug due to floating point error in :func:`metrics.roc_auc_score` with

sklearn/decomposition/pca.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,18 @@ class PCA(_BasePCA):
130130
131131
n_components == min(n_samples, n_features)
132132
133-
if n_components == 'mle' and svd_solver == 'full', Minka\'s MLE is used
134-
to guess the dimension
135-
if ``0 < n_components < 1`` and svd_solver == 'full', select the number
136-
of components such that the amount of variance that needs to be
133+
If ``n_components == 'mle'`` and ``svd_solver == 'full'``, Minka\'s
134+
MLE is used to guess the dimension. Use of ``n_components == 'mle'``
135+
will interpret ``svd_solver == 'auto'`` as ``svd_solver == 'full'``.
136+
137+
If ``0 < n_components < 1`` and ``svd_solver == 'full'``, select the
138+
number of components such that the amount of variance that needs to be
137139
explained is greater than the percentage specified by n_components.
138-
If svd_solver == 'arpack', the number of components must be strictly
139-
less than the minimum of n_features and n_samples.
140-
Hence, the None case results in:
140+
141+
If ``svd_solver == 'arpack'``, the number of components must be
142+
strictly less than the minimum of n_features and n_samples.
143+
144+
Hence, the None case results in::
141145
142146
n_components == min(n_samples, n_features) - 1
143147
@@ -386,8 +390,8 @@ def _fit(self, X):
386390
# Handle svd_solver
387391
svd_solver = self.svd_solver
388392
if svd_solver == 'auto':
389-
# Small problem, just call full PCA
390-
if max(X.shape) <= 500:
393+
# Small problem or n_components == 'mle', just call full PCA
394+
if max(X.shape) <= 500 or n_components == 'mle':
391395
svd_solver = 'full'
392396
elif n_components >= 1 and n_components < .8 * min(X.shape):
393397
svd_solver = 'randomized'

sklearn/decomposition/tests/test_pca.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sklearn.utils.testing import assert_true
88
from sklearn.utils.testing import assert_equal
99
from sklearn.utils.testing import assert_greater
10+
from sklearn.utils.testing import assert_raise_message
1011
from sklearn.utils.testing import assert_raises
1112
from sklearn.utils.testing import assert_raises_regex
1213
from sklearn.utils.testing import assert_no_warnings
@@ -453,6 +454,26 @@ def test_randomized_pca_inverse():
453454
assert_less(relative_max_delta, 1e-5)
454455

455456

457+
def test_n_components_mle():
458+
# Ensure that n_components == 'mle' doesn't raise error for auto/full
459+
# svd_solver and raises error for arpack/randomized svd_solver
460+
rng = np.random.RandomState(0)
461+
n_samples = 600
462+
n_features = 10
463+
X = rng.randn(n_samples, n_features)
464+
n_components_dict = {}
465+
for solver in solver_list:
466+
pca = PCA(n_components='mle', svd_solver=solver)
467+
if solver in ['auto', 'full']:
468+
pca.fit(X)
469+
n_components_dict[solver] = pca.n_components_
470+
else: # arpack/randomized solver
471+
error_message = ("n_components='mle' cannot be a string with "
472+
"svd_solver='{}'".format(solver))
473+
assert_raise_message(ValueError, error_message, pca.fit, X)
474+
assert_equal(n_components_dict['auto'], n_components_dict['full'])
475+
476+
456477
def test_pca_dim():
457478
# Check automated dimensionality setting
458479
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)