Skip to content

Commit 1be9335

Browse files
committed
ENH : add get_precision method with matrix inverse lemma to FactorAnalysis + use precision in score
1 parent daca399 commit 1be9335

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

sklearn/decomposition/factor_analysis.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,33 @@ def get_covariance(self):
267267
cov.flat[::len(cov) + 1] += self.noise_variance_ # modify diag inplace
268268
return cov
269269

270+
def get_precision(self):
271+
"""Compute data precision matrix with the FactorAnalysis model.
272+
273+
Returns
274+
-------
275+
precision : array, shape=(n_features, n_features)
276+
Estimated precision of data.
277+
"""
278+
n_features = self.components_.shape[1]
279+
280+
# handle corner cases first
281+
if self.n_components == 0:
282+
return np.diag(1. / self.noise_variance_)
283+
if self.n_components == n_features:
284+
return linalg.inv(self.get_covariance())
285+
286+
# Get precision using matrix inversion lemma
287+
components_ = self.components_
288+
precision = np.dot(components_ / self.noise_variance_, components_.T)
289+
precision.flat[::len(precision) + 1] += 1.
290+
precision = np.dot(components_.T,
291+
np.dot(linalg.inv(precision), components_))
292+
precision /= self.noise_variance_[:, np.newaxis]
293+
precision /= -self.noise_variance_[np.newaxis, :]
294+
precision.flat[::len(precision) + 1] += 1. / self.noise_variance_
295+
return precision
296+
270297
def score(self, X, y=None):
271298
"""Compute score of X under FactorAnalysis model.
272299
@@ -281,10 +308,10 @@ def score(self, X, y=None):
281308
log-likelihood of each row of X under the current model
282309
"""
283310
Xr = X - self.mean_
284-
cov = self.get_covariance()
311+
precision = self.get_precision()
285312
n_features = X.shape[1]
286313
log_like = np.zeros(X.shape[0])
287-
self.precision_ = linalg.inv(cov)
288-
log_like = -.5 * (Xr * (fast_dot(Xr, self.precision_))).sum(axis=1)
289-
log_like -= .5 * (fast_logdet(cov) + n_features * log(2. * np.pi))
314+
self.precision_ = precision # should not store it I guess...
315+
log_like = -.5 * (Xr * (np.dot(Xr, precision))).sum(axis=1)
316+
log_like -= .5 * (-fast_logdet(precision) + n_features * log(2. * np.pi))
290317
return log_like

sklearn/decomposition/tests/test_factor_analysis.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from sklearn.utils.testing import assert_less
1212
from sklearn.utils.testing import assert_raises
1313
from sklearn.utils.testing import assert_almost_equal
14+
from sklearn.utils.testing import assert_array_almost_equal
1415
from sklearn.utils import ConvergenceWarning
15-
1616
from sklearn.decomposition import FactorAnalysis
1717

1818

@@ -62,6 +62,7 @@ def test_factor_analysis():
6262
noise_variance_init=np.ones(n_features))
6363
assert_raises(ValueError, fa.fit, X[:, :2])
6464

65+
6566
f = lambda x, y: np.abs(getattr(x, y)) # sign will not be equal
6667
fa1, fa2 = fas
6768
for attr in ['loglike_', 'components_', 'noise_variance_']:
@@ -76,3 +77,26 @@ def test_factor_analysis():
7677
warnings.simplefilter('always', DeprecationWarning)
7778
FactorAnalysis(verbose=1)
7879
assert_true(w[-1].category == DeprecationWarning)
80+
81+
fa2 = FactorAnalysis(n_components=n_components,
82+
noise_variance_init=np.ones(n_features))
83+
assert_raises(ValueError, fa2.fit, X[:, :2])
84+
85+
# Test get_covariance and get_precision with n_components < n_features
86+
cov = fa.get_covariance()
87+
precision = fa.get_precision()
88+
assert_array_almost_equal(np.dot(cov, precision), np.eye(X.shape[1]), 12)
89+
90+
# Test get_covariance and get_precision with n_components == n_features
91+
fa.n_components = n_features
92+
fa.fit(X)
93+
cov = fa.get_covariance()
94+
precision = fa.get_precision()
95+
assert_array_almost_equal(np.dot(cov, precision), np.eye(X.shape[1]), 12)
96+
97+
# Test get_covariance and get_precision with n_components == 0
98+
fa.n_components = 0
99+
fa.fit(X)
100+
cov = fa.get_covariance()
101+
precision = fa.get_precision()
102+
assert_array_almost_equal(np.dot(cov, precision), np.eye(X.shape[1]), 12)

0 commit comments

Comments
 (0)