Skip to content

Commit 7b46c13

Browse files
committed
Merge pull request scikit-learn#3049 from glouppe/tree-bestfirst
[MRG+1] Trees: set correct impurity values in BestFirstBuilder
2 parents fb795eb + 2633774 commit 7b46c13

File tree

6 files changed

+1076
-935
lines changed

6 files changed

+1076
-935
lines changed

sklearn/tree/_tree.c

Lines changed: 686 additions & 648 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sklearn/tree/_tree.pyx

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,7 +1985,8 @@ cdef inline int _add_to_frontier(PriorityHeapRecord* rec,
19851985
"""Adds record ``rec`` to the priority queue ``frontier``; returns -1
19861986
on memory-error. """
19871987
return frontier.push(rec.node_id, rec.start, rec.end, rec.pos, rec.depth,
1988-
rec.is_leaf, rec.improvement, rec.impurity)
1988+
rec.is_leaf, rec.improvement, rec.impurity,
1989+
rec.impurity_left, rec.impurity_right)
19891990

19901991

19911992
cdef class BestFirstTreeBuilder(TreeBuilder):
@@ -2084,8 +2085,9 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
20842085
# Compute left split node
20852086
rc = self._add_split_node(splitter, tree,
20862087
record.start, record.pos,
2087-
record.impurity, IS_NOT_FIRST,
2088-
IS_LEFT, node, record.depth + 1,
2088+
record.impurity_left,
2089+
IS_NOT_FIRST, IS_LEFT, node,
2090+
record.depth + 1,
20892091
&split_node_left)
20902092
if rc == -1:
20912093
break
@@ -2095,7 +2097,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
20952097

20962098
# Compute right split node
20972099
rc = self._add_split_node(splitter, tree, record.pos,
2098-
record.end, record.impurity,
2100+
record.end,
2101+
record.impurity_right,
20992102
IS_NOT_FIRST, IS_NOT_LEFT, node,
21002103
record.depth + 1,
21012104
&split_node_right)
@@ -2183,11 +2186,16 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
21832186
res.pos = pos
21842187
res.is_leaf = 0
21852188
res.improvement = split_improvement
2189+
res.impurity_left = split_impurity_left
2190+
res.impurity_right = split_impurity_right
2191+
21862192
else:
21872193
# is leaf => 0 improvement
21882194
res.pos = end
21892195
res.is_leaf = 1
21902196
res.improvement = 0.0
2197+
res.impurity_left = impurity
2198+
res.impurity_right = impurity
21912199

21922200
return 0
21932201

0 commit comments

Comments
 (0)