Skip to content

Commit 6ee390d

Browse files
andreanrthomasjpfan
authored andcommitted
DOC Change default dataset for plot_johnson_lindenstrauss_bound.py (scikit-learn#14787)
1 parent b730bef commit 6ee390d

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

examples/plot_johnson_lindenstrauss_bound.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,26 @@
102102
# Empirical validation
103103
# ====================
104104
#
105-
# We validate the above bounds on the digits dataset or on the 20 newsgroups
106-
# text document (TF-IDF word frequencies) dataset:
107-
#
108-
# - for the digits dataset, some 8x8 gray level pixels data for 500
109-
# handwritten digits pictures are randomly projected to spaces for various
110-
# larger number of dimensions ``n_components``.
105+
# We validate the above bounds on the 20 newsgroups text document
106+
# (TF-IDF word frequencies) dataset or on the digits dataset:
111107
#
112108
# - for the 20 newsgroups dataset some 500 documents with 100k
113109
# features in total are projected using a sparse random matrix to smaller
114110
# euclidean spaces with various values for the target number of dimensions
115111
# ``n_components``.
116112
#
117-
# The default dataset is the digits dataset. To run the example on the twenty
118-
# newsgroups dataset, pass the --twenty-newsgroups command line argument to
113+
# - for the digits dataset, some 8x8 gray level pixels data for 500
114+
# handwritten digits pictures are randomly projected to spaces for various
115+
# larger number of dimensions ``n_components``.
116+
#
117+
# The default dataset is the 20 newsgroups dataset. To run the example on the
118+
# digits dataset, pass the ``--use-digits-dataset`` command line argument to
119119
# this script.
120120

121-
if '--twenty-newsgroups' in sys.argv:
122-
# Need an internet connection hence not enabled by default
123-
data = fetch_20newsgroups_vectorized().data[:500]
124-
else:
121+
if '--use-digits-dataset' in sys.argv:
125122
data = load_digits().data[:500]
123+
else:
124+
data = fetch_20newsgroups_vectorized().data[:500]
126125

127126
##########################################################
128127
# For each value of ``n_components``, we plot:
@@ -158,7 +157,10 @@
158157
projected_data, squared=True).ravel()[nonzero]
159158

160159
plt.figure()
161-
plt.hexbin(dists, projected_dists, gridsize=100, cmap=plt.cm.PuBu)
160+
min_dist = min(projected_dists.min(), dists.min())
161+
max_dist = max(projected_dists.max(), dists.max())
162+
plt.hexbin(dists, projected_dists, gridsize=100, cmap=plt.cm.PuBu,
163+
extent=[min_dist, max_dist, min_dist, max_dist])
162164
plt.xlabel("Pairwise squared distances in original space")
163165
plt.ylabel("Pairwise squared distances in projected space")
164166
plt.title("Pairwise distances distribution for n_components=%d" %

0 commit comments

Comments
 (0)