Skip to content

Commit 5ea6b3a

Browse files
jnothmanarjoly
authored andcommitted
COSMIT in response to @arjoly's comments
1 parent e633f25 commit 5ea6b3a

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

sklearn/datasets/samples_generator.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -311,34 +311,34 @@ def sample_example():
311311
_, n_classes = p_w_c.shape
312312

313313
# pick a nonzero number of labels per document by rejection sampling
314-
n = n_classes + 1
315-
while (not allow_unlabeled and n == 0) or n > n_classes:
316-
n = generator.poisson(n_labels)
314+
y_size = n_classes + 1
315+
while (not allow_unlabeled and y_size == 0) or y_size > n_classes:
316+
y_size = generator.poisson(n_labels)
317317

318318
# pick n classes
319-
y = []
320-
while len(y) != n:
319+
y = set()
320+
while len(y) != y_size:
321321
# pick a class with probability P(c)
322-
c = np.searchsorted(cumulative_p_c, generator.rand())
323-
324-
if not c in y:
325-
y.append(c)
322+
c = np.searchsorted(cumulative_p_c,
323+
generator.rand(y_size - len(y)))
324+
y.update(c)
325+
y = list(y)
326326

327327
# pick a non-zero document length by rejection sampling
328-
k = 0
329-
while k == 0:
330-
k = generator.poisson(length)
328+
n_words = 0
329+
while n_words == 0:
330+
n_words = generator.poisson(length)
331331

332-
# generate a document of length k words
332+
# generate a document of length n_words
333333
if len(y) == 0:
334334
# if sample does not belong to any class, generate noise word
335-
words = generator.randint(n_features, size=k)
335+
words = generator.randint(n_features, size=n_words)
336336
return words, y
337337

338338
# sample words with replacement from selected classes
339-
cumulative_p_w_sample = np.cumsum(p_w_c[:, y].sum(axis=1))
339+
cumulative_p_w_sample = p_w_c.take(y, axis=1).sum(axis=1).cumsum()
340340
cumulative_p_w_sample /= cumulative_p_w_sample[-1]
341-
words = np.searchsorted(cumulative_p_w_sample, generator.rand(k))
341+
words = np.searchsorted(cumulative_p_w_sample, generator.rand(n_words))
342342
return words, y
343343

344344
X_indices = array.array('i')

0 commit comments

Comments
 (0)