Skip to content

Commit f5300d1

Browse files
stefanvlarsmans
authored andcommitted
ENH Speed up and simplify cartesian product
1 parent 5c5e83c commit f5300d1

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

sklearn/utils/extmath.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Alexandre T. Passos
77
# Olivier Grisel
88
# Lars Buitinck
9+
# Stefan van der Walt
910
# License: BSD 3 clause
1011

1112
from functools import partial
@@ -501,24 +502,20 @@ def cartesian(arrays, out=None):
501502
[3, 5, 6],
502503
[3, 5, 7]])
503504
504-
References
505-
----------
506-
http://stackoverflow.com/q/1208118
507-
508505
"""
509-
arrays = [np.asarray(x).ravel() for x in arrays]
506+
arrays = [np.asarray(x) for x in arrays]
507+
shape = (len(x) for x in arrays)
510508
dtype = arrays[0].dtype
511509

512-
n = np.prod([x.size for x in arrays])
510+
ix = np.indices(shape)
511+
ix = ix.reshape(len(arrays), -1).T
512+
513513
if out is None:
514-
out = np.empty([n, len(arrays)], dtype=dtype)
515-
516-
m = n // arrays[0].size
517-
out[:, 0] = np.repeat(arrays[0], m)
518-
if arrays[1:]:
519-
cartesian(arrays[1:], out=out[0:m, 1:])
520-
for j in xrange(1, arrays[0].size):
521-
out[j * m:(j + 1) * m, 1:] = out[0:m, 1:]
514+
out = np.empty_like(ix, dtype=dtype)
515+
516+
for n, arr in enumerate(arrays):
517+
out[:, n] = arrays[n][ix[:, n]]
518+
522519
return out
523520

524521

0 commit comments

Comments
 (0)