Skip to content

Commit b9f280f

Browse files
committed
add learning in n iterations
1 parent 5e40c5b commit b9f280f

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

MLTests/PredictionCompareLearnTest.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,26 @@ public double GetLearnNetwork(double weight, double error, double errorUp, doubl
2222
return result;
2323
}
2424

25+
public double LearnWeightInNIteration(double weight, double input, double predictionGoal, double step, int iterations)
26+
{
27+
var result = weight;
28+
for(int iteration = 0; iteration < iterations; iteration++)
29+
{
30+
var prediction = input * result;
31+
var error = Math.Pow(prediction - predictionGoal, 2);
32+
33+
var predictionUp = input * (result + step);
34+
var errorUp = Math.Pow(predictionUp - predictionGoal , 2);
35+
36+
var predictionDn = input * (result - step);
37+
var errorDn = Math.Pow(predictionDn - predictionGoal , 2);
38+
39+
if (errorDn < errorUp) result = result - step;
40+
if (errorUp < errorDn) result = result + step;
41+
}
42+
return result;
43+
}
44+
2545
[Test]
2646
public void ShouldGetSimpleLearnError()
2747
{
@@ -66,5 +86,21 @@ public void ShouldLearnByHotColdMethod()
6686
var betterError = Math.Round(Math.Pow(betterPrediction - goal, 2), 3);
6787
betterError.Should().Be(0.004);
6888
}
89+
90+
[Test]
91+
public void ShouldIterativeLearn()
92+
{
93+
var weight = 0.5;
94+
var input = 0.5;
95+
var predictionGoal = 0.8;
96+
var step = 0.001;
97+
var iterations = 1101;
98+
99+
var goodWeight = LearnWeightInNIteration(weight, input, predictionGoal, step, iterations);
100+
var result = Math.Round(input * goodWeight, 2);
101+
result.Should().Be(0.8);
102+
}
103+
104+
69105
}
70106
}

0 commit comments

Comments
 (0)