Skip to content

Commit 895dfd3

Browse files
thomasjpfanjnothman
authored andcommitted
ENH Adds transformer support in ColumnTransformer.remainder (scikit-learn#11315)
1 parent f89131b commit 895dfd3

File tree

4 files changed

+345
-95
lines changed

4 files changed

+345
-95
lines changed

doc/modules/compose.rst

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,22 +404,26 @@ preprocessing or a specific feature extraction method::
404404
>>> X = pd.DataFrame(
405405
... {'city': ['London', 'London', 'Paris', 'Sallisaw'],
406406
... 'title': ["His Last Bow", "How Watson Learned the Trick",
407-
... "A Moveable Feast", "The Grapes of Wrath"]})
407+
... "A Moveable Feast", "The Grapes of Wrath"],
408+
... 'expert_rating': [5, 3, 4, 5],
409+
... 'user_rating': [4, 5, 4, 3]})
408410

409411
For this data, we might want to encode the ``'city'`` column as a categorical
410412
variable, but apply a :class:`feature_extraction.text.CountVectorizer
411413
<sklearn.feature_extraction.text.CountVectorizer>` to the ``'title'`` column.
412414
As we might use multiple feature extraction methods on the same column, we give
413-
each transformer a unique name, say ``'city_category'`` and ``'title_bow'``::
415+
each transformer a unique name, say ``'city_category'`` and ``'title_bow'``.
416+
We can ignore the remaining rating columns by setting ``remainder='drop'``::
414417

415418
>>> from sklearn.compose import ColumnTransformer
416419
>>> from sklearn.feature_extraction.text import CountVectorizer
417420
>>> column_trans = ColumnTransformer(
418421
... [('city_category', CountVectorizer(analyzer=lambda x: [x]), 'city'),
419-
... ('title_bow', CountVectorizer(), 'title')])
422+
... ('title_bow', CountVectorizer(), 'title')],
423+
... remainder='drop')
420424

421425
>>> column_trans.fit(X) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
422-
ColumnTransformer(n_jobs=1, remainder='passthrough', transformer_weights=None,
426+
ColumnTransformer(n_jobs=1, remainder='drop', transformer_weights=None,
423427
transformers=...)
424428

425429
>>> column_trans.get_feature_names()
@@ -448,6 +452,39 @@ as a list of multiple items, an integer array, a slice, or a boolean mask.
448452
Strings can reference columns if the input is a DataFrame, integers are always
449453
interpreted as the positional columns.
450454

455+
We can keep the remaining rating columns by setting
456+
``remainder='passthrough'``. The values are appended to the end of the
457+
transformation::
458+
459+
>>> column_trans = ColumnTransformer(
460+
... [('city_category', CountVectorizer(analyzer=lambda x: [x]), 'city'),
461+
... ('title_bow', CountVectorizer(), 'title')],
462+
... remainder='passthrough')
463+
464+
>>> column_trans.fit_transform(X).toarray()
465+
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
466+
array([[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 5, 4],
467+
[1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 3, 5],
468+
[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 4],
469+
[0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 5, 3]]...)
470+
471+
The ``remainder`` parameter can be set to an estimator to transform the
472+
remaining rating columns. The transformed values are appended to the end of
473+
the transformation::
474+
475+
>>> from sklearn.preprocessing import MinMaxScaler
476+
>>> column_trans = ColumnTransformer(
477+
... [('city_category', CountVectorizer(analyzer=lambda x: [x]), 'city'),
478+
... ('title_bow', CountVectorizer(), 'title')],
479+
... remainder=MinMaxScaler())
480+
481+
>>> column_trans.fit_transform(X)[:, -2:].toarray()
482+
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
483+
array([[1. , 0.5],
484+
[0. , 1. ],
485+
[0.5, 0.5],
486+
[1. , 0. ]])
487+
451488
The :func:`~sklearn.compose.make_columntransformer` function is available
452489
to more easily create a :class:`~sklearn.compose.ColumnTransformer` object.
453490
Specifically, the names will be given automatically. The equivalent for the

sklearn/compose/_column_transformer.py

Lines changed: 66 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Author: Andreas Mueller
77
# Joris Van den Bossche
88
# License: BSD
9-
9+
from itertools import chain
1010

1111
import numpy as np
1212
from scipy import sparse
@@ -69,14 +69,17 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
6969
``transformer`` expects X to be a 1d array-like (vector),
7070
otherwise a 2d array will be passed to the transformer.
7171
72-
remainder : {'passthrough', 'drop'}, default 'passthrough'
72+
remainder : {'passthrough', 'drop'} or estimator, default 'passthrough'
7373
By default, all remaining columns that were not specified in
7474
`transformers` will be automatically passed through (default of
7575
``'passthrough'``). This subset of columns is concatenated with the
7676
output of the transformers.
7777
By using ``remainder='drop'``, only the specified columns in
7878
`transformers` are transformed and combined in the output, and the
7979
non-specified columns are dropped.
80+
By setting ``remainder`` to be an estimator, the remaining
81+
non-specified columns will use the ``remainder`` estimator. The
82+
estimator must support `fit` and `transform`.
8083
8184
n_jobs : int, optional
8285
Number of jobs to run in parallel (default 1).
@@ -90,7 +93,13 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
9093
----------
9194
transformers_ : list
9295
The collection of fitted transformers as tuples of
93-
(name, fitted_transformer, column).
96+
(name, fitted_transformer, column). `fitted_transformer` can be an
97+
estimator, 'drop', or 'passthrough'. If there are remaining columns,
98+
the final element is a tuple of the form:
99+
('remainder', transformer, remaining_columns) corresponding to the
100+
``remainder`` parameter. If there are remaining columns, then
101+
``len(transformers_)==len(transformers)+1``, otherwise
102+
``len(transformers_)==len(transformers)``.
94103
95104
named_transformers_ : Bunch object, a dictionary with attribute access
96105
Read-only attribute to access any transformer by given name.
@@ -188,13 +197,12 @@ def _iter(self, X=None, fitted=False, replace_strings=False):
188197
transformers = self.transformers_
189198
else:
190199
transformers = self.transformers
200+
if self._remainder[2] is not None:
201+
transformers = chain(transformers, [self._remainder])
191202
get_weight = (self.transformer_weights or {}).get
192203

193204
for name, trans, column in transformers:
194-
if X is None:
195-
sub = X
196-
else:
197-
sub = _get_column(X, column)
205+
sub = None if X is None else _get_column(X, column)
198206

199207
if replace_strings:
200208
# replace 'passthrough' with identity transformer and
@@ -209,7 +217,10 @@ def _iter(self, X=None, fitted=False, replace_strings=False):
209217
yield (name, trans, sub, get_weight(name))
210218

211219
def _validate_transformers(self):
212-
names, transformers, _, _ = zip(*self._iter())
220+
if not self.transformers:
221+
return
222+
223+
names, transformers, _ = zip(*self.transformers)
213224

214225
# validate names
215226
self._validate_names(names)
@@ -226,24 +237,27 @@ def _validate_transformers(self):
226237
(t, type(t)))
227238

228239
def _validate_remainder(self, X):
229-
"""Generate list of passthrough columns for 'remainder' case."""
230-
if self.remainder not in ('drop', 'passthrough'):
240+
"""
241+
Validates ``remainder`` and defines ``_remainder`` targeting
242+
the remaining columns.
243+
"""
244+
is_transformer = ((hasattr(self.remainder, "fit")
245+
or hasattr(self.remainder, "fit_transform"))
246+
and hasattr(self.remainder, "transform"))
247+
if (self.remainder not in ('drop', 'passthrough')
248+
and not is_transformer):
231249
raise ValueError(
232-
"The remainder keyword needs to be one of 'drop' or "
233-
"'passthrough'. {0:r} was passed instead")
250+
"The remainder keyword needs to be one of 'drop', "
251+
"'passthrough', or estimator. '%s' was passed instead" %
252+
self.remainder)
234253

235254
n_columns = X.shape[1]
255+
cols = []
256+
for _, _, columns in self.transformers:
257+
cols.extend(_get_column_indices(X, columns))
258+
remaining_idx = sorted(list(set(range(n_columns)) - set(cols))) or None
236259

237-
if self.remainder == 'passthrough':
238-
cols = []
239-
for _, _, columns in self.transformers:
240-
cols.extend(_get_column_indices(X, columns))
241-
self._passthrough = sorted(list(set(range(n_columns)) - set(cols)))
242-
if not self._passthrough:
243-
# empty list -> no need to select passthrough columns
244-
self._passthrough = None
245-
else:
246-
self._passthrough = None
260+
self._remainder = ('remainder', self.remainder, remaining_idx)
247261

248262
@property
249263
def named_transformers_(self):
@@ -267,12 +281,6 @@ def get_feature_names(self):
267281
Names of the features produced by transform.
268282
"""
269283
check_is_fitted(self, 'transformers_')
270-
if self._passthrough is not None:
271-
raise NotImplementedError(
272-
"get_feature_names is not yet supported when having columns"
273-
"that are passed through (you specify remainder='drop' to not "
274-
"pass through the unspecified columns).")
275-
276284
feature_names = []
277285
for name, trans, _, _ in self._iter(fitted=True):
278286
if trans == 'drop':
@@ -294,7 +302,11 @@ def _update_fitted_transformers(self, transformers):
294302
transformers = iter(transformers)
295303
transformers_ = []
296304

297-
for name, old, column in self.transformers:
305+
transformer_iter = self.transformers
306+
if self._remainder[2] is not None:
307+
transformer_iter = chain(transformer_iter, [self._remainder])
308+
309+
for name, old, column in transformer_iter:
298310
if old == 'drop':
299311
trans = 'drop'
300312
elif old == 'passthrough':
@@ -304,7 +316,6 @@ def _update_fitted_transformers(self, transformers):
304316
trans = 'passthrough'
305317
else:
306318
trans = next(transformers)
307-
308319
transformers_.append((name, trans, column))
309320

310321
# sanity check that transformers is exhausted
@@ -335,7 +346,7 @@ def _fit_transform(self, X, y, func, fitted=False):
335346
return Parallel(n_jobs=self.n_jobs)(
336347
delayed(func)(clone(trans) if not fitted else trans,
337348
X_sel, y, weight)
338-
for name, trans, X_sel, weight in self._iter(
349+
for _, trans, X_sel, weight in self._iter(
339350
X=X, fitted=fitted, replace_strings=True))
340351
except ValueError as e:
341352
if "Expected 2D array, got 1D array instead" in str(e):
@@ -361,12 +372,12 @@ def fit(self, X, y=None):
361372
This estimator
362373
363374
"""
364-
self._validate_transformers()
365375
self._validate_remainder(X)
376+
self._validate_transformers()
366377

367378
transformers = self._fit_transform(X, y, _fit_one_transformer)
368-
369379
self._update_fitted_transformers(transformers)
380+
370381
return self
371382

372383
def fit_transform(self, X, y=None):
@@ -390,31 +401,21 @@ def fit_transform(self, X, y=None):
390401
sparse matrices.
391402
392403
"""
393-
self._validate_transformers()
394404
self._validate_remainder(X)
405+
self._validate_transformers()
395406

396407
result = self._fit_transform(X, y, _fit_transform_one)
397408

398409
if not result:
399410
# All transformers are None
400-
if self._passthrough is None:
401-
return np.zeros((X.shape[0], 0))
402-
else:
403-
return _get_column(X, self._passthrough)
411+
return np.zeros((X.shape[0], 0))
404412

405413
Xs, transformers = zip(*result)
406414

407415
self._update_fitted_transformers(transformers)
408416
self._validate_output(Xs)
409417

410-
if self._passthrough is not None:
411-
Xs = list(Xs) + [_get_column(X, self._passthrough)]
412-
413-
if any(sparse.issparse(f) for f in Xs):
414-
Xs = sparse.hstack(Xs).tocsr()
415-
else:
416-
Xs = np.hstack(Xs)
417-
return Xs
418+
return _hstack(list(Xs))
418419

419420
def transform(self, X):
420421
"""Transform X separately by each transformer, concatenate results.
@@ -440,19 +441,9 @@ def transform(self, X):
440441

441442
if not Xs:
442443
# All transformers are None
443-
if self._passthrough is None:
444-
return np.zeros((X.shape[0], 0))
445-
else:
446-
return _get_column(X, self._passthrough)
447-
448-
if self._passthrough is not None:
449-
Xs = list(Xs) + [_get_column(X, self._passthrough)]
444+
return np.zeros((X.shape[0], 0))
450445

451-
if any(sparse.issparse(f) for f in Xs):
452-
Xs = sparse.hstack(Xs).tocsr()
453-
else:
454-
Xs = np.hstack(Xs)
455-
return Xs
446+
return _hstack(list(Xs))
456447

457448

458449
def _check_key_type(key, superclass):
@@ -486,6 +477,19 @@ def _check_key_type(key, superclass):
486477
return False
487478

488479

480+
def _hstack(X):
481+
"""
482+
Stacks X horizontally.
483+
484+
Supports input types (X): list of
485+
numpy arrays, sparse arrays and DataFrames
486+
"""
487+
if any(sparse.issparse(f) for f in X):
488+
return sparse.hstack(X).tocsr()
489+
else:
490+
return np.hstack(X)
491+
492+
489493
def _get_column(X, key):
490494
"""
491495
Get feature column(s) from input data X.
@@ -612,14 +616,17 @@ def make_column_transformer(*transformers, **kwargs):
612616
----------
613617
*transformers : tuples of column selections and transformers
614618
615-
remainder : {'passthrough', 'drop'}, default 'passthrough'
619+
remainder : {'passthrough', 'drop'} or estimator, default 'passthrough'
616620
By default, all remaining columns that were not specified in
617621
`transformers` will be automatically passed through (default of
618622
``'passthrough'``). This subset of columns is concatenated with the
619623
output of the transformers.
620624
By using ``remainder='drop'``, only the specified columns in
621625
`transformers` are transformed and combined in the output, and the
622626
non-specified columns are dropped.
627+
By setting ``remainder`` to be an estimator, the remaining
628+
non-specified columns will use the ``remainder`` estimator. The
629+
estimator must support `fit` and `transform`.
623630
624631
n_jobs : int, optional
625632
Number of jobs to run in parallel (default 1).

0 commit comments

Comments
 (0)