Skip to content

Commit 7b8b986

Browse files
harsh020jnothman
authored andcommitted
ENH Change fit_transform of MissingIndicator class to get mask only once (scikit-learn#14356)
1 parent 4265923 commit 7b8b986

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

doc/whats_new/v0.22.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ Changelog
5656
:mod:`sklearn.datasets`
5757
.......................
5858

59-
- |Feature| :func:`datasets.fetch_openml` now supports heterogeneous data using pandas
59+
- |Feature| :func:`datasets.fetch_openml` now supports heterogeneous data using pandas
6060
by setting `as_frame=True`. :pr:`13902` by `Thomas Fan`_.
6161

6262
- |Enhancement| The parameter `return_X_y` was added to
@@ -107,6 +107,11 @@ Changelog
107107
preserve the class balance of the original training set. :pr:`14194`
108108
by :user:`Johann Faouzi <johannfaouzi>`.
109109

110+
- |Efficiency| :func:`ensemble.MissingIndicator.fit_transform` the
111+
_get_missing_features_info function is now called once when calling
112+
fit_transform for MissingIndicator class. :pr:`14356` by :user:
113+
`Harsh Soni <harsh020>`.
114+
110115
- |Fix| :class:`ensemble.AdaBoostClassifier` computes probabilities based on
111116
the decision function as in the literature. Thus, `predict` and
112117
`predict_proba` give consistent results.
@@ -139,7 +144,7 @@ Changelog
139144
- |Feature| Added multiclass support to :func:`metrics.roc_auc_score`.
140145
:issue:`12789` by :user:`Kathy Chen <kathyxchen>`,
141146
:user:`Mohamed Maskani <maskani-moh>`, and :user:`Thomas Fan <thomasjpfan>`.
142-
147+
143148
- |Feature| Add :class:`metrics.mean_tweedie_deviance` measuring the
144149
Tweedie deviance for a power parameter ``p``. Also add mean Poisson deviance
145150
:class:`metrics.mean_poisson_deviance` and mean Gamma deviance
@@ -190,7 +195,7 @@ Changelog
190195
:mod:`sklearn.cluster`
191196
......................
192197

193-
- |Enhancement| :class:`cluster.SpectralClustering` now accepts a ``n_components``
198+
- |Enhancement| :class:`cluster.SpectralClustering` now accepts a ``n_components``
194199
parameter. This parameter extends `SpectralClustering` class functionality to
195200
match `spectral_clustering`.
196201
:pr:`13726` by :user:`Shuzhe Xiao <fdas3213>`.

sklearn/impute/_base.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def _validate_input(self, X):
586586

587587
return X
588588

589-
def fit(self, X, y=None):
589+
def _fit(self, X, y=None):
590590
"""Fit the transformer on X.
591591
592592
Parameters
@@ -597,8 +597,10 @@ def fit(self, X, y=None):
597597
598598
Returns
599599
-------
600-
self : object
601-
Returns self.
600+
imputer_mask : {ndarray or sparse matrix}, shape (n_samples, \
601+
n_features)
602+
The imputer mask of the original data.
603+
602604
"""
603605
X = self._validate_input(X)
604606
self._n_features = X.shape[1]
@@ -612,7 +614,26 @@ def fit(self, X, y=None):
612614
raise ValueError("'sparse' has to be a boolean or 'auto'. "
613615
"Got {!r} instead.".format(self.sparse))
614616

615-
self.features_ = self._get_missing_features_info(X)[1]
617+
missing_features_info = self._get_missing_features_info(X)
618+
self.features_ = missing_features_info[1]
619+
620+
return missing_features_info[0]
621+
622+
def fit(self, X, y=None):
623+
"""Fit the transformer on X.
624+
625+
Parameters
626+
----------
627+
X : {array-like, sparse matrix}, shape (n_samples, n_features)
628+
Input data, where ``n_samples`` is the number of samples and
629+
``n_features`` is the number of features.
630+
631+
Returns
632+
-------
633+
self : object
634+
Returns self.
635+
"""
636+
self._fit(X, y)
616637

617638
return self
618639

@@ -667,7 +688,12 @@ def fit_transform(self, X, y=None):
667688
will be boolean.
668689
669690
"""
670-
return self.fit(X, y).transform(X)
691+
imputer_mask = self._fit(X, y)
692+
693+
if self.features_.size < self._n_features:
694+
imputer_mask = imputer_mask[:, self.features_]
695+
696+
return imputer_mask
671697

672698
def _more_tags(self):
673699
return {'allow_nan': True,

0 commit comments

Comments
 (0)