1111from sklearn .utils .testing import assert_less
1212from sklearn .utils .testing import assert_raises
1313from sklearn .utils .testing import assert_almost_equal
14+ from sklearn .utils .testing import assert_array_almost_equal
1415from sklearn .utils import ConvergenceWarning
15-
1616from sklearn .decomposition import FactorAnalysis
1717
1818
@@ -62,6 +62,7 @@ def test_factor_analysis():
6262 noise_variance_init = np .ones (n_features ))
6363 assert_raises (ValueError , fa .fit , X [:, :2 ])
6464
65+
6566 f = lambda x , y : np .abs (getattr (x , y )) # sign will not be equal
6667 fa1 , fa2 = fas
6768 for attr in ['loglike_' , 'components_' , 'noise_variance_' ]:
@@ -76,3 +77,26 @@ def test_factor_analysis():
7677 warnings .simplefilter ('always' , DeprecationWarning )
7778 FactorAnalysis (verbose = 1 )
7879 assert_true (w [- 1 ].category == DeprecationWarning )
80+
81+ fa2 = FactorAnalysis (n_components = n_components ,
82+ noise_variance_init = np .ones (n_features ))
83+ assert_raises (ValueError , fa2 .fit , X [:, :2 ])
84+
85+ # Test get_covariance and get_precision with n_components < n_features
86+ cov = fa .get_covariance ()
87+ precision = fa .get_precision ()
88+ assert_array_almost_equal (np .dot (cov , precision ), np .eye (X .shape [1 ]), 12 )
89+
90+ # Test get_covariance and get_precision with n_components == n_features
91+ fa .n_components = n_features
92+ fa .fit (X )
93+ cov = fa .get_covariance ()
94+ precision = fa .get_precision ()
95+ assert_array_almost_equal (np .dot (cov , precision ), np .eye (X .shape [1 ]), 12 )
96+
97+ # Test get_covariance and get_precision with n_components == 0
98+ fa .n_components = 0
99+ fa .fit (X )
100+ cov = fa .get_covariance ()
101+ precision = fa .get_precision ()
102+ assert_array_almost_equal (np .dot (cov , precision ), np .eye (X .shape [1 ]), 12 )
0 commit comments