Skip to content

Commit 7700b5a

Browse files
jotasiglemaitre
authored andcommitted
[MRG+1] FIX: Use ConvergenceWarning whenver it applies (scikit-learn#10306)
1 parent 40e1536 commit 7700b5a

File tree

17 files changed

+99
-17
lines changed

17 files changed

+99
-17
lines changed

doc/whats_new/v0.20.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,17 @@ Preprocessing
383383
:issue:`10558` by :user:`Baze Petrushev <petrushev>` and
384384
:user:`Hanmin Qin <qinhanmin2014>`.
385385

386+
Misc
387+
388+
- Changed warning type from UserWarning to ConvergenceWarning for failing
389+
convergence in :func:`linear_model.logistic_regression_path`,
390+
:class:`linear_model.RANSACRegressor`, :func:`linear_model.ridge_regression`,
391+
:class:`gaussian_process.GaussianProcessRegressor`,
392+
:class:`gaussian_process.GaussianProcessClassifier`,
393+
:func:`decomposition.fastica`, :class:`cross_decomposition.PLSCanonical`,
394+
:class:`cluster.AffinityPropagation`, and :class:`cluster.Birch`.
395+
:issue:`#10306` by :user:`Jonathan Siebert <jotasi>`.
396+
386397
Changes to estimator checks
387398
---------------------------
388399

sklearn/cluster/affinity_propagation_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,5 +390,5 @@ def predict(self, X):
390390
else:
391391
warnings.warn("This model does not have any cluster centers "
392392
"because affinity propagation did not converge. "
393-
"Labeling every sample as '-1'.")
393+
"Labeling every sample as '-1'.", ConvergenceWarning)
394394
return np.array([-1] * X.shape[0])

sklearn/cluster/birch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..utils import check_array
1616
from ..utils.extmath import row_norms, safe_sparse_dot
1717
from ..utils.validation import check_is_fitted
18-
from ..exceptions import NotFittedError
18+
from ..exceptions import NotFittedError, ConvergenceWarning
1919
from .hierarchical import AgglomerativeClustering
2020

2121

@@ -626,7 +626,7 @@ def _global_clustering(self, X=None):
626626
warnings.warn(
627627
"Number of subclusters found (%d) by Birch is less "
628628
"than (%d). Decrease the threshold."
629-
% (len(centroids), self.n_clusters))
629+
% (len(centroids), self.n_clusters), ConvergenceWarning)
630630
else:
631631
# The global clustering step that clusters the subclusters of
632632
# the leaves. It assumes the centroids of the subclusters as

sklearn/cluster/tests/test_affinity_propagation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,14 @@ def test_affinity_propagation_predict_non_convergence():
133133
X = np.array([[0, 0], [1, 1], [-2, -2]])
134134

135135
# Force non-convergence by allowing only a single iteration
136-
af = AffinityPropagation(preference=-10, max_iter=1).fit(X)
136+
af = assert_warns(ConvergenceWarning,
137+
AffinityPropagation(preference=-10, max_iter=1).fit, X)
137138

138139
# At prediction time, consider new samples as noise since there are no
139140
# clusters
140-
assert_array_equal(np.array([-1, -1, -1]),
141-
af.predict(np.array([[2, 2], [3, 3], [4, 4]])))
141+
to_predict = np.array([[2, 2], [3, 3], [4, 4]])
142+
y = assert_warns(ConvergenceWarning, af.predict, to_predict)
143+
assert_array_equal(np.array([-1, -1, -1]), y)
142144

143145

144146
def test_equal_similarities_and_preferences():

sklearn/cluster/tests/test_birch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.cluster.birch import Birch
1010
from sklearn.cluster.hierarchical import AgglomerativeClustering
1111
from sklearn.datasets import make_blobs
12+
from sklearn.exceptions import ConvergenceWarning
1213
from sklearn.linear_model import ElasticNet
1314
from sklearn.metrics import pairwise_distances_argmin, v_measure_score
1415

@@ -93,7 +94,7 @@ def test_n_clusters():
9394

9495
# Test that a small number of clusters raises a warning.
9596
brc4 = Birch(threshold=10000.)
96-
assert_warns(UserWarning, brc4.fit, X)
97+
assert_warns(ConvergenceWarning, brc4.fit, X)
9798

9899

99100
def test_sparse_X():

sklearn/cross_decomposition/pls_.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ..utils import check_array, check_consistent_length
1717
from ..utils.extmath import svd_flip
1818
from ..utils.validation import check_is_fitted, FLOAT_DTYPES
19+
from ..exceptions import ConvergenceWarning
1920
from ..externals import six
2021

2122
__all__ = ['PLSCanonical', 'PLSRegression', 'PLSSVD']
@@ -74,7 +75,8 @@ def _nipals_twoblocks_inner_loop(X, Y, mode="A", max_iter=500, tol=1e-06,
7475
if np.dot(x_weights_diff.T, x_weights_diff) < tol or Y.shape[1] == 1:
7576
break
7677
if ite == max_iter:
77-
warnings.warn('Maximum number of iterations reached')
78+
warnings.warn('Maximum number of iterations reached',
79+
ConvergenceWarning)
7880
break
7981
x_weights_old = x_weights
8082
ite += 1

sklearn/cross_decomposition/tests/test_pls.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33

44
from sklearn.utils.testing import (assert_equal, assert_array_almost_equal,
55
assert_array_equal, assert_true,
6-
assert_raise_message)
6+
assert_raise_message, assert_warns)
77
from sklearn.datasets import load_linnerud
88
from sklearn.cross_decomposition import pls_, CCA
99
from sklearn.preprocessing import StandardScaler
1010
from sklearn.utils import check_random_state
11+
from sklearn.exceptions import ConvergenceWarning
1112

1213

1314
def test_pls():
@@ -260,6 +261,15 @@ def check_ortho(M, err_msg):
260261
check_ortho(pls_ca.y_scores_, "y scores are not orthogonal")
261262

262263

264+
def test_convergence_fail():
265+
d = load_linnerud()
266+
X = d.data
267+
Y = d.target
268+
pls_bynipals = pls_.PLSCanonical(n_components=X.shape[1],
269+
max_iter=2, tol=1e-10)
270+
assert_warns(ConvergenceWarning, pls_bynipals.fit, X, Y)
271+
272+
263273
def test_PLSSVD():
264274
# Let's check the PLSSVD doesn't return all possible component but just
265275
# the specified number

sklearn/decomposition/fastica_.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from scipy import linalg
1616

1717
from ..base import BaseEstimator, TransformerMixin
18+
from ..exceptions import ConvergenceWarning
1819
from ..externals import six
1920
from ..externals.six import moves
2021
from ..externals.six import string_types
@@ -116,7 +117,8 @@ def _ica_par(X, tol, g, fun_args, max_iter, w_init):
116117
break
117118
else:
118119
warnings.warn('FastICA did not converge. Consider increasing '
119-
'tolerance or the maximum number of iterations.')
120+
'tolerance or the maximum number of iterations.',
121+
ConvergenceWarning)
120122

121123
return W, ii + 1
122124

sklearn/decomposition/tests/test_fastica.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sklearn.decomposition import FastICA, fastica, PCA
1919
from sklearn.decomposition.fastica_ import _gs_decorrelation
2020
from sklearn.externals.six import moves
21+
from sklearn.exceptions import ConvergenceWarning
2122

2223

2324
def center_and_norm(x, axis=-1):
@@ -141,6 +142,31 @@ def test_fastica_nowhiten():
141142
assert_true(hasattr(ica, 'mixing_'))
142143

143144

145+
def test_fastica_convergence_fail():
146+
# Test the FastICA algorithm on very simple data
147+
# (see test_non_square_fastica).
148+
# Ensure a ConvergenceWarning raised if the tolerance is sufficiently low.
149+
rng = np.random.RandomState(0)
150+
151+
n_samples = 1000
152+
# Generate two sources:
153+
t = np.linspace(0, 100, n_samples)
154+
s1 = np.sin(t)
155+
s2 = np.ceil(np.sin(np.pi * t))
156+
s = np.c_[s1, s2].T
157+
center_and_norm(s)
158+
s1, s2 = s
159+
160+
# Mixing matrix
161+
mixing = rng.randn(6, 2)
162+
m = np.dot(mixing, s)
163+
164+
# Do fastICA with tolerance 0. to ensure failing convergence
165+
ica = FastICA(algorithm="parallel", n_components=2, random_state=rng,
166+
max_iter=2, tol=0.)
167+
assert_warns(ConvergenceWarning, ica.fit, m.T)
168+
169+
144170
def test_non_square_fastica(add_noise=False):
145171
# Test the FastICA algorithm on very simple data.
146172
rng = np.random.RandomState(0)

sklearn/gaussian_process/gpc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sklearn.utils import check_random_state
2020
from sklearn.preprocessing import LabelEncoder
2121
from sklearn.multiclass import OneVsRestClassifier, OneVsOneClassifier
22+
from sklearn.exceptions import ConvergenceWarning
2223

2324

2425
# Values required for approximating the logistic sigmoid by
@@ -428,7 +429,8 @@ def _constrained_optimization(self, obj_func, initial_theta, bounds):
428429
fmin_l_bfgs_b(obj_func, initial_theta, bounds=bounds)
429430
if convergence_dict["warnflag"] != 0:
430431
warnings.warn("fmin_l_bfgs_b terminated abnormally with the "
431-
" state: %s" % convergence_dict)
432+
" state: %s" % convergence_dict,
433+
ConvergenceWarning)
432434
elif callable(self.optimizer):
433435
theta_opt, func_min = \
434436
self.optimizer(obj_func, initial_theta, bounds=bounds)

0 commit comments

Comments
 (0)