Skip to content

Commit c7d6f9f

Browse files
committed
Test row_norms for float32 data
1 parent 65e91f7 commit c7d6f9f

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

sklearn/utils/tests/test_extmath.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,23 @@ def test_norm_squared_norm():
148148

149149
def test_row_norms():
150150
X = np.random.RandomState(42).randn(100, 100)
151-
sq_norm = (X ** 2).sum(axis=1)
152-
153-
assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 5)
154-
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X))
155-
156-
Xcsr = sparse.csr_matrix(X, dtype=np.float32)
157-
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True), 5)
158-
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr))
151+
for dtype in (np.float32, np.float64):
152+
if dtype is np.float32:
153+
precision = 4
154+
else:
155+
precision = 5
156+
157+
X = X.astype(dtype)
158+
sq_norm = (X ** 2).sum(axis=1)
159+
160+
assert_array_almost_equal(sq_norm, row_norms(X, squared=True),
161+
precision)
162+
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X), precision)
163+
164+
Xcsr = sparse.csr_matrix(X, dtype=dtype)
165+
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True),
166+
precision)
167+
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr), precision)
159168

160169

161170
def test_randomized_svd_low_rank_with_noise():

0 commit comments

Comments
 (0)