diff --git a/modAL/models/learners.py b/modAL/models/learners.py index 7fdcf2c..b4ea394 100644 --- a/modAL/models/learners.py +++ b/modAL/models/learners.py @@ -311,6 +311,20 @@ def _set_classes(self): def _add_training_data(self, X: modALinput, y: modALinput): super()._add_training_data(X, y) + + def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, only_new: bool = False, **fit_kwargs) -> None: + """ + Adds X and y to the known training data for each learner and retrains learners with the augmented dataset. + + Args: + X: The new samples for which the labels are supplied by the expert. + y: Labels corresponding to the new instances in X. + bootstrap: If True, trains each learner on a bootstrapped set. Useful when building the ensemble by bagging. + only_new: If True, the model is retrained using only X and y, ignoring the previously provided examples. + **fit_kwargs: Keyword arguments to be passed to the fit method of the predictor. + """ + + super().teach(X, y, bootstrap=bootstrap, only_new=only_new, **fit_kwargs) self._set_classes() def predict(self, X: modALinput, **predict_proba_kwargs) -> Any: