11111212# License: BSD 3 clause
1313
14+ from numbers import Integral
15+
1416import numpy as np
1517import warnings
1618
@@ -73,7 +75,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
7375 feature_names = None , class_names = None , label = 'all' ,
7476 filled = False , leaves_parallel = False , impurity = True ,
7577 node_ids = False , proportion = False , rotate = False ,
76- rounded = False , special_characters = False ):
78+ rounded = False , special_characters = False , precision = 3 ):
7779 """Export a decision tree in DOT format.
7880
7981 This function generates a GraphViz representation of the decision tree,
@@ -143,6 +145,10 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
143145 When set to ``False``, ignore special characters for PostScript
144146 compatibility.
145147
148+ precision : int, optional (default=3)
149+ Number of digits of precision for floating point in the values of
150+ impurity, threshold and value attributes of each node.
151+
146152 Returns
147153 -------
148154 dot_data : string
@@ -162,6 +168,7 @@ def export_graphviz(decision_tree, out_file=SENTINEL, max_depth=None,
162168 >>> clf = clf.fit(iris.data, iris.target)
163169 >>> tree.export_graphviz(clf,
164170 ... out_file='tree.dot') # doctest: +SKIP
171+
165172 """
166173
167174 def get_color (value ):
@@ -226,7 +233,8 @@ def node_to_str(tree, node_id, criterion):
226233 characters [2 ])
227234 node_string += '%s %s %s%s' % (feature ,
228235 characters [3 ],
229- round (tree .threshold [node_id ], 4 ),
236+ round (tree .threshold [node_id ],
237+ precision ),
230238 characters [4 ])
231239
232240 # Write impurity
@@ -237,7 +245,7 @@ def node_to_str(tree, node_id, criterion):
237245 criterion = "impurity"
238246 if labels :
239247 node_string += '%s = ' % criterion
240- node_string += (str (round (tree .impurity [node_id ], 4 )) +
248+ node_string += (str (round (tree .impurity [node_id ], precision )) +
241249 characters [4 ])
242250
243251 # Write node sample count
@@ -260,16 +268,16 @@ def node_to_str(tree, node_id, criterion):
260268 node_string += 'value = '
261269 if tree .n_classes [0 ] == 1 :
262270 # Regression
263- value_text = np .around (value , 4 )
271+ value_text = np .around (value , precision )
264272 elif proportion :
265273 # Classification
266- value_text = np .around (value , 2 )
274+ value_text = np .around (value , precision )
267275 elif np .all (np .equal (np .mod (value , 1 ), 0 )):
268276 # Classification without floating-point weights
269277 value_text = value .astype (int )
270278 else :
271279 # Classification with floating-point weights
272- value_text = np .around (value , 4 )
280+ value_text = np .around (value , precision )
273281 # Strip whitespace
274282 value_text = str (value_text .astype ('S32' )).replace ("b'" , "'" )
275283 value_text = value_text .replace ("' '" , ", " ).replace ("'" , "" )
@@ -402,6 +410,14 @@ def recurse(tree, node_id, criterion, parent=None, depth=0):
402410 return_string = True
403411 out_file = six .StringIO ()
404412
413+ if isinstance (precision , Integral ):
414+ if precision < 0 :
415+ raise ValueError ("'precision' should be greater or equal to 0."
416+ " Got {} instead." .format (precision ))
417+ else :
418+ raise ValueError ("'precision' should be an integer. Got {}"
419+ " instead." .format (type (precision )))
420+
405421 # Check length of feature_names before getting into the tree node
406422 # Raise error if length of feature_names does not match
407423 # n_features_ in the decision_tree
0 commit comments