Skip to content

Commit 48c2c15

Browse files
committed
Improve the GMM PDF example.
1 parent 8a0b317 commit 48c2c15

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

examples/mixture/plot_gmm_pdf.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,42 @@
99
"""
1010

1111
import numpy as np
12-
import pylab as pl
12+
import matplotlib.pyplot as plt
13+
from matplotlib.colors import LogNorm
1314
from sklearn import mixture
1415

1516
n_samples = 300
1617

1718
# generate random sample, two components
1819
np.random.seed(0)
20+
21+
# generate spherical data centered on (20, 20)
22+
shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, 20])
23+
24+
# generate zero centered stretched gaussian data
1925
C = np.array([[0., -0.7], [3.5, .7]])
20-
X_train = np.r_[np.dot(np.random.randn(n_samples, 2), C),
21-
np.random.randn(n_samples, 2) + np.array([20, 20])]
26+
stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C)
27+
28+
# concatenate the two datasets into the final training set
29+
X_train = np.vstack([shifted_gaussian, stretched_gaussian])
2230

31+
# fit a Gaussian Mixture Model with two components
2332
clf = mixture.GMM(n_components=2, covariance_type='full')
2433
clf.fit(X_train)
2534

35+
# display predicted scores by the model as a contour plot
2636
x = np.linspace(-20.0, 30.0)
2737
y = np.linspace(-20.0, 40.0)
2838
X, Y = np.meshgrid(x, y)
29-
XX = np.c_[X.ravel(), Y.ravel()]
30-
Z = np.log(-clf.score_samples(XX)[0])
39+
XX = np.array([X.ravel(), Y.ravel()]).T
40+
Z = -clf.score_samples(XX)[0]
3141
Z = Z.reshape(X.shape)
3242

33-
CS = pl.contour(X, Y, Z)
34-
CB = pl.colorbar(CS, shrink=0.8, extend='both')
35-
pl.scatter(X_train[:, 0], X_train[:, 1], .8)
43+
CS = plt.contour(X, Y, Z, norm=LogNorm(vmin=1.0,vmax=1000.0),
44+
levels=np.logspace(0, 3, 10))
45+
CB = plt.colorbar(CS, shrink=0.8, extend='both')
46+
plt.scatter(X_train[:, 0], X_train[:, 1], .8)
3647

37-
pl.axis('tight')
38-
pl.show()
48+
plt.title('Predicted negative log-likelihood by a GMM')
49+
plt.axis('tight')
50+
plt.show()

0 commit comments

Comments
 (0)