66# Author: Andreas Mueller
77# Joris Van den Bossche
88# License: BSD
9-
9+ from itertools import chain
1010
1111import numpy as np
1212from scipy import sparse
@@ -69,14 +69,17 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
6969 ``transformer`` expects X to be a 1d array-like (vector),
7070 otherwise a 2d array will be passed to the transformer.
7171
72- remainder : {'passthrough', 'drop'}, default 'passthrough'
72+ remainder : {'passthrough', 'drop'} or estimator , default 'passthrough'
7373 By default, all remaining columns that were not specified in
7474 `transformers` will be automatically passed through (default of
7575 ``'passthrough'``). This subset of columns is concatenated with the
7676 output of the transformers.
7777 By using ``remainder='drop'``, only the specified columns in
7878 `transformers` are transformed and combined in the output, and the
7979 non-specified columns are dropped.
80+ By setting ``remainder`` to be an estimator, the remaining
81+ non-specified columns will use the ``remainder`` estimator. The
82+ estimator must support `fit` and `transform`.
8083
8184 n_jobs : int, optional
8285 Number of jobs to run in parallel (default 1).
@@ -90,7 +93,13 @@ class ColumnTransformer(_BaseComposition, TransformerMixin):
9093 ----------
9194 transformers_ : list
9295 The collection of fitted transformers as tuples of
93- (name, fitted_transformer, column).
96+ (name, fitted_transformer, column). `fitted_transformer` can be an
97+ estimator, 'drop', or 'passthrough'. If there are remaining columns,
98+ the final element is a tuple of the form:
99+ ('remainder', transformer, remaining_columns) corresponding to the
100+ ``remainder`` parameter. If there are remaining columns, then
101+ ``len(transformers_)==len(transformers)+1``, otherwise
102+ ``len(transformers_)==len(transformers)``.
94103
95104 named_transformers_ : Bunch object, a dictionary with attribute access
96105 Read-only attribute to access any transformer by given name.
@@ -188,13 +197,12 @@ def _iter(self, X=None, fitted=False, replace_strings=False):
188197 transformers = self .transformers_
189198 else :
190199 transformers = self .transformers
200+ if self ._remainder [2 ] is not None :
201+ transformers = chain (transformers , [self ._remainder ])
191202 get_weight = (self .transformer_weights or {}).get
192203
193204 for name , trans , column in transformers :
194- if X is None :
195- sub = X
196- else :
197- sub = _get_column (X , column )
205+ sub = None if X is None else _get_column (X , column )
198206
199207 if replace_strings :
200208 # replace 'passthrough' with identity transformer and
@@ -209,7 +217,10 @@ def _iter(self, X=None, fitted=False, replace_strings=False):
209217 yield (name , trans , sub , get_weight (name ))
210218
211219 def _validate_transformers (self ):
212- names , transformers , _ , _ = zip (* self ._iter ())
220+ if not self .transformers :
221+ return
222+
223+ names , transformers , _ = zip (* self .transformers )
213224
214225 # validate names
215226 self ._validate_names (names )
@@ -226,24 +237,27 @@ def _validate_transformers(self):
226237 (t , type (t )))
227238
228239 def _validate_remainder (self , X ):
229- """Generate list of passthrough columns for 'remainder' case."""
230- if self .remainder not in ('drop' , 'passthrough' ):
240+ """
241+ Validates ``remainder`` and defines ``_remainder`` targeting
242+ the remaining columns.
243+ """
244+ is_transformer = ((hasattr (self .remainder , "fit" )
245+ or hasattr (self .remainder , "fit_transform" ))
246+ and hasattr (self .remainder , "transform" ))
247+ if (self .remainder not in ('drop' , 'passthrough' )
248+ and not is_transformer ):
231249 raise ValueError (
232- "The remainder keyword needs to be one of 'drop' or "
233- "'passthrough'. {0:r} was passed instead" )
250+ "The remainder keyword needs to be one of 'drop', "
251+ "'passthrough', or estimator. '%s' was passed instead" %
252+ self .remainder )
234253
235254 n_columns = X .shape [1 ]
255+ cols = []
256+ for _ , _ , columns in self .transformers :
257+ cols .extend (_get_column_indices (X , columns ))
258+ remaining_idx = sorted (list (set (range (n_columns )) - set (cols ))) or None
236259
237- if self .remainder == 'passthrough' :
238- cols = []
239- for _ , _ , columns in self .transformers :
240- cols .extend (_get_column_indices (X , columns ))
241- self ._passthrough = sorted (list (set (range (n_columns )) - set (cols )))
242- if not self ._passthrough :
243- # empty list -> no need to select passthrough columns
244- self ._passthrough = None
245- else :
246- self ._passthrough = None
260+ self ._remainder = ('remainder' , self .remainder , remaining_idx )
247261
248262 @property
249263 def named_transformers_ (self ):
@@ -267,12 +281,6 @@ def get_feature_names(self):
267281 Names of the features produced by transform.
268282 """
269283 check_is_fitted (self , 'transformers_' )
270- if self ._passthrough is not None :
271- raise NotImplementedError (
272- "get_feature_names is not yet supported when having columns"
273- "that are passed through (you specify remainder='drop' to not "
274- "pass through the unspecified columns)." )
275-
276284 feature_names = []
277285 for name , trans , _ , _ in self ._iter (fitted = True ):
278286 if trans == 'drop' :
@@ -294,7 +302,11 @@ def _update_fitted_transformers(self, transformers):
294302 transformers = iter (transformers )
295303 transformers_ = []
296304
297- for name , old , column in self .transformers :
305+ transformer_iter = self .transformers
306+ if self ._remainder [2 ] is not None :
307+ transformer_iter = chain (transformer_iter , [self ._remainder ])
308+
309+ for name , old , column in transformer_iter :
298310 if old == 'drop' :
299311 trans = 'drop'
300312 elif old == 'passthrough' :
@@ -304,7 +316,6 @@ def _update_fitted_transformers(self, transformers):
304316 trans = 'passthrough'
305317 else :
306318 trans = next (transformers )
307-
308319 transformers_ .append ((name , trans , column ))
309320
310321 # sanity check that transformers is exhausted
@@ -335,7 +346,7 @@ def _fit_transform(self, X, y, func, fitted=False):
335346 return Parallel (n_jobs = self .n_jobs )(
336347 delayed (func )(clone (trans ) if not fitted else trans ,
337348 X_sel , y , weight )
338- for name , trans , X_sel , weight in self ._iter (
349+ for _ , trans , X_sel , weight in self ._iter (
339350 X = X , fitted = fitted , replace_strings = True ))
340351 except ValueError as e :
341352 if "Expected 2D array, got 1D array instead" in str (e ):
@@ -361,12 +372,12 @@ def fit(self, X, y=None):
361372 This estimator
362373
363374 """
364- self ._validate_transformers ()
365375 self ._validate_remainder (X )
376+ self ._validate_transformers ()
366377
367378 transformers = self ._fit_transform (X , y , _fit_one_transformer )
368-
369379 self ._update_fitted_transformers (transformers )
380+
370381 return self
371382
372383 def fit_transform (self , X , y = None ):
@@ -390,31 +401,21 @@ def fit_transform(self, X, y=None):
390401 sparse matrices.
391402
392403 """
393- self ._validate_transformers ()
394404 self ._validate_remainder (X )
405+ self ._validate_transformers ()
395406
396407 result = self ._fit_transform (X , y , _fit_transform_one )
397408
398409 if not result :
399410 # All transformers are None
400- if self ._passthrough is None :
401- return np .zeros ((X .shape [0 ], 0 ))
402- else :
403- return _get_column (X , self ._passthrough )
411+ return np .zeros ((X .shape [0 ], 0 ))
404412
405413 Xs , transformers = zip (* result )
406414
407415 self ._update_fitted_transformers (transformers )
408416 self ._validate_output (Xs )
409417
410- if self ._passthrough is not None :
411- Xs = list (Xs ) + [_get_column (X , self ._passthrough )]
412-
413- if any (sparse .issparse (f ) for f in Xs ):
414- Xs = sparse .hstack (Xs ).tocsr ()
415- else :
416- Xs = np .hstack (Xs )
417- return Xs
418+ return _hstack (list (Xs ))
418419
419420 def transform (self , X ):
420421 """Transform X separately by each transformer, concatenate results.
@@ -440,19 +441,9 @@ def transform(self, X):
440441
441442 if not Xs :
442443 # All transformers are None
443- if self ._passthrough is None :
444- return np .zeros ((X .shape [0 ], 0 ))
445- else :
446- return _get_column (X , self ._passthrough )
447-
448- if self ._passthrough is not None :
449- Xs = list (Xs ) + [_get_column (X , self ._passthrough )]
444+ return np .zeros ((X .shape [0 ], 0 ))
450445
451- if any (sparse .issparse (f ) for f in Xs ):
452- Xs = sparse .hstack (Xs ).tocsr ()
453- else :
454- Xs = np .hstack (Xs )
455- return Xs
446+ return _hstack (list (Xs ))
456447
457448
458449def _check_key_type (key , superclass ):
@@ -486,6 +477,19 @@ def _check_key_type(key, superclass):
486477 return False
487478
488479
480+ def _hstack (X ):
481+ """
482+ Stacks X horizontally.
483+
484+ Supports input types (X): list of
485+ numpy arrays, sparse arrays and DataFrames
486+ """
487+ if any (sparse .issparse (f ) for f in X ):
488+ return sparse .hstack (X ).tocsr ()
489+ else :
490+ return np .hstack (X )
491+
492+
489493def _get_column (X , key ):
490494 """
491495 Get feature column(s) from input data X.
@@ -612,14 +616,17 @@ def make_column_transformer(*transformers, **kwargs):
612616 ----------
613617 *transformers : tuples of column selections and transformers
614618
615- remainder : {'passthrough', 'drop'}, default 'passthrough'
619+ remainder : {'passthrough', 'drop'} or estimator , default 'passthrough'
616620 By default, all remaining columns that were not specified in
617621 `transformers` will be automatically passed through (default of
618622 ``'passthrough'``). This subset of columns is concatenated with the
619623 output of the transformers.
620624 By using ``remainder='drop'``, only the specified columns in
621625 `transformers` are transformed and combined in the output, and the
622626 non-specified columns are dropped.
627+ By setting ``remainder`` to be an estimator, the remaining
628+ non-specified columns will use the ``remainder`` estimator. The
629+ estimator must support `fit` and `transform`.
623630
624631 n_jobs : int, optional
625632 Number of jobs to run in parallel (default 1).
0 commit comments