Skip to content

Commit 4aaf45b

Browse files
wdevazelhesglemaitre
authored andcommitted
[MRG+2] FIX trustworthiness accepts custom metric (scikit-learn#9775)
1 parent 2c4ef26 commit 4aaf45b

File tree

3 files changed

+72
-6
lines changed

3 files changed

+72
-6
lines changed

doc/whats_new/v0.20.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ Classifiers and regressors
162162
in :class:`discriminant_analysis`.
163163
:issue:`10898` by :user:`Nanxin Chen <bobchennan>`.`
164164

165+
- :func:`manifold.t_sne.trustworthiness` accepts metrics other than
166+
Euclidean. :issue:`9775` by :user:`William de Vazelhes <wdevazelhes>`.
167+
165168
Cluster
166169

167170
- :class:`cluster.KMeans`, :class:`cluster.MiniBatchKMeans` and
@@ -236,6 +239,15 @@ Linear, kernelized and related models
236239
underlying implementation is not random.
237240
:issue:`9497` by :user:`Albert Thomas <albertcthomas>`.
238241

242+
Decomposition, manifold learning and clustering
243+
244+
- Deprecate ``precomputed`` parameter in function
245+
:func:`manifold.t_sne.trustworthiness`. Instead, the new parameter
246+
``metric`` should be used with any compatible metric including
247+
'precomputed', in which case the input matrix ``X`` should be a matrix of
248+
pairwise distances or squared distances. :issue:`9775` by
249+
:user:`William de Vazelhes <wdevazelhes>`.
250+
239251
Utils
240252

241253
- Avoid copying the data in :func:`utils.check_array` when the input data is a
@@ -478,6 +490,15 @@ Linear, kernelized and related models
478490
:class:`linear_model.logistic.LogisticRegression` when ``verbose`` is set to 0.
479491
:issue:`10881` by :user:`Alexandre Sevin <AlexandreSev>`.
480492

493+
Decomposition, manifold learning and clustering
494+
495+
- Deprecate ``precomputed`` parameter in function
496+
:func:`manifold.t_sne.trustworthiness`. Instead, the new parameter
497+
``metric`` should be used with any compatible metric including
498+
'precomputed', in which case the input matrix ``X`` should be a matrix of
499+
pairwise distances or squared distances. :issue:`9775` by
500+
:user:`William de Vazelhes <wdevazelhes>`.
501+
481502
Metrics
482503

483504
- Deprecate ``reorder`` parameter in :func:`metrics.auc` as it's no longer required

sklearn/manifold/t_sne.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# http://cseweb.ucsd.edu/~lvdmaaten/workshops/nips2010/papers/vandermaaten.pdf
1010
from __future__ import division
1111

12+
import warnings
1213
from time import time
1314
import numpy as np
1415
from scipy import linalg
@@ -394,7 +395,8 @@ def _gradient_descent(objective, p0, it, n_iter,
394395
return p, error, i
395396

396397

397-
def trustworthiness(X, X_embedded, n_neighbors=5, precomputed=False):
398+
def trustworthiness(X, X_embedded, n_neighbors=5,
399+
precomputed=False, metric='euclidean'):
398400
r"""Expresses to what extent the local structure is retained.
399401
400402
The trustworthiness is within [0, 1]. It is defined as
@@ -431,15 +433,28 @@ def trustworthiness(X, X_embedded, n_neighbors=5, precomputed=False):
431433
precomputed : bool, optional (default: False)
432434
Set this flag if X is a precomputed square distance matrix.
433435
436+
..deprecated:: 0.20
437+
``precomputed`` has been deprecated in version 0.20 and will be
438+
removed in version 0.22. Use ``metric`` instead.
439+
440+
metric : string, or callable, optional, default 'euclidean'
441+
Which metric to use for computing pairwise distances between samples
442+
from the original input space. If metric is 'precomputed', X must be a
443+
matrix of pairwise distances or squared distances. Otherwise, see the
444+
documentation of argument metric in sklearn.pairwise.pairwise_distances
445+
for a list of available metrics.
446+
434447
Returns
435448
-------
436449
trustworthiness : float
437450
Trustworthiness of the low-dimensional embedding.
438451
"""
439452
if precomputed:
440-
dist_X = X
441-
else:
442-
dist_X = pairwise_distances(X, squared=True)
453+
warnings.warn("The flag 'precomputed' has been deprecated in version "
454+
"0.20 and will be removed in 0.22. See 'metric' "
455+
"parameter instead.", DeprecationWarning)
456+
metric = 'precomputed'
457+
dist_X = pairwise_distances(X, metric=metric)
443458
ind_X = np.argsort(dist_X, axis=1)
444459
ind_X_embedded = NearestNeighbors(n_neighbors).fit(X_embedded).kneighbors(
445460
return_distance=False)

sklearn/manifold/tests/test_t_sne.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from sklearn.utils.testing import assert_greater
1515
from sklearn.utils.testing import assert_raises_regexp
1616
from sklearn.utils.testing import assert_in
17+
from sklearn.utils.testing import assert_warns
18+
from sklearn.utils.testing import assert_raises
1719
from sklearn.utils.testing import skip_if_32bit
1820
from sklearn.utils import check_random_state
1921
from sklearn.manifold.t_sne import _joint_probabilities
@@ -288,11 +290,39 @@ def test_preserve_trustworthiness_approximately_with_precomputed_distances():
288290
early_exaggeration=2.0, metric="precomputed",
289291
random_state=i, verbose=0)
290292
X_embedded = tsne.fit_transform(D)
291-
t = trustworthiness(D, X_embedded, n_neighbors=1,
292-
precomputed=True)
293+
t = trustworthiness(D, X_embedded, n_neighbors=1, metric="precomputed")
293294
assert t > .95
294295

295296

297+
def test_trustworthiness_precomputed_deprecation():
298+
# FIXME: Remove this test in v0.23
299+
300+
# Use of the flag `precomputed` in trustworthiness parameters has been
301+
# deprecated, but will still work until v0.23.
302+
random_state = check_random_state(0)
303+
X = random_state.randn(100, 2)
304+
assert_equal(assert_warns(DeprecationWarning, trustworthiness,
305+
pairwise_distances(X), X, precomputed=True), 1.)
306+
assert_equal(assert_warns(DeprecationWarning, trustworthiness,
307+
pairwise_distances(X), X, metric='precomputed',
308+
precomputed=True), 1.)
309+
assert_raises(ValueError, assert_warns, DeprecationWarning,
310+
trustworthiness, X, X, metric='euclidean', precomputed=True)
311+
assert_equal(assert_warns(DeprecationWarning, trustworthiness,
312+
pairwise_distances(X), X, metric='euclidean',
313+
precomputed=True), 1.)
314+
315+
316+
def test_trustworthiness_not_euclidean_metric():
317+
# Test trustworthiness with a metric different from 'euclidean' and
318+
# 'precomputed'
319+
random_state = check_random_state(0)
320+
X = random_state.randn(100, 2)
321+
assert_equal(trustworthiness(X, X, metric='cosine'),
322+
trustworthiness(pairwise_distances(X, metric='cosine'), X,
323+
metric='precomputed'))
324+
325+
296326
def test_early_exaggeration_too_small():
297327
# Early exaggeration factor must be >= 1.
298328
tsne = TSNE(early_exaggeration=0.99)

0 commit comments

Comments
 (0)