Skip to content

Commit 15d6f98

Browse files
committed
Merge pull request scikit-learn#4686 from bnaul/graph_lasso_tests
[MRG+1] Add enet_tol parameter to GraphLasso class/methods
2 parents ca07b2a + 61efaea commit 15d6f98

File tree

1 file changed

+34
-12
lines changed

1 file changed

+34
-12
lines changed

sklearn/covariance/graph_lasso_.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def graph_lasso(emp_cov, alpha, cov_init=None, mode='cd', tol=1e-4,
107107
108108
enet_tol : positive float, optional
109109
The tolerance for the elastic net solver used to calculate the descent
110-
direction. Only used for mode='cd'.
110+
direction. This parameter controls the accuracy of the search direction
111+
for a given column update, not of the overall parameter estimate. Only
112+
used for mode='cd'.
111113
112114
max_iter : integer, optional
113115
The maximum number of iterations.
@@ -280,6 +282,12 @@ class GraphLasso(EmpiricalCovariance):
280282
The tolerance to declare convergence: if the dual gap goes below
281283
this value, iterations are stopped.
282284
285+
enet_tol : positive float, optional
286+
The tolerance for the elastic net solver used to calculate the descent
287+
direction. This parameter controls the accuracy of the search direction
288+
for a given column update, not of the overall parameter estimate. Only
289+
used for mode='cd'.
290+
283291
max_iter : integer, default 100
284292
The maximum number of iterations.
285293
@@ -309,11 +317,12 @@ class GraphLasso(EmpiricalCovariance):
309317
graph_lasso, GraphLassoCV
310318
"""
311319

312-
def __init__(self, alpha=.01, mode='cd', tol=1e-4, max_iter=100,
313-
verbose=False, assume_centered=False):
320+
def __init__(self, alpha=.01, mode='cd', tol=1e-4, enet_tol=1e-4,
321+
max_iter=100, verbose=False, assume_centered=False):
314322
self.alpha = alpha
315323
self.mode = mode
316324
self.tol = tol
325+
self.enet_tol = enet_tol
317326
self.max_iter = max_iter
318327
self.verbose = verbose
319328
self.assume_centered = assume_centered
@@ -330,14 +339,14 @@ def fit(self, X, y=None):
330339
X, assume_centered=self.assume_centered)
331340
self.covariance_, self.precision_, self.n_iter_ = graph_lasso(
332341
emp_cov, alpha=self.alpha, mode=self.mode, tol=self.tol,
333-
max_iter=self.max_iter, verbose=self.verbose,
334-
return_n_iter=True)
342+
enet_tol=self.enet_tol, max_iter=self.max_iter,
343+
verbose=self.verbose, return_n_iter=True)
335344
return self
336345

337346

338347
# Cross-validation with GraphLasso
339348
def graph_lasso_path(X, alphas, cov_init=None, X_test=None, mode='cd',
340-
tol=1e-4, max_iter=100, verbose=False):
349+
tol=1e-4, enet_tol=1e-4, max_iter=100, verbose=False):
341350
"""l1-penalized covariance estimator along a path of decreasing alphas
342351
343352
Parameters
@@ -360,6 +369,12 @@ def graph_lasso_path(X, alphas, cov_init=None, X_test=None, mode='cd',
360369
The tolerance to declare convergence: if the dual gap goes below
361370
this value, iterations are stopped.
362371
372+
enet_tol : positive float, optional
373+
The tolerance for the elastic net solver used to calculate the descent
374+
direction. This parameter controls the accuracy of the search direction
375+
for a given column update, not of the overall parameter estimate. Only
376+
used for mode='cd'.
377+
363378
max_iter : integer, optional
364379
The maximum number of iterations.
365380
@@ -396,7 +411,7 @@ def graph_lasso_path(X, alphas, cov_init=None, X_test=None, mode='cd',
396411
# Capture the errors, and move on
397412
covariance_, precision_ = graph_lasso(
398413
emp_cov, alpha=alpha, cov_init=covariance_, mode=mode, tol=tol,
399-
max_iter=max_iter, verbose=inner_verbose)
414+
enet_tol=enet_tol, max_iter=max_iter, verbose=inner_verbose)
400415
covariances_.append(covariance_)
401416
precisions_.append(precision_)
402417
if X_test is not None:
@@ -445,6 +460,12 @@ class GraphLassoCV(GraphLasso):
445460
The tolerance to declare convergence: if the dual gap goes below
446461
this value, iterations are stopped.
447462
463+
enet_tol : positive float, optional
464+
The tolerance for the elastic net solver used to calculate the descent
465+
direction. This parameter controls the accuracy of the search direction
466+
for a given column update, not of the overall parameter estimate. Only
467+
used for mode='cd'.
468+
448469
max_iter: integer, optional
449470
Maximum number of iterations.
450471
@@ -506,12 +527,13 @@ class GraphLassoCV(GraphLasso):
506527
"""
507528

508529
def __init__(self, alphas=4, n_refinements=4, cv=None, tol=1e-4,
509-
max_iter=100, mode='cd', n_jobs=1, verbose=False,
510-
assume_centered=False):
530+
enet_tol=1e-4, max_iter=100, mode='cd', n_jobs=1,
531+
verbose=False, assume_centered=False):
511532
self.alphas = alphas
512533
self.n_refinements = n_refinements
513534
self.mode = mode
514535
self.tol = tol
536+
self.enet_tol = enet_tol
515537
self.max_iter = max_iter
516538
self.verbose = verbose
517539
self.cv = cv
@@ -572,7 +594,7 @@ def fit(self, X, y=None):
572594
delayed(graph_lasso_path)(
573595
X[train], alphas=alphas,
574596
X_test=X[test], mode=self.mode,
575-
tol=self.tol,
597+
tol=self.tol, enet_tol=self.enet_tol,
576598
max_iter=int(.1 * self.max_iter),
577599
verbose=inner_verbose)
578600
for train, test in cv)
@@ -644,6 +666,6 @@ def fit(self, X, y=None):
644666
# Finally fit the model with the selected alpha
645667
self.covariance_, self.precision_, self.n_iter_ = graph_lasso(
646668
emp_cov, alpha=best_alpha, mode=self.mode, tol=self.tol,
647-
max_iter=self.max_iter, verbose=inner_verbose,
648-
return_n_iter=True)
669+
enet_tol=self.enet_tol, max_iter=self.max_iter,
670+
verbose=inner_verbose, return_n_iter=True)
649671
return self

0 commit comments

Comments
 (0)