|
9 | 9 | """ |
10 | 10 |
|
11 | 11 | import numpy as np |
12 | | -import pylab as pl |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +from matplotlib.colors import LogNorm |
13 | 14 | from sklearn import mixture |
14 | 15 |
|
15 | 16 | n_samples = 300 |
16 | 17 |
|
17 | 18 | # generate random sample, two components |
18 | 19 | np.random.seed(0) |
| 20 | + |
| 21 | +# generate spherical data centered on (20, 20) |
| 22 | +shifted_gaussian = np.random.randn(n_samples, 2) + np.array([20, 20]) |
| 23 | + |
| 24 | +# generate zero centered stretched gaussian data |
19 | 25 | C = np.array([[0., -0.7], [3.5, .7]]) |
20 | | -X_train = np.r_[np.dot(np.random.randn(n_samples, 2), C), |
21 | | - np.random.randn(n_samples, 2) + np.array([20, 20])] |
| 26 | +stretched_gaussian = np.dot(np.random.randn(n_samples, 2), C) |
| 27 | + |
| 28 | +# concatenate the two datasets into the final training set |
| 29 | +X_train = np.vstack([shifted_gaussian, stretched_gaussian]) |
22 | 30 |
|
| 31 | +# fit a Gaussian Mixture Model with two components |
23 | 32 | clf = mixture.GMM(n_components=2, covariance_type='full') |
24 | 33 | clf.fit(X_train) |
25 | 34 |
|
| 35 | +# display predicted scores by the model as a contour plot |
26 | 36 | x = np.linspace(-20.0, 30.0) |
27 | 37 | y = np.linspace(-20.0, 40.0) |
28 | 38 | X, Y = np.meshgrid(x, y) |
29 | | -XX = np.c_[X.ravel(), Y.ravel()] |
30 | | -Z = np.log(-clf.score_samples(XX)[0]) |
| 39 | +XX = np.array([X.ravel(), Y.ravel()]).T |
| 40 | +Z = -clf.score_samples(XX)[0] |
31 | 41 | Z = Z.reshape(X.shape) |
32 | 42 |
|
33 | | -CS = pl.contour(X, Y, Z) |
34 | | -CB = pl.colorbar(CS, shrink=0.8, extend='both') |
35 | | -pl.scatter(X_train[:, 0], X_train[:, 1], .8) |
| 43 | +CS = plt.contour(X, Y, Z, norm=LogNorm(vmin=1.0,vmax=1000.0), |
| 44 | + levels=np.logspace(0, 3, 10)) |
| 45 | +CB = plt.colorbar(CS, shrink=0.8, extend='both') |
| 46 | +plt.scatter(X_train[:, 0], X_train[:, 1], .8) |
36 | 47 |
|
37 | | -pl.axis('tight') |
38 | | -pl.show() |
| 48 | +plt.title('Predicted negative log-likelihood by a GMM') |
| 49 | +plt.axis('tight') |
| 50 | +plt.show() |
0 commit comments