Skip to content

Commit 908ded8

Browse files
tomMoralNicolasHug
authored andcommitted
ENH Parallelize gradient computation in t-SNE (scikit-learn#13264)
1 parent 132ad99 commit 908ded8

File tree

7 files changed

+173
-100
lines changed

7 files changed

+173
-100
lines changed

benchmarks/bench_tsne_mnist.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sklearn.decomposition import PCA
2222
from sklearn.utils import check_array
2323
from sklearn.utils import shuffle as _shuffle
24-
24+
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
2525

2626
LOG_DIR = "mnist_tsne_output"
2727
if not os.path.exists(LOG_DIR):
@@ -86,6 +86,7 @@ def sanitize(filename):
8686
"preprocessing.")
8787
args = parser.parse_args()
8888

89+
print("Used number of threads: {}".format(_openmp_effective_n_threads()))
8990
X, y = load_data(order=args.order)
9091

9192
if args.pca_components > 0:
@@ -141,7 +142,7 @@ def bhtsne(X):
141142
data_size.append(70000)
142143

143144
results = []
144-
basename, _ = os.path.splitext(__file__)
145+
basename = os.path.basename(os.path.splitext(__file__)[0])
145146
log_filename = os.path.join(LOG_DIR, basename + '.json')
146147
for n in data_size:
147148
X_train = X[:n]

doc/whats_new/v0.22.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,10 @@ Changelog
430430
impact when ``metric="precomputed"`` or (``metric="euclidean"`` and
431431
``method="exact"``). :issue:`15082` by `Roman Yurchak`_.
432432

433+
- |Efficiency| Improved efficiency of :class:`manifold.TSNE` when
434+
``method="barnes-hut"`` by computing the gradient in parallel.
435+
:pr:`13213` by :user:`Thomas Moreau <tommoral>`
436+
433437
- |API| Deprecate ``training_data_`` unused attribute in
434438
:class:`manifold.Isomap`. :issue:`10482` by `Tom Dupre la Tour`_.
435439

sklearn/manifold/_barnes_hut_tsne.pyx

Lines changed: 118 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
# implementations and papers describing the technique
1010

1111

12-
from libc.stdlib cimport malloc, free
13-
from libc.stdio cimport printf
14-
from libc.math cimport sqrt, log
1512
import numpy as np
1613
cimport numpy as np
14+
from libc.stdio cimport printf
15+
from libc.math cimport sqrt, log
16+
from libc.stdlib cimport malloc, free
17+
from cython.parallel cimport prange, parallel
1718

1819
from ..neighbors._quad_tree cimport _QuadTree
1920

21+
2022
cdef char* EMPTY_STRING = ""
2123

2224
cdef extern from "math.h":
@@ -53,17 +55,18 @@ cdef float compute_gradient(float[:] val_P,
5355
int dof,
5456
long start,
5557
long stop,
56-
bint compute_error) nogil:
58+
bint compute_error,
59+
int num_threads) nogil:
5760
# Having created the tree, calculate the gradient
5861
# in two components, the positive and negative forces
5962
cdef:
6063
long i, coord
6164
int ax
6265
long n_samples = pos_reference.shape[0]
6366
int n_dimensions = qt.n_dimensions
64-
double[1] sum_Q
6567
clock_t t1 = 0, t2 = 0
66-
float sQ, error
68+
double sQ
69+
float error
6770
int take_timing = 1 if qt.verbose > 15 else 0
6871

6972
if qt.verbose > 11:
@@ -72,25 +75,25 @@ cdef float compute_gradient(float[:] val_P,
7275
cdef float* neg_f = <float*> malloc(sizeof(float) * n_samples * n_dimensions)
7376
cdef float* pos_f = <float*> malloc(sizeof(float) * n_samples * n_dimensions)
7477

75-
sum_Q[0] = 0.0
7678
if take_timing:
7779
t1 = clock()
78-
compute_gradient_negative(pos_reference, neg_f, qt, sum_Q,
79-
dof, theta, start, stop)
80+
sQ = compute_gradient_negative(pos_reference, neg_f, qt, dof, theta, start,
81+
stop, num_threads)
8082
if take_timing:
8183
t2 = clock()
8284
printf("[t-SNE] Computing negative gradient: %e ticks\n", ((float) (t2 - t1)))
83-
sQ = sum_Q[0]
8485

8586
if take_timing:
8687
t1 = clock()
8788
error = compute_gradient_positive(val_P, pos_reference, neighbors, indptr,
8889
pos_f, n_dimensions, dof, sQ, start,
89-
qt.verbose, compute_error)
90+
qt.verbose, compute_error, num_threads)
9091
if take_timing:
9192
t2 = clock()
92-
printf("[t-SNE] Computing positive gradient: %e ticks\n", ((float) (t2 - t1)))
93-
for i in range(start, n_samples):
93+
printf("[t-SNE] Computing positive gradient: %e ticks\n",
94+
((float) (t2 - t1)))
95+
for i in prange(start, n_samples, nogil=True, num_threads=num_threads,
96+
schedule='static'):
9497
for ax in range(n_dimensions):
9598
coord = i * n_dimensions + ax
9699
tot_force[i, ax] = pos_f[coord] - (neg_f[coord] / sQ)
@@ -110,7 +113,8 @@ cdef float compute_gradient_positive(float[:] val_P,
110113
double sum_Q,
111114
np.int64_t start,
112115
int verbose,
113-
bint compute_error) nogil:
116+
bint compute_error,
117+
int num_threads) nogil:
114118
# Sum over the following expression for i not equal to j
115119
# grad_i = p_ij (1 + ||y_i - y_j||^2)^-1 (y_i - y_j)
116120
# This is equivalent to compute_edge_forces in the authors' code
@@ -120,118 +124,138 @@ cdef float compute_gradient_positive(float[:] val_P,
120124
int ax
121125
long i, j, k
122126
long n_samples = indptr.shape[0] - 1
123-
float dij, qij, pij
124127
float C = 0.0
128+
float dij, qij, pij
125129
float exponent = (dof + 1.0) / 2.0
126130
float float_dof = (float) (dof)
127-
float[3] buff
131+
float* buff
128132
clock_t t1 = 0, t2 = 0
129133
float dt
130134

131135
if verbose > 10:
132136
t1 = clock()
133-
for i in range(start, n_samples):
134-
# Init the gradient vector
135-
for ax in range(n_dimensions):
136-
pos_f[i * n_dimensions + ax] = 0.0
137-
# Compute the positive interaction for the nearest neighbors
138-
for k in range(indptr[i], indptr[i+1]):
139-
j = neighbors[k]
140-
dij = 0.0
141-
pij = val_P[k]
142-
for ax in range(n_dimensions):
143-
buff[ax] = pos_reference[i, ax] - pos_reference[j, ax]
144-
dij += buff[ax] * buff[ax]
145-
qij = float_dof / (float_dof + dij)
146-
if dof != 1: # i.e. exponent != 1
147-
qij **= exponent
148-
dij = pij * qij
149-
150-
# only compute the error when needed
151-
if compute_error:
152-
qij /= sum_Q
153-
C += pij * log(max(pij, FLOAT32_TINY) / max(qij, FLOAT32_TINY))
137+
138+
with nogil, parallel(num_threads=num_threads):
139+
# Define private buffer variables
140+
buff = <float *> malloc(sizeof(float) * n_dimensions)
141+
142+
for i in prange(start, n_samples, schedule='static'):
143+
# Init the gradient vector
154144
for ax in range(n_dimensions):
155-
pos_f[i * n_dimensions + ax] += dij * buff[ax]
145+
pos_f[i * n_dimensions + ax] = 0.0
146+
# Compute the positive interaction for the nearest neighbors
147+
for k in range(indptr[i], indptr[i+1]):
148+
j = neighbors[k]
149+
dij = 0.0
150+
pij = val_P[k]
151+
for ax in range(n_dimensions):
152+
buff[ax] = pos_reference[i, ax] - pos_reference[j, ax]
153+
dij += buff[ax] * buff[ax]
154+
qij = float_dof / (float_dof + dij)
155+
if dof != 1: # i.e. exponent != 1
156+
qij = qij ** exponent
157+
dij = pij * qij
158+
159+
# only compute the error when needed
160+
if compute_error:
161+
qij = qij / sum_Q
162+
C += pij * log(max(pij, FLOAT32_TINY) \
163+
/ max(qij, FLOAT32_TINY))
164+
for ax in range(n_dimensions):
165+
pos_f[i * n_dimensions + ax] += dij * buff[ax]
166+
167+
free(buff)
156168
if verbose > 10:
157169
t2 = clock()
158170
dt = ((float) (t2 - t1))
159171
printf("[t-SNE] Computed error=%1.4f in %1.1e ticks\n", C, dt)
160172
return C
161173

162174

163-
cdef void compute_gradient_negative(float[:, :] pos_reference,
164-
float* neg_f,
165-
_QuadTree qt,
166-
double* sum_Q,
167-
int dof,
168-
float theta,
169-
long start,
170-
long stop) nogil:
175+
cdef double compute_gradient_negative(float[:, :] pos_reference,
176+
float* neg_f,
177+
_QuadTree qt,
178+
int dof,
179+
float theta,
180+
long start,
181+
long stop,
182+
int num_threads) nogil:
171183
if stop == -1:
172184
stop = pos_reference.shape[0]
173185
cdef:
174186
int ax
175187
int n_dimensions = qt.n_dimensions
188+
int offset = n_dimensions + 2
176189
long i, j, idx
177190
long n = stop - start
178191
long dta = 0
179192
long dtb = 0
180-
long offset = n_dimensions + 2
181193
float size, dist2s, mult
182194
float exponent = (dof + 1.0) / 2.0
183195
float float_dof = (float) (dof)
184-
double qijZ
185-
float[1] iQ
186-
float[3] force, neg_force, pos
196+
double qijZ, sum_Q = 0.0
197+
float* force
198+
float* neg_force
199+
float* pos
187200
clock_t t1 = 0, t2 = 0, t3 = 0
188201
int take_timing = 1 if qt.verbose > 20 else 0
189202

190-
summary = <float*> malloc(sizeof(float) * n * offset)
191203

192-
for i in range(start, stop):
193-
# Clear the arrays
194-
for ax in range(n_dimensions):
195-
force[ax] = 0.0
196-
neg_force[ax] = 0.0
197-
pos[ax] = pos_reference[i, ax]
198-
iQ[0] = 0.0
199-
# Find which nodes are summarizing and collect their centers of mass
200-
# deltas, and sizes, into vectorized arrays
201-
if take_timing:
202-
t1 = clock()
203-
idx = qt.summarize(pos, summary, theta*theta)
204-
if take_timing:
205-
t2 = clock()
206-
# Compute the t-SNE negative force
207-
# for the digits dataset, walking the tree
208-
# is about 10-15x more expensive than the
209-
# following for loop
210-
for j in range(idx // offset):
211-
212-
dist2s = summary[j * offset + n_dimensions]
213-
size = summary[j * offset + n_dimensions + 1]
214-
qijZ = float_dof / (float_dof + dist2s) # 1/(1+dist)
215-
if dof != 1: # i.e. exponent != 1
216-
qijZ **= exponent
217-
sum_Q[0] += size * qijZ # size of the node * q
218-
mult = size * qijZ * qijZ
204+
with nogil, parallel(num_threads=num_threads):
205+
# Define thread-local buffers
206+
summary = <float*> malloc(sizeof(float) * n * offset)
207+
pos = <float *> malloc(sizeof(float) * n_dimensions)
208+
force = <float *> malloc(sizeof(float) * n_dimensions)
209+
neg_force = <float *> malloc(sizeof(float) * n_dimensions)
210+
211+
for i in prange(start, stop, schedule='static'):
212+
# Clear the arrays
219213
for ax in range(n_dimensions):
220-
neg_force[ax] += mult * summary[j * offset + ax]
221-
if take_timing:
222-
t3 = clock()
223-
for ax in range(n_dimensions):
224-
neg_f[i * n_dimensions + ax] = neg_force[ax]
225-
if take_timing:
226-
dta += t2 - t1
227-
dtb += t3 - t2
214+
force[ax] = 0.0
215+
neg_force[ax] = 0.0
216+
pos[ax] = pos_reference[i, ax]
217+
218+
# Find which nodes are summarizing and collect their centers of mass
219+
# deltas, and sizes, into vectorized arrays
220+
if take_timing:
221+
t1 = clock()
222+
idx = qt.summarize(pos, summary, theta*theta)
223+
if take_timing:
224+
t2 = clock()
225+
# Compute the t-SNE negative force
226+
# for the digits dataset, walking the tree
227+
# is about 10-15x more expensive than the
228+
# following for loop
229+
for j in range(idx // offset):
230+
231+
dist2s = summary[j * offset + n_dimensions]
232+
size = summary[j * offset + n_dimensions + 1]
233+
qijZ = float_dof / (float_dof + dist2s) # 1/(1+dist)
234+
if dof != 1: # i.e. exponent != 1
235+
qijZ = qijZ ** exponent
236+
237+
sum_Q += size * qijZ # size of the node * q
238+
mult = size * qijZ * qijZ
239+
for ax in range(n_dimensions):
240+
neg_force[ax] += mult * summary[j * offset + ax]
241+
if take_timing:
242+
t3 = clock()
243+
for ax in range(n_dimensions):
244+
neg_f[i * n_dimensions + ax] = neg_force[ax]
245+
if take_timing:
246+
dta += t2 - t1
247+
dtb += t3 - t2
248+
free(pos)
249+
free(force)
250+
free(neg_force)
251+
free(summary)
228252
if take_timing:
229253
printf("[t-SNE] Tree: %li clock ticks | ", dta)
230254
printf("Force computation: %li clock ticks\n", dtb)
231255

232256
# Put sum_Q to machine EPSILON to avoid divisions by 0
233-
sum_Q[0] = max(sum_Q[0], FLOAT64_EPS)
234-
free(summary)
257+
sum_Q = max(sum_Q, FLOAT64_EPS)
258+
return sum_Q
235259

236260

237261
def gradient(float[:] val_P,
@@ -244,7 +268,8 @@ def gradient(float[:] val_P,
244268
int verbose,
245269
int dof=1,
246270
long skip_num_points=0,
247-
bint compute_error=1):
271+
bint compute_error=1,
272+
int num_threads=1):
248273
# This function is designed to be called from external Python
249274
# it passes the 'forces' array by reference and fills thats array
250275
# up in-place
@@ -269,8 +294,11 @@ def gradient(float[:] val_P,
269294
# in the generated C code that triggers error with gcc 4.9
270295
# and -Werror=format-security
271296
printf("[t-SNE] Computing gradient\n%s", EMPTY_STRING)
297+
272298
C = compute_gradient(val_P, pos_output, neighbors, indptr, forces,
273-
qt, theta, dof, skip_num_points, -1, compute_error)
299+
qt, theta, dof, skip_num_points, -1, compute_error,
300+
num_threads)
301+
274302
if verbose > 10:
275303
# XXX: format hack to workaround lack of `const char *` type
276304
# in the generated C code

0 commit comments

Comments
 (0)