Skip to content

Commit 9d14f73

Browse files
committed
ENH: Renames CallableTransformer -> FunctionTransformer.
Makes `pass_y` an argument to FunctionTransformer to indicate that the labels should be passed to the wrapped function.
1 parent cb0916c commit 9d14f73

File tree

7 files changed

+63
-51
lines changed

7 files changed

+63
-51
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1104,7 +1104,7 @@ See the :ref:`metrics` section of the user guide for further details.
11041104
:template: class.rst
11051105

11061106
preprocessing.Binarizer
1107-
preprocessing.CallableTransformer
1107+
preprocessing.FunctionTransformer
11081108
preprocessing.Imputer
11091109
preprocessing.KernelCenterer
11101110
preprocessing.LabelBinarizer

doc/modules/preprocessing.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -514,17 +514,17 @@ Custom Transformers
514514

515515
Often, you will want to convert an existing python function into a transformer
516516
to assist in data cleaning or processing. Users may implement a transformer from
517-
an arbitrary callable with :class:`CallableTransformer`. For example, one could
517+
an arbitrary function with :class:`FunctionTransformer`. For example, one could
518518
apply a log transformation in a pipeline like::
519519

520520
>>> import numpy as np
521-
>>> from sklearn.preprocessing import CallableTransformer
522-
>>> transformer = CallableTransformer(np.log)
521+
>>> from sklearn.preprocessing import FunctionTransformer
522+
>>> transformer = FunctionTransformer(np.log)
523523
>>> X = np.array([[1, 2], [3, 4]])
524524
>>> transformer.transform(X)
525525
array([[ 0. , 0.69314718],
526526
[ 1.09861229, 1.38629436]])
527527

528-
For a full code example that demonstrates using a :class:`CallableTransformer`
528+
For a full code example that demonstrates using a :class:`FunctionTransformer`
529529
to do column selection,
530-
see :ref:`example_preprocessing_plot_callable_transformer.py`
530+
see :ref:`example_preprocessing_plot_function_transformer.py`

examples/preprocessing/plot_callable_transformer.py renamed to examples/preprocessing/plot_function_transformer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""
22
=========================================================
3-
Using CallableTransformer to select columns
3+
Using FunctionTransformer to select columns
44
=========================================================
55
6-
Shows how to use a callable transformer in a pipeline. If you know your
6+
Shows how to use a function transformer in a pipeline. If you know your
77
dataset's first principle component is irrelevant for a classification task,
8-
you can use the CallableTransformer to select all but the first column of the
8+
you can use the FunctionTransformer to select all but the first column of the
99
PCA transformed data.
1010
"""
1111
import matplotlib.pyplot as plt
@@ -14,7 +14,7 @@
1414
from sklearn.cross_validation import train_test_split
1515
from sklearn.decomposition import PCA
1616
from sklearn.pipeline import make_pipeline
17-
from sklearn.preprocessing import CallableTransformer
17+
from sklearn.preprocessing import FunctionTransformer
1818

1919

2020
def _generate_vector(shift=0.5, noise=15):
@@ -38,7 +38,7 @@ def generate_dataset():
3838
)), np.hstack((np.zeros(1000), np.ones(1000)))
3939

4040

41-
def all_but_first_column(X, y):
41+
def all_but_first_column(X):
4242
return X[:, 1:]
4343

4444

@@ -48,7 +48,7 @@ def drop_first_component(X, y):
4848
transform the dataset.
4949
"""
5050
pipeline = make_pipeline(
51-
PCA(), CallableTransformer(all_but_first_column),
51+
PCA(), FunctionTransformer(all_but_first_column),
5252
)
5353
X_train, X_test, y_train, y_test = train_test_split(X, y)
5454
pipeline.fit(X_train, y_train)

sklearn/preprocessing/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
normalization, binarization and imputation methods.
44
"""
55

6-
from .callable_transformer import CallableTransformer
6+
from .function_transformer import FunctionTransformer
77

88
from .data import Binarizer
99
from .data import KernelCenterer
@@ -33,7 +33,7 @@
3333

3434
__all__ = [
3535
'Binarizer',
36-
'CallableTransformer',
36+
'FunctionTransformer',
3737
'Imputer',
3838
'KernelCenterer',
3939
'LabelBinarizer',

sklearn/preprocessing/callable_transformer.py renamed to sklearn/preprocessing/function_transformer.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,43 @@
22
from ..utils import check_array
33

44

5-
class CallableTransformer(BaseEstimator, TransformerMixin):
6-
"""Allows the construction of a transformer from an arbitrary callable.
5+
def _identity(X):
6+
"""The identity function.
7+
"""
8+
return X
9+
10+
11+
class FunctionTransformer(BaseEstimator, TransformerMixin):
12+
"""Constructs a transformer from an arbitrary callable.
13+
14+
Note: If a lambda is used as the function, then the resulting
15+
transformer will not be pickleable.
716
817
Parameters
918
----------
1019
func : callable, optional default=None
1120
The callable to use for the transformation. This will be passed
1221
the same arguments as transform, with args and kwargs forwarded.
1322
If func is None, then func will be the identity function.
23+
1424
validate : bool, optional default=True
1525
Indicate that the input X array should be checked before calling
1626
func. If validate is false, there will be no input validation.
27+
1728
accept_sparse : boolean, optional
1829
Indicate that func accepts a sparse matrix as input.
19-
args : tuple, optional
20-
A tuple of positional arguments to be passed to func. These will
21-
be passed after X and y.
22-
kwargs : dict, optional
23-
A dictionary of keyword arguments to be passed to func.
30+
31+
pass_y: bool, optional default=False
32+
Indicate that transform should forward the y argument to the
33+
inner callable.
2434
2535
"""
26-
def __init__(self, func=None, validate=True, accept_sparse=False,
27-
args=None, kwargs=None):
36+
def __init__(self, func=None, validate=True,
37+
accept_sparse=False, pass_y=False):
2838
self.func = func
2939
self.validate = validate
3040
self.accept_sparse = accept_sparse
31-
self.args = args
32-
self.kwargs = kwargs
41+
self.pass_y = pass_y
3342

3443
def fit(self, X, y=None):
3544
if self.validate:
@@ -39,6 +48,5 @@ def fit(self, X, y=None):
3948
def transform(self, X, y=None):
4049
if self.validate:
4150
X = check_array(X, self.accept_sparse)
42-
return (self.func or (lambda X, y, *args, **kwargs: X))(
43-
X, y, *(self.args or ()), **(self.kwargs or {})
44-
)
51+
52+
return (self.func or _identity)(X, *((y,) if self.pass_y else ()))
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from nose.tools import assert_equal
22
import numpy as np
33

4-
from ..callable_transformer import CallableTransformer
4+
from ..function_transformer import FunctionTransformer
55

66

77
def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X):
@@ -16,20 +16,20 @@ def _func(X, *args, **kwargs):
1616

1717
def test_delegate_to_func():
1818
# (args|kwargs)_store will hold the positional and keyword arguments
19-
# passed to the function inside the CallableTransformer.
19+
# passed to the function inside the FunctionTransformer.
2020
args_store = []
2121
kwargs_store = {}
2222
X = np.arange(10).reshape((5, 2))
2323
np.testing.assert_array_equal(
24-
CallableTransformer(_make_func(args_store, kwargs_store)).transform(X),
24+
FunctionTransformer(_make_func(args_store, kwargs_store)).transform(X),
2525
X,
2626
'transform should have returned X unchanged',
2727
)
2828

29-
# The function should only have recieved X and y, where y is None.
29+
# The function should only have recieved X.
3030
assert_equal(
3131
args_store,
32-
[X, None],
32+
[X],
3333
'Incorrect positional arguments passed to func: {args}'.format(
3434
args=args_store,
3535
),
@@ -42,38 +42,42 @@ def test_delegate_to_func():
4242
),
4343
)
4444

45+
# reset the argument stores.
46+
args_store.clear()
47+
kwargs_store.clear()
48+
y = object()
4549

46-
def test_argument_closure():
47-
# (args|kwargs)_store will hold the positional and keyword arguments
48-
# passed to the function inside the CallableTransformer.
49-
args_store = []
50-
kwargs_store = {}
51-
args = (object(), object())
52-
kwargs = {'a': object(), 'b': object()}
53-
X = np.arange(10).reshape((5, 2))
5450
np.testing.assert_array_equal(
55-
CallableTransformer(
51+
FunctionTransformer(
5652
_make_func(args_store, kwargs_store),
57-
args=args,
58-
kwargs=kwargs,
59-
).transform(X),
53+
pass_y=True,
54+
).transform(X, y),
6055
X,
6156
'transform should have returned X unchanged',
6257
)
6358

64-
# The function should have been passed X, y (None), and the args
65-
# that were passed to the CallableTransformer.
59+
# The function should have recieved X and y.
6660
assert_equal(
6761
args_store,
68-
[X, None] + list(args),
62+
[X, y],
6963
'Incorrect positional arguments passed to func: {args}'.format(
7064
args=args_store,
7165
),
7266
)
7367
assert_equal(
7468
kwargs_store,
75-
kwargs,
76-
'Incorrect keyword arguments passed to func: {args}'.format(
69+
{},
70+
'Unexpected keyword arguments passed to func: {args}'.format(
7771
args=kwargs_store,
7872
),
7973
)
74+
75+
76+
def test_np_log():
77+
X = np.arange(10).reshape((5, 2))
78+
79+
# Test that the numpy.log example still works.
80+
np.testing.assert_array_equal(
81+
FunctionTransformer(np.log).transform(X),
82+
np.log(X),
83+
)

sklearn/utils/estimator_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def _yield_transformer_checks(name, Transformer):
139139
yield check_transformer_data_not_an_array
140140
# these don't actually fit the data, so don't raise errors
141141
if name not in ['AdditiveChi2Sampler', 'Binarizer',
142-
'Normalizer', 'CallableTransformer']:
142+
'FunctionTransformer', 'Normalizer']:
143143
# basic tests
144144
yield check_transformer_general
145145
yield check_transformers_unfitted

0 commit comments

Comments
 (0)