Skip to content

Commit 8a26b28

Browse files
krishnaw14norvig
authored andcommitted
Added function and test cases for cross-entropy loss (aimacode#853)
* Correction in the formula for mean square error * Added cross-entropy loss * Test case for cross-entropy loss * Decimal point mistake * Added spaces around = and ==
1 parent d1ea3fe commit 8a26b28

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

learning.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
def euclidean_distance(X, Y):
2222
return math.sqrt(sum((x - y)**2 for x, y in zip(X, Y)))
2323

24+
def cross_entropy_loss(X,Y):
25+
n=len(X)
26+
return (-1.0/n)*sum(x*math.log(y)+(1-x)*math.log(1-y) for x,y in zip(X,Y) )
27+
2428

2529
def rms_error(X, Y):
2630
return math.sqrt(ms_error(X, Y))

tests/test_learning.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ def test_euclidean():
1818
distance = euclidean_distance([0, 0, 0], [0, 0, 0])
1919
assert distance == 0
2020

21+
def test_cross_entropy():
22+
loss = cross_entropy_loss([1,0], [0.9, 0.3])
23+
assert round(loss,2) == 0.23
24+
25+
loss = cross_entropy_loss([1,0,0,1], [0.9,0.3,0.5,0.75])
26+
assert round(loss,2) == 0.36
27+
28+
loss = cross_entropy_loss([1,0,0,1,1,0,1,1], [0.9,0.3,0.5,0.75,0.85,0.14,0.93,0.79])
29+
assert round(loss,2) == 0.26
30+
2131

2232
def test_rms_error():
2333
assert rms_error([2, 2], [2, 2]) == 0

0 commit comments

Comments
 (0)