Skip to content

Commit 8df450e

Browse files
committed
Improved comment explaining numerical stability issue
1 parent affbb03 commit 8df450e

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

code/network2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ class CrossEntropyCost:
4343
@staticmethod
4444
def fn(a, y):
4545
"""Return the cost associated with an output ``a`` and desired output
46-
``y``. Note that the np.nan_to_num ensures that if the output
47-
from the network is exactly right, then 0.0 will be returned,
48-
rather than nan.
46+
``y``. Note that np.nan_to_num is used to ensure numerical
47+
stability. In particular, if both ``a`` and ``y`` have a 1.0
48+
in the same slot, then the expression (1-y)*np.log(1-a)
49+
returns nan. The np.nan_to_num ensures that that is converted
50+
to the correct value (0.0).
4951
5052
"""
5153
return np.nan_to_num(np.sum(-y*np.log(a)-(1-y)*np.log(1-a)))

0 commit comments

Comments
 (0)