Skip to content

Commit 65a00f5

Browse files
committed
Merge pull request aimacode#190 from SnShine/addRounder
Removed truncate() and added rounder() as discussed in aimacode#183
2 parents 410ac5f + 273a47b commit 65a00f5

File tree

3 files changed

+22
-25
lines changed

3 files changed

+22
-25
lines changed

tests/test_probability.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@ def test_forward_backward():
110110
umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor)
111111

112112
umbrella_evidence = [T, T, F, T, T]
113-
assert truncate(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [[0.6469, 0.3531],
113+
assert rounder(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [[0.6469, 0.3531],
114114
[0.8673, 0.1327], [0.8204, 0.1796], [0.3075, 0.6925], [0.8204, 0.1796], [0.8673, 0.1327]]
115115

116116
umbrella_evidence = [T, F, T, F, T]
117-
assert truncate(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [[0.5871, 0.4129],
117+
assert rounder(forward_backward(umbrellaHMM, umbrella_evidence, umbrella_prior)) == [[0.5871, 0.4129],
118118
[0.7177, 0.2823], [0.2324, 0.7676], [0.6072, 0.3928], [0.2324, 0.7676], [0.7177, 0.2823]]
119119

120120
def test_fixed_lag_smoothing():
@@ -126,16 +126,16 @@ def test_fixed_lag_smoothing():
126126
umbrellaHMM = HiddenMarkovModel(umbrella_transition, umbrella_sensor)
127127

128128
d = 2
129-
assert truncate(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.1111, 0.8889]
129+
assert rounder(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.1111, 0.8889]
130130
d = 5
131-
assert truncate(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) is None
131+
assert fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t) is None
132132

133133
umbrella_evidence = [T, T, F, T, T]
134134
# t = 4
135135
e_t = T
136136

137137
d = 1
138-
assert truncate(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.9939, 0.0061]
138+
assert rounder(fixed_lag_smoothing(e_t, umbrellaHMM, d, umbrella_evidence, t)) == [0.9939, 0.0061]
139139

140140

141141
if __name__ == '__main__':

tests/test_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,20 @@ def test_scalar_vector_product():
113113
assert scalar_vector_product(2, [1, 2, 3]) == [2, 4, 6]
114114

115115
def test_scalar_matrix_product():
116-
assert truncate(scalar_matrix_product(-5, [[1, 2], [3, 4], [0, 6]])) == [[-5, -10], [-15, -20], [0, -30]]
117-
assert truncate(scalar_matrix_product(0.2, [[1, 2], [2, 3]])) == [[0.2, 0.4], [0.4, 0.6]]
116+
assert rounder(scalar_matrix_product(-5, [[1, 2], [3, 4], [0, 6]])) == [[-5, -10], [-15, -20], [0, -30]]
117+
assert rounder(scalar_matrix_product(0.2, [[1, 2], [2, 3]])) == [[0.2, 0.4], [0.4, 0.6]]
118118

119119

120120
def test_inverse_matrix():
121-
assert truncate(inverse_matrix([[1, 0], [0, 1]])) == [[1, 0], [0, 1]]
122-
assert truncate(inverse_matrix([[2, 1], [4, 3]])) == [[1.5, -0.5], [-2.0, 1.0]]
123-
assert truncate(inverse_matrix([[4, 7], [2, 6]])) == [[0.6, -0.7], [-0.2, 0.4]]
124-
125-
def test_truncate():
126-
assert truncate(5.3330000300330) == 5.3330
127-
assert truncate(10.234566) == 10.2346
128-
assert truncate([1.234566, 0.555555, 6.010101]) == [1.2346, 0.5556, 6.0101]
129-
assert truncate([[1.234566, 0.555555, 6.010101],
121+
assert rounder(inverse_matrix([[1, 0], [0, 1]])) == [[1, 0], [0, 1]]
122+
assert rounder(inverse_matrix([[2, 1], [4, 3]])) == [[1.5, -0.5], [-2.0, 1.0]]
123+
assert rounder(inverse_matrix([[4, 7], [2, 6]])) == [[0.6, -0.7], [-0.2, 0.4]]
124+
125+
def test_rounder():
126+
assert rounder(5.3330000300330) == 5.3330
127+
assert rounder(10.234566) == 10.2346
128+
assert rounder([1.234566, 0.555555, 6.010101]) == [1.2346, 0.5556, 6.0101]
129+
assert rounder([[1.234566, 0.555555, 6.010101],
130130
[10.505050, 12.121212, 6.030303]]) == [[1.2346, 0.5556, 6.0101],
131131
[10.5051, 12.1212, 6.0303]]
132132

utils.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -243,16 +243,13 @@ def weighted_sampler(seq, weights):
243243

244244
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
245245

246-
def truncate(x, n = 4):
247-
"""Truncates floats, vectors, matrices to n decimal values"""
248-
if isinstance(x, float):
249-
return(float("{0:.{1}f}".format(x, n)))
250-
elif isinstance(x, list) and isinstance(x[0], float):
251-
return([float("{0:.{1}f}".format(i, n)) for i in x])
252-
elif isinstance(x, list) and isinstance(x[0], list) and isinstance(x[0][0], float):
253-
return([[float("{0:.{1}f}".format(i, n)) for i in row] for row in x])
246+
def rounder(numbers, d = 4):
247+
"Round a single number, or sequence of numbers, to d decimal places."
248+
if isinstance(numbers, (int, float)):
249+
return round(numbers, d)
254250
else:
255-
return x
251+
constructor = type(numbers) # Can be list, set, tuple, etc.
252+
return constructor(rounder(n, d) for n in numbers)
256253

257254
def num_or_str(x):
258255
"""The argument is a string; convert to a number if

0 commit comments

Comments
 (0)