Skip to content

Commit fa82eee

Browse files
dalmiajnothman
authored andcommitted
[MRG+1] Better error message for GPR (scikit-learn#8386)
* FIX: better error message for GPR * TST: added tests for error message * FIX: correct error message and simplified tests * DOC: updated docstring for GPR * FIX: pass float as value for alpha * Raise original error after modification to preserve traceback Amend test accordingly
1 parent 6579220 commit fa82eee

File tree

2 files changed

+32
-10
lines changed

2 files changed

+32
-10
lines changed

sklearn/gaussian_process/gpr.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,14 @@ class GaussianProcessRegressor(BaseEstimator, RegressorMixin):
4747
4848
alpha : float or array-like, optional (default: 1e-10)
4949
Value added to the diagonal of the kernel matrix during fitting.
50-
Larger values correspond to increased noise level in the observations
51-
and reduce potential numerical issue during fitting. If an array is
52-
passed, it must have the same number of entries as the data used for
53-
fitting and is used as datapoint-dependent noise level. Note that this
54-
is equivalent to adding a WhiteKernel with c=alpha. Allowing to specify
55-
the noise level directly as a parameter is mainly for convenience and
56-
for consistency with Ridge.
50+
Larger values correspond to increased noise level in the observations.
51+
This can also prevent a potential numerical issue during fitting, by
52+
ensuring that the calculated values form a positive definite matrix.
53+
If an array is passed, it must have the same number of entries as the
54+
data used for fitting and is used as datapoint-dependent noise level.
55+
Note that this is equivalent to adding a WhiteKernel with c=alpha.
56+
Allowing to specify the noise level directly as a parameter is mainly
57+
for convenience and for consistency with Ridge.
5758
5859
optimizer : string or callable, optional (default: "fmin_l_bfgs_b")
5960
Can either be one of the internally supported optimizers for optimizing
@@ -242,9 +243,16 @@ def obj_func(theta, eval_gradient=True):
242243
# of actual query points
243244
K = self.kernel_(self.X_train_)
244245
K[np.diag_indices_from(K)] += self.alpha
245-
self.L_ = cholesky(K, lower=True) # Line 2
246+
try:
247+
self.L_ = cholesky(K, lower=True) # Line 2
248+
except np.linalg.LinAlgError as exc:
249+
exc.args = ("The kernel, %s, is not returning a "
250+
"positive definite matrix. Try gradually "
251+
"increasing the 'alpha' parameter of your "
252+
"GaussianProcessRegressor estimator."
253+
% self.kernel_,) + exc.args
254+
raise
246255
self.alpha_ = cho_solve((self.L_, True), self.y_train_) # Line 3
247-
248256
return self
249257

250258
def predict(self, X, return_std=False, return_cov=False):

sklearn/gaussian_process/tests/test_gpr.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from sklearn.gaussian_process import GaussianProcessRegressor
1111
from sklearn.gaussian_process.kernels \
1212
import RBF, ConstantKernel as C, WhiteKernel
13+
from sklearn.gaussian_process.kernels import DotProduct
1314

1415
from sklearn.utils.testing \
1516
import (assert_true, assert_greater, assert_array_less,
16-
assert_almost_equal, assert_equal)
17+
assert_almost_equal, assert_equal, assert_raise_message)
1718

1819

1920
def f(x):
@@ -290,6 +291,19 @@ def optimizer(obj_func, initial_theta, bounds):
290291
gpr.log_marginal_likelihood(gpr.kernel.theta))
291292

292293

294+
def test_gpr_correct_error_message():
295+
X = np.arange(12).reshape(6, -1)
296+
y = np.ones(6)
297+
kernel = DotProduct()
298+
gpr = GaussianProcessRegressor(kernel=kernel, alpha=0.0)
299+
assert_raise_message(np.linalg.LinAlgError,
300+
"The kernel, %s, is not returning a "
301+
"positive definite matrix. Try gradually increasing "
302+
"the 'alpha' parameter of your "
303+
"GaussianProcessRegressor estimator."
304+
% kernel, gpr.fit, X, y)
305+
306+
293307
def test_duplicate_input():
294308
# Test GPR can handle two different output-values for the same input.
295309
for kernel in kernels:

0 commit comments

Comments
 (0)