|
28 | 28 | from ..base import BaseEstimator, TransformerMixin |
29 | 29 | from ..externals.six.moves import xrange |
30 | 30 | from ..utils import array2d, check_arrays, check_random_state |
31 | | -from ..utils.extmath import fast_logdet, fast_dot, randomized_svd |
| 31 | +from ..utils.extmath import fast_logdet, fast_dot, randomized_svd, squared_norm |
32 | 32 | from ..utils import ConvergenceWarning |
33 | 33 |
|
34 | 34 |
|
@@ -188,15 +188,15 @@ def fit(self, X, y=None): |
188 | 188 | def my_svd(X): |
189 | 189 | _, s, V = linalg.svd(X, full_matrices=False) |
190 | 190 | return (s[:n_components], V[:n_components], |
191 | | - np.dot(s[n_components:].flat, s[n_components:].flat)) |
| 191 | + squared_norm(s[n_components:])) |
192 | 192 | elif self.svd_method == 'randomized': |
193 | 193 | random_state = check_random_state(self.random_state) |
194 | 194 |
|
195 | 195 | def my_svd(X): |
196 | 196 | _, s, V = randomized_svd(X, n_components, |
197 | 197 | random_state=random_state, |
198 | 198 | n_iter=self.iterated_power) |
199 | | - return s, V, np.dot(X.flat, X.flat) - np.dot(s, s) |
| 199 | + return s, V, squared_norm(X) - squared_norm(s) |
200 | 200 | else: |
201 | 201 | raise ValueError('SVD method %s is not supported. Please consider' |
202 | 202 | ' the documentation' % self.svd_method) |
|
0 commit comments