Skip to content

Commit 0872e9a

Browse files
0xs1djeremiedbb
andauthored
TST use global_random_seed in sklearn/covariance/tests/test_graphical… (scikit-learn#31692)
Co-authored-by: Jérémie du Boisberranger <[email protected]>
1 parent e971134 commit 0872e9a

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

sklearn/covariance/tests/test_graphical_lasso.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def test_graphical_lassos(random_state=1):
7474
assert_array_almost_equal(precs[0], precs[1])
7575

7676

77-
def test_graphical_lasso_when_alpha_equals_0():
77+
def test_graphical_lasso_when_alpha_equals_0(global_random_seed):
7878
"""Test graphical_lasso's early return condition when alpha=0."""
79-
X = np.random.randn(100, 10)
79+
X = np.random.RandomState(global_random_seed).randn(100, 10)
8080
emp_cov = empirical_covariance(X, assume_centered=True)
8181

8282
model = GraphicalLasso(alpha=0, covariance="precomputed").fit(emp_cov)
@@ -170,11 +170,11 @@ def test_graphical_lasso_iris_singular():
170170
assert_array_almost_equal(icov, icov_R, decimal=5)
171171

172172

173-
def test_graphical_lasso_cv(random_state=1):
173+
def test_graphical_lasso_cv(global_random_seed):
174174
# Sample data from a sparse multivariate normal
175175
dim = 5
176176
n_samples = 6
177-
random_state = check_random_state(random_state)
177+
random_state = np.random.RandomState(global_random_seed)
178178
prec = make_sparse_spd_matrix(dim, alpha=0.96, random_state=random_state)
179179
cov = linalg.inv(prec)
180180
X = random_state.multivariate_normal(np.zeros(dim), cov, size=n_samples)
@@ -237,7 +237,7 @@ def test_graphical_lasso_cv_alphas_invalid_array(alphas, err_type, err_msg):
237237
GraphicalLassoCV(alphas=alphas, tol=1e-1, n_jobs=1).fit(X)
238238

239239

240-
def test_graphical_lasso_cv_scores():
240+
def test_graphical_lasso_cv_scores(global_random_seed):
241241
splits = 4
242242
n_alphas = 5
243243
n_refinements = 3
@@ -249,7 +249,7 @@ def test_graphical_lasso_cv_scores():
249249
[0.0, 0.0, 0.1, 0.7],
250250
]
251251
)
252-
rng = np.random.RandomState(0)
252+
rng = np.random.RandomState(global_random_seed)
253253
X = rng.multivariate_normal(mean=[0, 0, 0, 0], cov=true_cov, size=200)
254254
cov = GraphicalLassoCV(cv=splits, alphas=n_alphas, n_refinements=n_refinements).fit(
255255
X

0 commit comments

Comments
 (0)