Skip to content

Commit cb0916c

Browse files
Joe Jevnikamueller
authored andcommitted
ENH: Adds CallableTransformer
CallableTransformer allows a user to convert a standard python callable into a transformer for use in a Pipeline.
1 parent 16337b2 commit cb0916c

File tree

7 files changed

+219
-1
lines changed

7 files changed

+219
-1
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,7 @@ See the :ref:`metrics` section of the user guide for further details.
11041104
:template: class.rst
11051105

11061106
preprocessing.Binarizer
1107+
preprocessing.CallableTransformer
11071108
preprocessing.Imputer
11081109
preprocessing.KernelCenterer
11091110
preprocessing.LabelBinarizer

doc/modules/preprocessing.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,23 @@ The features of X have been transformed from :math:`(X_1, X_2, X_3)` to :math:`(
508508
Note that polynomial features are used implicitily in `kernel methods <http://en.wikipedia.org/wiki/Kernel_method>`_ (e.g., :class:`sklearn.svm.SVC`, :class:`sklearn.decomposition.KernelPCA`) when using polynomial :ref:`svm_kernels`.
509509

510510
See :ref:`example_linear_model_plot_polynomial_interpolation.py` for Ridge regression using created polynomial features.
511+
512+
Custom Transformers
513+
===================
514+
515+
Often, you will want to convert an existing python function into a transformer
516+
to assist in data cleaning or processing. Users may implement a transformer from
517+
an arbitrary callable with :class:`CallableTransformer`. For example, one could
518+
apply a log transformation in a pipeline like::
519+
520+
>>> import numpy as np
521+
>>> from sklearn.preprocessing import CallableTransformer
522+
>>> transformer = CallableTransformer(np.log)
523+
>>> X = np.array([[1, 2], [3, 4]])
524+
>>> transformer.transform(X)
525+
array([[ 0. , 0.69314718],
526+
[ 1.09861229, 1.38629436]])
527+
528+
For a full code example that demonstrates using a :class:`CallableTransformer`
529+
to do column selection,
530+
see :ref:`example_preprocessing_plot_callable_transformer.py`
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
=========================================================
3+
Using CallableTransformer to select columns
4+
=========================================================
5+
6+
Shows how to use a callable transformer in a pipeline. If you know your
7+
dataset's first principle component is irrelevant for a classification task,
8+
you can use the CallableTransformer to select all but the first column of the
9+
PCA transformed data.
10+
"""
11+
import matplotlib.pyplot as plt
12+
import numpy as np
13+
14+
from sklearn.cross_validation import train_test_split
15+
from sklearn.decomposition import PCA
16+
from sklearn.pipeline import make_pipeline
17+
from sklearn.preprocessing import CallableTransformer
18+
19+
20+
def _generate_vector(shift=0.5, noise=15):
21+
return np.arange(1000) + (np.random.rand(1000) - shift) * noise
22+
23+
24+
def generate_dataset():
25+
"""
26+
This dataset is two lines with a slope ~ 1, where one has
27+
a y offset of ~100
28+
"""
29+
return np.vstack((
30+
np.vstack((
31+
_generate_vector(),
32+
_generate_vector() + 100,
33+
)).T,
34+
np.vstack((
35+
_generate_vector(),
36+
_generate_vector(),
37+
)).T,
38+
)), np.hstack((np.zeros(1000), np.ones(1000)))
39+
40+
41+
def all_but_first_column(X, y):
42+
return X[:, 1:]
43+
44+
45+
def drop_first_component(X, y):
46+
"""
47+
Create a pipeline with PCA and the column selector and use it to
48+
transform the dataset.
49+
"""
50+
pipeline = make_pipeline(
51+
PCA(), CallableTransformer(all_but_first_column),
52+
)
53+
X_train, X_test, y_train, y_test = train_test_split(X, y)
54+
pipeline.fit(X_train, y_train)
55+
return pipeline.transform(X_test), y_test
56+
57+
58+
if __name__ == '__main__':
59+
X, y = generate_dataset()
60+
plt.scatter(X[:, 0], X[:, 1], c=y, s=50)
61+
plt.show()
62+
X_transformed, y_transformed = drop_first_component(*generate_dataset())
63+
plt.scatter(
64+
X_transformed[:, 0],
65+
np.zeros(len(X_transformed)),
66+
c=y_transformed,
67+
s=50,
68+
)
69+
plt.show()

sklearn/preprocessing/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
normalization, binarization and imputation methods.
44
"""
55

6+
from .callable_transformer import CallableTransformer
7+
68
from .data import Binarizer
79
from .data import KernelCenterer
810
from .data import MinMaxScaler
@@ -28,8 +30,10 @@
2830

2931
from .imputation import Imputer
3032

33+
3134
__all__ = [
3235
'Binarizer',
36+
'CallableTransformer',
3337
'Imputer',
3438
'KernelCenterer',
3539
'LabelBinarizer',
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from ..base import BaseEstimator, TransformerMixin
2+
from ..utils import check_array
3+
4+
5+
class CallableTransformer(BaseEstimator, TransformerMixin):
6+
"""Allows the construction of a transformer from an arbitrary callable.
7+
8+
Parameters
9+
----------
10+
func : callable, optional default=None
11+
The callable to use for the transformation. This will be passed
12+
the same arguments as transform, with args and kwargs forwarded.
13+
If func is None, then func will be the identity function.
14+
validate : bool, optional default=True
15+
Indicate that the input X array should be checked before calling
16+
func. If validate is false, there will be no input validation.
17+
accept_sparse : boolean, optional
18+
Indicate that func accepts a sparse matrix as input.
19+
args : tuple, optional
20+
A tuple of positional arguments to be passed to func. These will
21+
be passed after X and y.
22+
kwargs : dict, optional
23+
A dictionary of keyword arguments to be passed to func.
24+
25+
"""
26+
def __init__(self, func=None, validate=True, accept_sparse=False,
27+
args=None, kwargs=None):
28+
self.func = func
29+
self.validate = validate
30+
self.accept_sparse = accept_sparse
31+
self.args = args
32+
self.kwargs = kwargs
33+
34+
def fit(self, X, y=None):
35+
if self.validate:
36+
check_array(X, self.accept_sparse)
37+
return self
38+
39+
def transform(self, X, y=None):
40+
if self.validate:
41+
X = check_array(X, self.accept_sparse)
42+
return (self.func or (lambda X, y, *args, **kwargs: X))(
43+
X, y, *(self.args or ()), **(self.kwargs or {})
44+
)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from nose.tools import assert_equal
2+
import numpy as np
3+
4+
from ..callable_transformer import CallableTransformer
5+
6+
7+
def _make_func(args_store, kwargs_store, func=lambda X, *a, **k: X):
8+
def _func(X, *args, **kwargs):
9+
args_store.append(X)
10+
args_store.extend(args)
11+
kwargs_store.update(kwargs)
12+
return func(X)
13+
14+
return _func
15+
16+
17+
def test_delegate_to_func():
18+
# (args|kwargs)_store will hold the positional and keyword arguments
19+
# passed to the function inside the CallableTransformer.
20+
args_store = []
21+
kwargs_store = {}
22+
X = np.arange(10).reshape((5, 2))
23+
np.testing.assert_array_equal(
24+
CallableTransformer(_make_func(args_store, kwargs_store)).transform(X),
25+
X,
26+
'transform should have returned X unchanged',
27+
)
28+
29+
# The function should only have recieved X and y, where y is None.
30+
assert_equal(
31+
args_store,
32+
[X, None],
33+
'Incorrect positional arguments passed to func: {args}'.format(
34+
args=args_store,
35+
),
36+
)
37+
assert_equal(
38+
kwargs_store,
39+
{},
40+
'Unexpected keyword arguments passed to func: {args}'.format(
41+
args=kwargs_store,
42+
),
43+
)
44+
45+
46+
def test_argument_closure():
47+
# (args|kwargs)_store will hold the positional and keyword arguments
48+
# passed to the function inside the CallableTransformer.
49+
args_store = []
50+
kwargs_store = {}
51+
args = (object(), object())
52+
kwargs = {'a': object(), 'b': object()}
53+
X = np.arange(10).reshape((5, 2))
54+
np.testing.assert_array_equal(
55+
CallableTransformer(
56+
_make_func(args_store, kwargs_store),
57+
args=args,
58+
kwargs=kwargs,
59+
).transform(X),
60+
X,
61+
'transform should have returned X unchanged',
62+
)
63+
64+
# The function should have been passed X, y (None), and the args
65+
# that were passed to the CallableTransformer.
66+
assert_equal(
67+
args_store,
68+
[X, None] + list(args),
69+
'Incorrect positional arguments passed to func: {args}'.format(
70+
args=args_store,
71+
),
72+
)
73+
assert_equal(
74+
kwargs_store,
75+
kwargs,
76+
'Incorrect keyword arguments passed to func: {args}'.format(
77+
args=kwargs_store,
78+
),
79+
)

sklearn/utils/estimator_checks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def _yield_transformer_checks(name, Transformer):
138138
'PLSCanonical', 'PLSRegression', 'CCA', 'PLSSVD']:
139139
yield check_transformer_data_not_an_array
140140
# these don't actually fit the data, so don't raise errors
141-
if name not in ['AdditiveChi2Sampler', 'Binarizer', 'Normalizer']:
141+
if name not in ['AdditiveChi2Sampler', 'Binarizer',
142+
'Normalizer', 'CallableTransformer']:
142143
# basic tests
143144
yield check_transformer_general
144145
yield check_transformers_unfitted

0 commit comments

Comments
 (0)