Skip to content

Commit d41fb76

Browse files
author
maheshakya
committed
Implemented median and constant strategies in DummyRegressor
1 parent c52476a commit d41fb76

File tree

2 files changed

+135
-11
lines changed

2 files changed

+135
-11
lines changed

sklearn/dummy.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
21
# Author: Mathieu Blondel <[email protected]>
32
# Arnaud Joly <[email protected]>
3+
# Maheshakya Wijewardena<[email protected]>
44
# License: BSD 3 clause
55

66
import numpy as np
@@ -13,6 +13,7 @@
1313

1414

1515
class DummyClassifier(BaseEstimator, ClassifierMixin):
16+
1617
"""
1718
DummyClassifier is a classifier that makes predictions using simple rules.
1819
@@ -273,6 +274,7 @@ def predict_log_proba(self, X):
273274

274275

275276
class DummyRegressor(BaseEstimator, RegressorMixin):
277+
276278
"""
277279
DummyRegressor is a regressor that always predicts the mean of the training
278280
targets.
@@ -282,8 +284,9 @@ class DummyRegressor(BaseEstimator, RegressorMixin):
282284
283285
Attributes
284286
----------
285-
`y_mean_` : float or array of shape [n_outputs]
286-
Mean of the training targets.
287+
`constant_' : float or array of shape [n_outputs]
288+
Mean or median of the training targets or constant value given the by
289+
the user.
287290
288291
`n_outputs_` : int,
289292
Number of outputs.
@@ -292,6 +295,10 @@ class DummyRegressor(BaseEstimator, RegressorMixin):
292295
True if the output at fit is 2d, else false.
293296
"""
294297

298+
def __init__(self, strategy="mean", constant=None):
299+
self.strategy = strategy
300+
self.constant = constant
301+
295302
def fit(self, X, y):
296303
"""Fit the random regressor.
297304
@@ -309,11 +316,47 @@ def fit(self, X, y):
309316
self : object
310317
Returns self.
311318
"""
319+
320+
if self.strategy not in ("mean", "median", "constant"):
321+
raise ValueError("Unknown strategy type.")
322+
312323
y = safe_asarray(y)
313-
self.y_mean_ = np.reshape(np.mean(y, axis=0), (1, -1))
314-
self.n_outputs_ = np.size(self.y_mean_) # y.shape[1] is not safe
315-
self.output_2d_ = (y.ndim == 2)
316-
return self
324+
325+
if self.strategy == "mean":
326+
self.constant_ = np.reshape(np.mean(y, axis=0), (1, -1))
327+
self.n_outputs_ = np.size(self.constant_) # y.shape[1] is not safe
328+
self.output_2d_ = (y.ndim == 2)
329+
return self
330+
331+
elif self.strategy == "median":
332+
self.constant_ = np.reshape(np.median(y, axis=0), (1, -1))
333+
self.n_outputs_ = np.size(self.constant_) # y.shape[1] is not safe
334+
self.output_2d_ = (y.ndim == 2)
335+
return self
336+
337+
elif self.strategy == "constant":
338+
if self.constant is None:
339+
raise ValueError("Constant not defined.")
340+
341+
if not (isinstance(self.constant, np.ndarray) or isinstance(self.constant, list)):
342+
raise ValueError(
343+
"Constants should be in type list or numpy.ndarray.")
344+
345+
self.output_2d_ = (y.ndim == 2)
346+
self.constant = safe_asarray(self.constant)
347+
348+
if self.output_2d_:
349+
if self.constant.shape[1] != y.shape[1]:
350+
raise ValueError(
351+
"Number of outputs and number of constants do not match.")
352+
else:
353+
if len(self.constant) != 1:
354+
raise ValueError(
355+
"Number of constants should be equal to one.")
356+
357+
self.constant_ = np.reshape(self.constant, (1, -1))
358+
self.n_outputs_ = np.size(self.constant_) # y.shape[1] is not safe
359+
return self
317360

318361
def predict(self, X):
319362
"""
@@ -330,14 +373,16 @@ def predict(self, X):
330373
y : array, shape = [n_samples] or [n_samples, n_outputs]
331374
Predicted target values for X.
332375
"""
333-
if not hasattr(self, "y_mean_"):
376+
if not hasattr(self, "constant_"):
334377
raise ValueError("DummyRegressor not fitted.")
335378

336379
X = safe_asarray(X)
337380
n_samples = X.shape[0]
338-
y = np.ones((n_samples, 1)) * self.y_mean_
381+
382+
y = np.ones((n_samples, 1)) * self.constant_
339383

340384
if self.n_outputs_ == 1 and not self.output_2d_:
341385
y = np.ravel(y)
342386

343387
return y
388+

sklearn/tests/test_dummy.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
sklearn/dummy.py:290:1: W293 blank line contains whitespace
2+
sklearn/dummy.py:341:80: E501 line too long (94 > 79 characters)
3+
sklearn/dummy.py:351:80: E501 line too long (82 > 79 characters)
4+
sklearn/dummy.py:388:1: W391 blank line at end of file
5+
maheshakya@maheshakya-TECRA-M11:~/scikit-learn$ autopep8 sklearn/tests/test_dummy.py
16
import warnings
27
import numpy as np
38

@@ -59,6 +64,19 @@ def _check_behavior_2d(clf):
5964
assert_equal(y.shape, y_pred.shape)
6065

6166

67+
def _check_behavior_2d_for_constant(clf):
68+
# 2d case only
69+
X = np.array([[0], [0], [0], [0]]) # ignored
70+
y = np.array([[1, 0, 5, 4, 3],
71+
[2, 0, 1, 2, 5],
72+
[1, 0, 4, 5, 2],
73+
[1, 3, 3, 2, 0]])
74+
est = clone(clf)
75+
est.fit(X, y)
76+
y_pred = est.predict(X)
77+
assert_equal(y.shape, y_pred.shape)
78+
79+
6280
def test_most_frequent_strategy():
6381
X = [[0], [0], [0], [0]] # ignored
6482
y = [1, 2, 1, 1]
@@ -175,7 +193,7 @@ def test_classifier_exceptions():
175193
assert_raises(ValueError, clf.predict_proba, [])
176194

177195

178-
def test_regressor():
196+
def test_mean_strategy_regressor():
179197
X = [[0]] * 4 # ignored
180198
y = [1, 2, 1, 1]
181199

@@ -184,7 +202,7 @@ def test_regressor():
184202
assert_array_equal(reg.predict(X), [5. / 4] * len(X))
185203

186204

187-
def test_multioutput_regressor():
205+
def test_mean_strategy_multioutput_regressor():
188206

189207
X_learn = np.random.randn(10, 10)
190208
y_learn = np.random.randn(10, 5)
@@ -210,6 +228,66 @@ def test_regressor_exceptions():
210228
assert_raises(ValueError, reg.predict, [])
211229

212230

231+
def test_median_strategy_regressor():
232+
X = [[0]] * 5 # ignored
233+
y = [1, 2, 4, 6, 8]
234+
235+
reg = DummyRegressor(strategy="median")
236+
reg.fit(X, y)
237+
assert_array_equal(reg.predict(X), [4] * len(X))
238+
239+
240+
def test_median_strategy_multioutput_regressor():
241+
242+
X_learn = np.random.randn(10, 10)
243+
y_learn = np.random.randn(10, 5)
244+
245+
median = np.median(y_learn, axis=0).reshape((1, -1))
246+
247+
X_test = np.random.randn(20, 10)
248+
y_test = np.random.randn(20, 5)
249+
250+
# Correctness oracle
251+
est = DummyRegressor(strategy="median")
252+
est.fit(X_learn, y_learn)
253+
y_pred_learn = est.predict(X_learn)
254+
y_pred_test = est.predict(X_test)
255+
256+
assert_array_equal(np.tile(median, (y_learn.shape[0], 1)), y_pred_learn)
257+
assert_array_equal(np.tile(median, (y_test.shape[0], 1)), y_pred_test)
258+
_check_behavior_2d(est)
259+
260+
261+
def test_constant_strategy_regressor():
262+
X = [[0]] * 5 # ignored
263+
y = [1, 2, 4, 6, 8]
264+
265+
reg = DummyRegressor(strategy="constant", constant=[43])
266+
reg.fit(X, y)
267+
assert_array_equal(reg.predict(X), [43] * len(X))
268+
269+
270+
def test_constant_strategy_multioutput_regressor():
271+
272+
X_learn = np.random.randn(10, 10)
273+
y_learn = np.random.randn(10, 5)
274+
275+
constants = np.random.randn(1, 5)
276+
277+
X_test = np.random.randn(20, 10)
278+
y_test = np.random.randn(20, 5)
279+
280+
# Correctness oracle
281+
est = DummyRegressor(strategy="constant", constant=constants)
282+
est.fit(X_learn, y_learn)
283+
y_pred_learn = est.predict(X_learn)
284+
y_pred_test = est.predict(X_test)
285+
286+
assert_array_equal(np.tile(constants, (y_learn.shape[0], 1)), y_pred_learn)
287+
assert_array_equal(np.tile(constants, (y_test.shape[0], 1)), y_pred_test)
288+
_check_behavior_2d_for_constant(est)
289+
290+
213291
def test_constant_strategy():
214292
X = [[0], [0], [0], [0]] # ignored
215293
y = [2, 1, 2, 2]
@@ -253,3 +331,4 @@ def test_constant_strategy_exceptions():
253331
clf = DummyClassifier(strategy="constant", random_state=0,
254332
constant=[2, 0])
255333
assert_raises(ValueError, clf.fit, X, y)
334+

0 commit comments

Comments
 (0)