Skip to content

Commit cdd2279

Browse files
committed
ENH fix astype usage to prevent copying
Removed where possible; using copy=False (from utils.fixes) where needed. Also some C integer type fixes to gradient boosting.
1 parent 4e19ef4 commit cdd2279

File tree

14 files changed

+10910
-9755
lines changed

14 files changed

+10910
-9755
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
/sklearn/linear_model/sgd_fast.c -diff
1111
/sklearn/metrics/pairwise_fast.c -diff
1212
/sklearn/neighbors/ball_tree.c -diff
13+
/sklearn/neighbors/kd_tree.c -diff
1314
/sklearn/svm/liblinear.c -diff
1415
/sklearn/svm/libsvm.c -diff
1516
/sklearn/svm/libsvm_sparse.c -diff

sklearn/covariance/robust_covariance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def reweight_covariance(self, data):
681681
location_reweighted = data[mask].mean(0)
682682
covariance_reweighted = self._nonrobust_covariance(
683683
data[mask], assume_centered=self.assume_centered)
684-
support_reweighted = np.zeros(n_samples).astype(bool)
684+
support_reweighted = np.zeros(n_samples, dtype=bool)
685685
support_reweighted[mask] = True
686686
self._set_covariance(covariance_reweighted)
687687
self.location_ = location_reweighted

sklearn/datasets/samples_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ..preprocessing import MultiLabelBinarizer
1515
from ..utils import array2d, check_random_state
1616
from ..utils import shuffle as util_shuffle
17+
from ..utils.fixes import astype
1718
from ..utils.random import sample_without_replacement
1819
from ..externals import six
1920
map = six.moves.map
@@ -26,8 +27,9 @@ def _generate_hypercube(samples, dimensions, rng):
2627
if dimensions > 30:
2728
return np.hstack([_generate_hypercube(samples, dimensions - 30, rng),
2829
_generate_hypercube(samples, 30, rng)])
29-
out = sample_without_replacement(2 ** dimensions, samples,
30-
random_state=rng).astype('>u4')
30+
out = astype(sample_without_replacement(2 ** dimensions, samples,
31+
random_state=rng),
32+
dtype='>u4', copy=False)
3133
out = np.unpackbits(out.view('>u1')).reshape((-1, 32))[:, -dimensions:]
3234
return out
3335

sklearn/ensemble/_gradient_boosting.c

Lines changed: 2009 additions & 2237 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sklearn/ensemble/_gradient_boosting.pyx

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@ from sklearn.tree._tree cimport Tree, Node
1616

1717
ctypedef np.int32_t int32
1818
ctypedef np.float64_t float64
19-
ctypedef np.int8_t int8
20-
21-
from numpy import bool as np_bool
19+
ctypedef np.uint8_t uint8
2220

2321
# no namespace lookup for numpy dtype and array creation
2422
from numpy import zeros as np_zeros
2523
from numpy import ones as np_ones
2624
from numpy import bool as np_bool
27-
from numpy import int8 as np_int8
28-
from numpy import intp as np_intp
2925
from numpy import float32 as np_float32
3026
from numpy import float64 as np_float64
3127

@@ -267,7 +263,8 @@ cpdef _partial_dependence_tree(Tree tree, DTYPE_t[:, ::1] X,
267263
total_weight)
268264

269265

270-
def _random_sample_mask(int n_total_samples, int n_total_in_bag, random_state):
266+
def _random_sample_mask(np.npy_intp n_total_samples,
267+
np.npy_intp n_total_in_bag, random_state):
271268
"""Create a random sample mask where ``n_total_in_bag`` elements are set.
272269
273270
Parameters
@@ -289,15 +286,15 @@ def _random_sample_mask(int n_total_samples, int n_total_in_bag, random_state):
289286
"""
290287
cdef np.ndarray[float64, ndim=1, mode="c"] rand = \
291288
random_state.rand(n_total_samples)
292-
cdef np.ndarray[int8, ndim=1, mode="c"] sample_mask = \
293-
np_zeros((n_total_samples,), dtype=np_int8)
289+
cdef np.ndarray[uint8, ndim=1, mode="c", cast=True] sample_mask = \
290+
np_zeros((n_total_samples,), dtype=np_bool)
294291

295-
cdef int n_bagged = 0
296-
cdef int i = 0
292+
cdef np.npy_intp n_bagged = 0
293+
cdef np.npy_intp i = 0
297294

298-
for i from 0 <= i < n_total_samples:
295+
for i in range(n_total_samples):
299296
if rand[i] * (n_total_samples - i) < (n_total_in_bag - n_bagged):
300297
sample_mask[i] = 1
301298
n_bagged += 1
302299

303-
return sample_mask.astype(np_bool)
300+
return sample_mask

sklearn/feature_extraction/image.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from numpy.lib.stride_tricks import as_strided
1717

1818
from ..utils import array2d, check_random_state
19+
from ..utils.fixes import astype
1920
from ..base import BaseEstimator
2021

2122
__all__ = ['PatchExtractor',
@@ -107,7 +108,8 @@ def _to_graph(n_x, n_y, n_z, mask=None, img=None,
107108
n_voxels = diag.size
108109
else:
109110
if mask is not None:
110-
mask = mask.astype(np.bool)
111+
mask = astype(mask, dtype=np.bool, copy=False)
112+
mask = np.asarray(mask, dtype=np.bool)
111113
edges = _mask_edges_weights(mask, edges)
112114
n_voxels = np.sum(mask)
113115
else:
@@ -147,7 +149,7 @@ def img_to_graph(img, mask=None, return_as=sparse.coo_matrix, dtype=None):
147149
dtype of img
148150
149151
Notes
150-
===========
152+
=====
151153
For sklearn versions 0.14.1 and prior, return_as=np.ndarray was handled
152154
by returning a dense np.matrix instance. Going forward, np.ndarray
153155
returns an np.ndarray, as expected.
@@ -183,7 +185,7 @@ def grid_to_graph(n_x, n_y, n_z=1, mask=None, return_as=sparse.coo_matrix,
183185
The data of the returned sparse matrix. By default it is int
184186
185187
Notes
186-
===========
188+
=====
187189
For sklearn versions 0.14.1 and prior, return_as=np.ndarray was handled
188190
by returning a dense np.matrix instance. Going forward, np.ndarray
189191
returns an np.ndarray, as expected.

sklearn/gaussian_process/gaussian_process.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def l1_cross_distances(X):
5252
ij[ll_0:ll_1, 1] = np.arange(k + 1, n_samples)
5353
D[ll_0:ll_1] = np.abs(X[k] - X[(k + 1):n_samples])
5454

55-
return D, ij.astype(np.int)
55+
return D, ij
5656

5757

5858
class GaussianProcess(BaseEstimator, RegressorMixin):

sklearn/learning_curve.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
#
55
# License: BSD 3 clause
66

7-
import numpy as np
87
import warnings
8+
9+
import numpy as np
10+
911
from .base import is_classifier, clone
1012
from .cross_validation import _check_cv
11-
from .utils import check_arrays
1213
from .externals.joblib import Parallel, delayed
1314
from .cross_validation import _safe_split, _score, _fit_and_score
1415
from .metrics.scorer import check_scoring
16+
from .utils import check_arrays
17+
from .utils.fixes import astype
1518

1619

1720
def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 5),
@@ -175,8 +178,8 @@ def _translate_train_sizes(train_sizes, n_max_training_samples):
175178
"must be within (0, 1], but is within [%f, %f]."
176179
% (n_min_required_samples,
177180
n_max_required_samples))
178-
train_sizes_abs = (train_sizes_abs
179-
* n_max_training_samples).astype(np.int)
181+
train_sizes_abs = astype(train_sizes_abs * n_max_training_samples,
182+
dtype=np.int, copy=False)
180183
train_sizes_abs = np.clip(train_sizes_abs, 1,
181184
n_max_training_samples)
182185
else:

sklearn/linear_model/randomized_l1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _resample_model(estimator_func, X, y, scaling=.5, n_resampling=200,
4949
verbose=max(0, verbose - 1),
5050
**params)
5151
for _ in range(n_resampling)):
52-
scores_ += active_set.astype(np.float)
52+
scores_ += active_set
5353

5454
scores_ /= n_resampling
5555
return scores_

sklearn/naive_bayes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,6 @@ def fit(self, X, y, sample_weight=None):
298298
Returns self.
299299
"""
300300
X, y = check_arrays(X, y, sparse_format='csr')
301-
X = X.astype(np.float)
302301
y = column_or_1d(y, warn=True)
303302
_, n_features = X.shape
304303

@@ -308,7 +307,8 @@ def fit(self, X, y, sample_weight=None):
308307
if Y.shape[1] == 1:
309308
Y = np.concatenate((1 - Y, Y), axis=1)
310309

311-
# convert to float to support sample weight consistently
310+
# convert to float to support sample weight consistently;
311+
# this means we also don't have to cast X to floating point
312312
Y = Y.astype(np.float64)
313313
if sample_weight is not None:
314314
Y *= array2d(sample_weight).T

0 commit comments

Comments
 (0)