Skip to content

Commit b7b91ef

Browse files
committed
Merge branch 'master' into gbrt-interactions
2 parents dd0deae + 9603b07 commit b7b91ef

File tree

14 files changed

+343
-53
lines changed

14 files changed

+343
-53
lines changed

doc/modules/manifold.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ tangent spaces to learn the embedding. LTSA can be performed with function
359359
:target: ../auto_examples/manifold/plot_lle_digits.html
360360
:align: center
361361
:scale: 50
362-
362+
363363
Complexity
364364
----------
365365

@@ -393,7 +393,8 @@ The overall complexity of standard LTSA is
393393
Multi-dimensional Scaling (MDS)
394394
===============================
395395

396-
Multidimensional scaling (:class:`MDS`) seeks a low-dimensional
396+
`Multidimensional scaling <http://en.wikipedia.org/wiki/Multidimensional_scaling>`_
397+
(:class:`MDS`) seeks a low-dimensional
397398
representation of the data in which the distances respect well the
398399
distances in the original high-dimensional space.
399400

examples/manifold/plot_compare_methods.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
For a discussion and comparison of these algorithms, see the
1010
:ref:`manifold module page <manifold>`
1111
12+
For a similiar example, where the methods are applied to a
13+
sphere dataset, see :ref:`example_manifold_plot_manifold_sphere.py`
14+
1215
Note that the purpose of the MDS is to find a low-dimensional
1316
representation of the data (here 2D) in which the distances respect well
1417
the distances in the original high-dimensional space, unlike other
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#!/usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
4+
"""
5+
=========================================
6+
Manifold Learning methods on a severed sphere
7+
=========================================
8+
9+
An application of the different :ref:`manifold` techniques
10+
on a spherical data-set. Here one can see the use of
11+
dimensionality reduction in order to gain some intuition
12+
regarding the Manifold learning methods. Regarding the dataset,
13+
the poles are cut from the sphere, as well as a thin slice down its
14+
side. This enables the manifold learning techniques to
15+
'spread it open' whilst projecting it onto two dimensions.
16+
17+
For a similiar example, where the methods are applied to the
18+
S-curve dataset, see :ref:`example_manifold_plot_compare_methods.py`
19+
20+
Note that the purpose of the :ref:`MDS <multidimensional_scaling>` is
21+
to find a low-dimensional representation of the data (here 2D) in
22+
which the distances respect well the distances in the original
23+
high-dimensional space, unlike other manifold-learning algorithms,
24+
it does not seeks an isotropic representation of the data in
25+
the low-dimensional space. Here the manifold problem matches fairly
26+
that of representing a flat map of the Earth, as with
27+
`map projection<http://en.wikipedia.org/wiki/Map_projection>`_
28+
"""
29+
30+
# Author: Jaques Grobler <[email protected]>
31+
# License: BSD
32+
33+
print __doc__
34+
35+
from time import time
36+
37+
import numpy as np
38+
import pylab as pl
39+
from mpl_toolkits.mplot3d import Axes3D
40+
from matplotlib.ticker import NullFormatter
41+
42+
from sklearn import manifold
43+
from sklearn.metrics import euclidean_distances
44+
from sklearn.utils import check_random_state
45+
46+
# Next line to silence pyflakes.
47+
Axes3D
48+
49+
# Variables for manifold learning.
50+
n_neighbors = 10
51+
n_samples = 1000
52+
53+
# Create our sphere.
54+
random_state = check_random_state(0)
55+
p = random_state.rand(n_samples) * (2 * np.pi - 0.55)
56+
t = random_state.rand(n_samples) * np.pi
57+
58+
# Sever the poles from the sphere.
59+
indices = ((t < (np.pi - (np.pi / 8))) & (t > ((np.pi / 8))))
60+
colors = p[indices]
61+
x, y, z = np.sin(t[indices]) * np.cos(p[indices]), \
62+
np.sin(t[indices]) * np.sin(p[indices]), \
63+
np.cos(t[indices])
64+
65+
# Plot our dataset.
66+
fig = pl.figure(figsize=(15, 8))
67+
pl.suptitle("Manifold Learning with %i points, %i neighbors"
68+
% (1000, n_neighbors), fontsize=14)
69+
70+
ax = fig.add_subplot(241, projection='3d')
71+
ax.scatter(x, y, z, c=p[indices], cmap=pl.cm.rainbow)
72+
try:
73+
# compatibility matplotlib < 1.0
74+
ax.view_init(40, -10)
75+
except:
76+
pass
77+
78+
sphere_data = np.array([x, y, z]).T
79+
80+
# Perform Locally Linear Embedding Manifold learning
81+
methods = ['standard', 'ltsa', 'hessian', 'modified']
82+
labels = ['LLE', 'LTSA', 'Hessian LLE', 'Modified LLE']
83+
84+
for i, method in enumerate(methods):
85+
t0 = time()
86+
trans_data = manifold\
87+
.LocallyLinearEmbedding(n_neighbors, 2,
88+
method=method).fit_transform(sphere_data).T
89+
t1 = time()
90+
print "%s: %.2g sec" % (methods[i], t1 - t0)
91+
92+
ax = fig.add_subplot(242 + i)
93+
pl.scatter(trans_data[0], trans_data[1], c=colors, cmap=pl.cm.rainbow)
94+
pl.title("%s (%.2g sec)" % (labels[i], t1 - t0))
95+
ax.xaxis.set_major_formatter(NullFormatter())
96+
ax.yaxis.set_major_formatter(NullFormatter())
97+
pl.axis('tight')
98+
99+
# Perform Isomap Manifold learning.
100+
t0 = time()
101+
trans_data = manifold.Isomap(n_neighbors, n_components=2)\
102+
.fit_transform(sphere_data).T
103+
t1 = time()
104+
print "%s: %.2g sec" % ('ISO', t1 - t0)
105+
106+
ax = fig.add_subplot(246)
107+
pl.scatter(trans_data[0], trans_data[1], c=colors, cmap=pl.cm.rainbow)
108+
pl.title("%s (%.2g sec)" % ('Isomap', t1 - t0))
109+
ax.xaxis.set_major_formatter(NullFormatter())
110+
ax.yaxis.set_major_formatter(NullFormatter())
111+
pl.axis('tight')
112+
113+
# Perform Multi-dimensional scaling.
114+
t0 = time()
115+
mds = manifold.MDS(2, max_iter=100, n_init=1)
116+
trans_data = mds.fit_transform(euclidean_distances(sphere_data)).T
117+
t1 = time()
118+
print "MDS: %.2g sec" % (t1 - t0)
119+
120+
ax = fig.add_subplot(247)
121+
pl.scatter(trans_data[0], trans_data[1], c=colors, cmap=pl.cm.rainbow)
122+
pl.title("MDS (%.2g sec)" % (t1 - t0))
123+
ax.xaxis.set_major_formatter(NullFormatter())
124+
ax.yaxis.set_major_formatter(NullFormatter())
125+
pl.axis('tight')
126+
127+
pl.show()

sklearn/cluster/tests/test_k_means.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _get_mac_os_version():
207207

208208

209209
def test_k_means_plus_plus_init_2_jobs():
210-
if _get_mac_os_version() == '10.7':
210+
if _get_mac_os_version() >= '10.7':
211211
raise SkipTest('Multi-process bug in Mac OS X Lion (see issue #636)')
212212
k_means = KMeans(init="k-means++", n_clusters=n_clusters, n_jobs=2,
213213
random_state=42).fit(X)

sklearn/feature_extraction/hashing.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Author: Lars Buitinck <[email protected]>
22
# License: 3-clause BSD.
33

4-
import itertools
54
import numbers
65

76
import numpy as np
@@ -76,7 +75,7 @@ def _validate_params(n_features, input_type):
7675
if not isinstance(n_features, (numbers.Integral, np.integer)):
7776
raise TypeError("n_features must be integral, got %r (%s)."
7877
% (n_features, type(n_features)))
79-
elif n_features < 1 or n_features >= 2**31:
78+
elif n_features < 1 or n_features >= 2 ** 31:
8079
raise ValueError("Invalid number of features (%d)." % n_features)
8180

8281
if input_type not in ("dict", "pair", "string"):

sklearn/manifold/spectral_embedding.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,10 @@ def spectral_embedding(adjacency, n_components=8, eigen_solver=None,
264264
if eigen_solver == 'amg':
265265
# Use AMG to get a preconditioner and speed up the eigenvalue
266266
# problem.
267+
if not sparse.issparse(laplacian):
268+
warnings.warn("AMG works better for sparse matrices")
267269
laplacian = laplacian.astype(np.float) # lobpcg needs native floats
270+
laplacian = _set_diag(laplacian, 1)
268271
ml = smoothed_aggregation_solver(atleast2d_or_csr(laplacian))
269272
M = ml.aspreconditioner()
270273
X = random_state.rand(laplacian.shape[0], n_components + 1)
@@ -446,7 +449,7 @@ def fit(self, X, y=None):
446449
self.random_state = check_random_state(self.random_state)
447450
if isinstance(self.affinity, basestring):
448451
if self.affinity not in set(("nearest_neighbors", "rbf",
449-
"precomputed")):
452+
"precomputed")):
450453
raise ValueError(("%s is not a valid affinity. Expected "
451454
"'precomputed', 'rbf', 'nearest_neighbors' "
452455
"or a callable.") % self.affinity)

sklearn/manifold/tests/test_spectral_embedding.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
from nose.tools import assert_true
22
from nose.tools import assert_equal
33

4-
from scipy import sparse
54
from scipy.sparse import csr_matrix
65
from scipy.sparse import csc_matrix
76
import numpy as np
8-
from numpy.testing import assert_almost_equal, assert_array_almost_equal
7+
from numpy.testing import assert_array_almost_equal
98

109
from nose.tools import assert_raises
1110
from nose.plugins.skip import SkipTest
1211

1312
from sklearn.manifold.spectral_embedding import SpectralEmbedding
1413
from sklearn.manifold.spectral_embedding import _graph_is_connected
1514
from sklearn.metrics.pairwise import rbf_kernel
16-
from sklearn.pipeline import Pipeline
1715
from sklearn.metrics import normalized_mutual_info_score
18-
from sklearn.cluster import KMeans, SpectralClustering
16+
from sklearn.cluster import KMeans
1917
from sklearn.datasets.samples_generator import make_blobs
2018

2119

@@ -84,7 +82,7 @@ def test_spectral_embedding_precomputed_affinity(seed=36):
8482
embed_rbf = se_rbf.fit_transform(S)
8583
assert_array_almost_equal(
8684
se_precomp.affinity_matrix_, se_rbf.affinity_matrix_)
87-
assert_true(_check_with_col_sign_flipping(embed_precomp, embed_rbf, 0.01))
85+
assert_true(_check_with_col_sign_flipping(embed_precomp, embed_rbf, 0.02))
8886

8987

9088
def test_spectral_embedding_callable_affinity(seed=36):
@@ -105,8 +103,9 @@ def test_spectral_embedding_callable_affinity(seed=36):
105103
embed_callable = se_callable.fit_transform(S)
106104
assert_array_almost_equal(
107105
se_callable.affinity_matrix_, se_rbf.affinity_matrix_)
106+
assert_array_almost_equal(kern, se_rbf.affinity_matrix_)
108107
assert_true(
109-
_check_with_col_sign_flipping(embed_rbf, embed_callable, 0.01))
108+
_check_with_col_sign_flipping(embed_rbf, embed_callable, 0.02))
110109

111110

112111
def test_spectral_embedding_amg_solver(seed=36):
@@ -116,17 +115,14 @@ def test_spectral_embedding_amg_solver(seed=36):
116115
except ImportError:
117116
raise SkipTest
118117

119-
gamma = 0.9
120-
se_amg = SpectralEmbedding(n_components=3, affinity="rbf",
121-
gamma=gamma, eigen_solver="amg",
118+
se_amg = SpectralEmbedding(n_components=3, affinity="nearest_neighbors",
119+
eigen_solver="amg", n_neighbors=5,
122120
random_state=np.random.RandomState(seed))
123-
se_arpack = SpectralEmbedding(n_components=3, affinity="rbf",
124-
gamma=gamma, eigen_solver="arpack",
121+
se_arpack = SpectralEmbedding(n_components=3, affinity="nearest_neighbors",
122+
eigen_solver="arpack", n_neighbors=5,
125123
random_state=np.random.RandomState(seed))
126124
embed_amg = se_amg.fit_transform(S)
127125
embed_arpack = se_arpack.fit_transform(S)
128-
assert_array_almost_equal(
129-
se_amg.affinity_matrix_, se_arpack.affinity_matrix_)
130126
assert_true(_check_with_col_sign_flipping(embed_amg, embed_arpack, 0.01))
131127

132128

@@ -151,33 +147,17 @@ def test_pipline_spectral_clustering(seed=36):
151147

152148
def test_spectral_embedding_unknown_eigensolver(seed=36):
153149
"""Test that SpectralClustering fails with an unknown eigensolver"""
154-
centers = np.array([
155-
[0., 0., 0.],
156-
[10., 10., 10.],
157-
[20., 20., 20.],
158-
])
159-
X, true_labels = make_blobs(n_samples=100, centers=centers,
160-
cluster_std=1., random_state=42)
161-
162-
se_precomp = SpectralEmbedding(n_components=1, affinity="precomputed",
163-
random_state=np.random.RandomState(seed),
164-
eigen_solver="<unknown>")
165-
assert_raises(ValueError, se_precomp.fit, S)
150+
se = SpectralEmbedding(n_components=1, affinity="precomputed",
151+
random_state=np.random.RandomState(seed),
152+
eigen_solver="<unknown>")
153+
assert_raises(ValueError, se.fit, S)
166154

167155

168156
def test_spectral_embedding_unknown_affinity(seed=36):
169157
"""Test that SpectralClustering fails with an unknown affinity type"""
170-
centers = np.array([
171-
[0., 0., 0.],
172-
[10., 10., 10.],
173-
[20., 20., 20.],
174-
])
175-
X, true_labels = make_blobs(n_samples=100, centers=centers,
176-
cluster_std=1., random_state=42)
177-
178-
se_precomp = SpectralEmbedding(n_components=1, affinity="<unknown>",
179-
random_state=np.random.RandomState(seed))
180-
assert_raises(ValueError, se_precomp.fit, S)
158+
se = SpectralEmbedding(n_components=1, affinity="<unknown>",
159+
random_state=np.random.RandomState(seed))
160+
assert_raises(ValueError, se.fit, S)
181161

182162

183163
def test_connectivity(seed=36):

0 commit comments

Comments
 (0)