Skip to content

Commit 5e40c5b

Browse files
committed
add hot-cold learn method
1 parent 4d974b3 commit 5e40c5b

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

MLTests/PredictionCompareLearnTest.cs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@ namespace MLTests
55
{
66
public class PredictionCompareLearnTest
77
{
8+
public double GetPredictFromSimpleNetwork(double input, double weight)
9+
{
10+
var prediction = input * weight;
11+
return prediction;
12+
}
13+
14+
public double GetLearnNetwork(double weight, double error, double errorUp, double errorDn, double step)
15+
{
16+
var result = weight;
17+
if(error > errorUp || error > errorDn)
18+
{
19+
if (errorDn < errorUp) result = result - step;
20+
if (errorUp < errorDn) result = result + step;
21+
}
22+
return result;
23+
}
24+
825
[Test]
926
public void ShouldGetSimpleLearnError()
1027
{
@@ -18,5 +35,36 @@ public void ShouldGetSimpleLearnError()
1835
var error = Math.Round(Math.Pow(goal - prediction, 2), 2);
1936
error.Should().Be(0.3);
2037
}
38+
39+
[Test]
40+
public void ShouldLearnByHotColdMethod()
41+
{
42+
var weight = 0.1;
43+
44+
var toes = 8.5;
45+
var winOrLose = 1;
46+
47+
var input = toes;
48+
var goal = winOrLose;
49+
50+
var prediction = GetPredictFromSimpleNetwork(input, weight);
51+
var error = Math.Round(Math.Pow(prediction - goal, 2), 3);
52+
error.Should().Be(0.022);
53+
54+
var step = 0.01;
55+
var predictUp = GetPredictFromSimpleNetwork(input, weight + step);
56+
var errorUp = Math.Round(Math.Pow(predictUp - goal, 2), 3);
57+
errorUp.Should().Be(0.004);
58+
59+
step = 0.01;
60+
var predictDn = GetPredictFromSimpleNetwork(input, weight - step);
61+
var errorDn = Math.Round(Math.Pow(predictDn - goal, 2), 3);
62+
errorDn.Should().Be(0.055);
63+
64+
var betterWeight = GetLearnNetwork(weight, error, errorUp, predictDn, step);
65+
var betterPrediction = GetPredictFromSimpleNetwork(input, betterWeight);
66+
var betterError = Math.Round(Math.Pow(betterPrediction - goal, 2), 3);
67+
betterError.Should().Be(0.004);
68+
}
2169
}
2270
}

0 commit comments

Comments
 (0)