Skip to content

Commit a7f6bde

Browse files
Chipe1norvig
authored andcommitted
Minor fixes (aimacode#581)
* Minor fixes * Typo fix
1 parent 5146f77 commit a7f6bde

File tree

3 files changed

+91
-59
lines changed

3 files changed

+91
-59
lines changed

learning.ipynb

Lines changed: 65 additions & 47 deletions
Large diffs are not rendered by default.

learning.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -381,16 +381,21 @@ class DecisionFork:
381381
"""A fork of a decision tree holds an attribute to test, and a dict
382382
of branches, one for each of the attribute's values."""
383383

384-
def __init__(self, attr, attrname=None, branches=None):
384+
def __init__(self, attr, attrname=None, default_child=None, branches=None):
385385
"""Initialize by saying what attribute this node tests."""
386386
self.attr = attr
387387
self.attrname = attrname or attr
388+
self.default_child = default_child
388389
self.branches = branches or {}
389390

390391
def __call__(self, example):
391392
"""Given an example, classify it using the attribute and the branches."""
392393
attrvalue = example[self.attr]
393-
return self.branches[attrvalue](example)
394+
if attrvalue in self.branches:
395+
return self.branches[attrvalue](example)
396+
else:
397+
# return default class when attribute is unknown
398+
return self.default_child(example)
394399

395400
def add(self, val, subtree):
396401
"""Add a branch. If self.attr = val, go to the given subtree."""
@@ -440,7 +445,7 @@ def decision_tree_learning(examples, attrs, parent_examples=()):
440445
return plurality_value(examples)
441446
else:
442447
A = choose_attribute(attrs, examples)
443-
tree = DecisionFork(A, dataset.attrnames[A])
448+
tree = DecisionFork(A, dataset.attrnames[A], plurality_value(examples))
444449
for (v_k, exs) in split_by(A, examples):
445450
subtree = decision_tree_learning(
446451
exs, removeall(A, attrs), examples)
@@ -495,27 +500,28 @@ def information_content(values):
495500

496501

497502
def RandomForest(dataset, n=5):
498-
"""A ensemble of Decision trese trained using bagging and feature bagging."""
499-
500-
predictors = [DecisionTreeLearner(examples=data_bagging(dataset),
501-
attrs=dataset.attrs,
502-
attrnames=dataset.attrnames,
503-
target=dataset.target,
504-
inputs=feature_bagging(datatset)) for _ in range(n)]
503+
"""An ensemble of Decision Trees trained using bagging and feature bagging."""
505504

506505
def data_bagging(dataset, m=0):
507506
"""Sample m examples with replacement"""
508507
n = len(dataset.examples)
509-
return weighted_sample_with_replacement(m or n, examples, [1]*n)
508+
return weighted_sample_with_replacement(m or n, dataset.examples, [1]*n)
510509

511510
def feature_bagging(dataset, p=0.7):
512511
"""Feature bagging with probability p to retain an attribute"""
513512
inputs = [i for i in dataset.inputs if probability(p)]
514513
return inputs or dataset.inputs
515514

516515
def predict(example):
516+
print([predictor(example) for predictor in predictors])
517517
return mode(predictor(example) for predictor in predictors)
518518

519+
predictors = [DecisionTreeLearner(DataSet(examples=data_bagging(dataset),
520+
attrs=dataset.attrs,
521+
attrnames=dataset.attrnames,
522+
target=dataset.target,
523+
inputs=feature_bagging(dataset))) for _ in range(n)]
524+
519525
return predict
520526

521527
# ______________________________________________________________________________
@@ -1046,7 +1052,7 @@ def T(attrname, branches):
10461052
branches = {value: (child if isinstance(child, DecisionFork)
10471053
else DecisionLeaf(child))
10481054
for value, child in branches.items()}
1049-
return DecisionFork(restaurant.attrnum(attrname), attrname, branches)
1055+
return DecisionFork(restaurant.attrnum(attrname), attrname, print, branches)
10501056

10511057

10521058
""" [Figure 18.2]

tests/test_learning.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ def test_decision_tree_learner():
123123
assert dTL([7.5, 4, 6, 2]) == "virginica"
124124

125125

126+
def test_random_forest():
127+
iris = DataSet(name="iris")
128+
rF = RandomForest(iris)
129+
assert rF([5, 3, 1, 0.1]) == "setosa"
130+
assert rF([6, 5, 3, 1.5]) == "versicolor"
131+
assert rF([7.5, 4, 6, 2]) == "virginica"
132+
133+
126134
def test_neural_network_learner():
127135
iris = DataSet(name="iris")
128136
classes = ["setosa", "versicolor", "virginica"]

0 commit comments

Comments
 (0)