Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion modAL/models/learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,22 @@ def _set_classes(self):
def _add_training_data(self, X: modALinput, y: modALinput):
super()._add_training_data(X, y)

def fit(self, X: modALinput, y: modALinput, **fit_kwargs) -> 'BaseCommittee':
"""
Fits every learner to a subset sampled with replacement from X. Calling this method makes the learner forget the
data it has seen up until this point and replaces it with X! If you would like to perform bootstrapping on each
learner using the data it has seen, use the method .rebag()!

Calling this method makes the learner forget the data it has seen up until this point and replaces it with X!

Args:
X: The samples to be fitted on.
y: The corresponding labels.
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
"""
super().fit(X, y, **fit_kwargs)
self._set_classes()

def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None:
"""
Adds X and y to the known training data for each learner and retrains learners with the augmented dataset.
Expand All @@ -323,7 +339,6 @@ def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new:
only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples.
**fit_kwargs: Keyword arguments to be passed to the fit method of the predictor.
"""

super().teach(X, y, bootstrap=bootstrap, only_new=only_new, **fit_kwargs)
self._set_classes()

Expand Down