Skip to content

Commit 9061bfc

Browse files
authored
Merge pull request MLBazaar#78 from HDI-Project/6_add_unit_tests
Add unit tests and integrate with codecov
2 parents b309403 + 52a150d commit 9061bfc

File tree

12 files changed

+438
-36
lines changed

12 files changed

+438
-36
lines changed

.travis.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ python:
66

77
# Command to install dependencies
88
install:
9-
- pip install -U tox-travis
9+
- pip install -U tox-travis codecov
1010
- sudo apt-get install graphviz
1111

1212
# Command to run tests
1313
script: tox
1414

15+
after_success: codecov
16+
1517
deploy:
1618

1719
- provider: pages

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ fix-lint: ## fix lint issues using autoflake, autopep8, and isort
103103

104104
.PHONY: test
105105
test: ## run tests quickly with the default Python
106-
python -m pytest
106+
python -m pytest --cov=mlblocks
107107

108108
.PHONY: test-all
109109
test-all: ## run tests on every Python version with tox

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ Pipelines and Primitives for Machine Learning and Data Science.
1111

1212
[![PyPi][pypi-img]][pypi-url]
1313
[![Travis][travis-img]][travis-url]
14+
[![CodeCov][codecov-img]][codecov-url]
1415

1516
[pypi-img]: https://img.shields.io/pypi/v/mlblocks.svg
1617
[pypi-url]: https://pypi.python.org/pypi/mlblocks
1718
[travis-img]: https://travis-ci.org/HDI-Project/MLBlocks.svg?branch=master
1819
[travis-url]: https://travis-ci.org/HDI-Project/MLBlocks
20+
[codecov-img]: https://codecov.io/gh/HDI-Project/MLBlocks/branch/master/graph/badge.svg
21+
[codecov-url]: https://codecov.io/gh/HDI-Project/MLBlocks
1922

2023
MLBlocks is a simple framework for composing end-to-end tunable Machine Learning Pipelines by
2124
seamlessly combining tools from any python library with a simple, common and uniform interface.

mlblocks/mlblock.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _extract_params(self, kwargs, hyperparameters):
9494
fit_args = [arg['name'] for arg in self.fit_args]
9595
produce_args = [arg['name'] for arg in self.produce_args]
9696

97-
for name in kwargs.keys():
97+
for name in list(kwargs.keys()):
9898
if name in fit_args:
9999
fit_params[name] = kwargs.pop(name)
100100

@@ -127,9 +127,9 @@ def __init__(self, name, **kwargs):
127127
self._class = bool(self.produce_method)
128128

129129
hyperparameters = metadata.get('hyperparameters', dict())
130-
init_params, fit_params, produce_params = self._extract_params(
131-
kwargs, hyperparameters)
132-
self._hyperparamters = init_params
130+
init_params, fit_params, produce_params = self._extract_params(kwargs, hyperparameters)
131+
132+
self._hyperparameters = init_params
133133
self._fit_params = fit_params
134134
self._produce_params = produce_params
135135

@@ -175,7 +175,7 @@ def get_hyperparameters(self):
175175
the dictionary containing the hyperparameter values that the
176176
MLBlock is currently using.
177177
"""
178-
return self._hyperparamters
178+
return self._hyperparameters.copy()
179179

180180
def set_hyperparameters(self, hyperparameters):
181181
"""Set new hyperparameters.
@@ -190,10 +190,10 @@ def set_hyperparameters(self, hyperparameters):
190190
of the hyperparameters and as values
191191
the values to be used.
192192
"""
193-
self._hyperparamters.update(hyperparameters)
193+
self._hyperparameters.update(hyperparameters)
194194

195195
if self._class:
196-
self.instance = self.primitive(**self._hyperparamters)
196+
self.instance = self.primitive(**self._hyperparameters)
197197

198198
def fit(self, **kwargs):
199199
"""Call the fit method of the primitive.
@@ -239,5 +239,5 @@ def produce(self, **kwargs):
239239
if self._class:
240240
return getattr(self.instance, self.produce_method)(**produce_args)
241241

242-
produce_args.update(self._hyperparamters)
242+
produce_args.update(self._hyperparameters)
243243
return self.primitive(**produce_args)

setup.cfg

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,3 @@ collect_ignore = ['setup.py']
4545

4646
[tool:pylint]
4747
good-names = X,y
48-

setup.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515

1616

1717
install_requires = [
18+
'mlprimitives>=0.1.3',
1819
]
1920

2021

2122
tests_require = [
22-
'mock>=2.0.0',
2323
'pytest>=3.4.2',
24+
'pytest-cov>=2.6.0',
2425
]
2526

2627

@@ -62,11 +63,6 @@
6263
]
6364

6465

65-
demo_requires = [
66-
'mlprimitives==0.1.1',
67-
]
68-
69-
7066
setup(
7167
author='MIT Data To AI Lab',
7268
author_email='[email protected]',
@@ -81,8 +77,7 @@
8177
],
8278
description="Pipelines and primitives for machine learning and data science.",
8379
extras_require={
84-
'demo': demo_requires,
85-
'dev': demo_requires + development_requires + tests_require,
80+
'dev': development_requires + tests_require,
8681
'test': tests_require,
8782
},
8883
include_package_data=True,

tests/test__primitives.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

tests/test_datasets.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from unittest import TestCase
4+
from unittest.mock import Mock
5+
6+
from mlblocks import datasets
7+
8+
9+
class TestDataset(TestCase):
10+
11+
def setUp(self):
12+
self.description = """Dataset Name.
13+
14+
Some extended description.
15+
"""
16+
self.score = Mock()
17+
self.score.return_value = 1.0
18+
19+
self.dataset = datasets.Dataset(
20+
self.description, 'data', 'target', self.score,
21+
shuffle=False, stratify=True, some='kwargs')
22+
23+
def test___init__(self):
24+
25+
assert self.dataset.name == 'Dataset Name.'
26+
assert self.dataset.description == self.description
27+
assert self.dataset.data == 'data'
28+
assert self.dataset.target == 'target'
29+
assert self.dataset._shuffle is False
30+
assert self.dataset._stratify is True
31+
assert self.dataset._score == self.score
32+
assert self.dataset.some == 'kwargs'
33+
34+
def test_score(self):
35+
returned = self.dataset.score('a', b='c')
36+
37+
assert returned == 1.0
38+
self.score.assert_called_once_with('a', b='c')
39+
40+
def test___repr__(self):
41+
repr_ = str(self.dataset)
42+
43+
assert repr_ == "Dataset Name."
44+
45+
46+
def test_dataset_describe(capsys):
47+
"""Tested here because fixtures are not supported in TestCases."""
48+
49+
description = """Dataset Name.
50+
51+
Some extended description.
52+
"""
53+
54+
dataset = datasets.Dataset(description, 'data', 'target', 'score')
55+
dataset.describe()
56+
57+
captured = capsys.readouterr()
58+
assert captured.out == description + '\n'

tests/test_mlblock.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from unittest import TestCase
4+
from unittest.mock import patch
5+
6+
from mlblocks.mlblock import MLBlock, import_object
7+
8+
# import pytest
9+
10+
11+
class DummyClass:
12+
pass
13+
14+
15+
def test_import_object():
16+
dummy_class = import_object(__name__ + '.DummyClass')
17+
18+
assert dummy_class is DummyClass
19+
20+
21+
class TestMLBlock(TestCase):
22+
23+
def test__extract_params(self):
24+
pass
25+
26+
@patch('mlblocks.mlblock.MLBlock.set_hyperparameters')
27+
@patch('mlblocks.mlblock.import_object')
28+
@patch('mlblocks.mlblock.load_primitive')
29+
def test___init__(self, load_primitive_mock, import_object_mock, set_hps_mock):
30+
load_primitive_mock.return_value = {
31+
'primitive': 'a_primitive_name',
32+
'produce': {
33+
'args': [
34+
{
35+
'name': 'argument'
36+
}
37+
],
38+
'output': [
39+
]
40+
}
41+
}
42+
43+
mlblock = MLBlock('given_primitive_name', argument='value')
44+
45+
assert mlblock.name == 'given_primitive_name'
46+
assert mlblock.primitive == import_object_mock.return_value
47+
assert mlblock._fit == dict()
48+
assert mlblock.fit_args == list()
49+
assert mlblock.fit_method is None
50+
51+
produce = {
52+
'args': [
53+
{
54+
'name': 'argument'
55+
}
56+
],
57+
'output': [
58+
]
59+
}
60+
assert mlblock._produce == produce
61+
assert mlblock.produce_args == produce['args']
62+
assert mlblock.produce_output == produce['output']
63+
assert mlblock.produce_method is None
64+
assert mlblock._class is False
65+
66+
assert mlblock._hyperparameters == dict()
67+
assert mlblock._fit_params == dict()
68+
assert mlblock._produce_params == {'argument': 'value'}
69+
70+
assert mlblock._tunable == dict()
71+
72+
set_hps_mock.assert_called_once_with(dict())
73+
74+
@patch('mlblocks.mlblock.import_object')
75+
@patch('mlblocks.mlblock.load_primitive')
76+
def test___str__(self, load_primitive_mock, import_object_mock):
77+
load_primitive_mock.return_value = {
78+
'primitive': 'a_primitive_name',
79+
'produce': {
80+
'args': [],
81+
'output': []
82+
}
83+
}
84+
85+
mlblock = MLBlock('given_primitive_name')
86+
87+
assert str(mlblock) == 'MLBlock - given_primitive_name'
88+
89+
@patch('mlblocks.mlblock.import_object')
90+
@patch('mlblocks.mlblock.load_primitive')
91+
def test_get_tunable_hyperparameters(self, load_primitive_mock, import_object_mock):
92+
"""get_tunable_hyperparameters has to return a copy of the _tunables attribute."""
93+
load_primitive_mock.return_value = {
94+
'primitive': 'a_primitive_name',
95+
'produce': {
96+
'args': [],
97+
'output': []
98+
}
99+
}
100+
101+
mlblock = MLBlock('given_primitive_name')
102+
103+
tunable = dict()
104+
mlblock._tunable = tunable
105+
106+
returned = mlblock.get_tunable_hyperparameters()
107+
108+
assert returned == tunable
109+
assert returned is not tunable
110+
111+
@patch('mlblocks.mlblock.import_object')
112+
@patch('mlblocks.mlblock.load_primitive')
113+
def test_get_hyperparameters(self, load_primitive_mock, import_object_mock):
114+
"""get_hyperparameters has to return a copy of the _hyperparameters attribute."""
115+
load_primitive_mock.return_value = {
116+
'primitive': 'a_primitive_name',
117+
'produce': {
118+
'args': [],
119+
'output': []
120+
}
121+
}
122+
123+
mlblock = MLBlock('given_primitive_name')
124+
125+
hyperparameters = dict()
126+
mlblock._hyperparameters = hyperparameters
127+
128+
returned = mlblock.get_hyperparameters()
129+
130+
assert returned == hyperparameters
131+
assert returned is not hyperparameters
132+
133+
def test_set_hyperparameters_function(self):
134+
pass
135+
136+
def test_set_hyperparameters_class(self):
137+
pass
138+
139+
def test_fit_no_fit(self):
140+
pass
141+
142+
def test_fit(self):
143+
pass
144+
145+
def test_produce_function(self):
146+
pass
147+
148+
def test_produce_class(self):
149+
pass

0 commit comments

Comments
 (0)