Skip to content

Commit f17ebac

Browse files
Vincent SchutVincent Schut
authored andcommitted
added a converged_ attribute to GMM to indicate whether fit() returned because of convergence or because max_iter was reached.
1 parent 8ef9ce8 commit f17ebac

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

scikits/learn/mixture.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def __init__(self, n_states=1, cvtype='diag'):
235235

236236
self.weights = np.ones(self._n_states) / self._n_states
237237

238+
self.converged_ = False
239+
238240
# Read-only properties.
239241
@property
240242
def cvtype(self):
@@ -415,7 +417,7 @@ def rvs(self, n_samples=1):
415417
# occurrences of current component in obs
416418
comp_in_obs = (comp==comps)
417419
# number of those occurrences
418-
num_comp_in_obs = comp_in_obs.sum()
420+
num_comp_in_obs = comp_in_obs.sum()
419421
if num_comp_in_obs > 0:
420422
if self._cvtype == 'tied':
421423
cv = self._covars
@@ -496,6 +498,7 @@ def fit(self, X, n_iter=10, min_covar=1e-3, thresh=1e-2, params='wmc',
496498

497499
# Check for convergence.
498500
if i > 0 and abs(logprob[-1] - logprob[-2]) < thresh:
501+
self.converged_ = True
499502
break
500503

501504
# Maximization step

0 commit comments

Comments
 (0)