Skip to content

Commit 8ccf060

Browse files
committed
add alpha koef meaning
1 parent d75ff5f commit 8ccf060

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using FluentAssertions;
2+
using NUnit.Framework;
3+
4+
namespace MLTests
5+
{
6+
public class GradientLearnWithAlphaTests
7+
{
8+
private double LearnWeightByGradientAlpha(double weight, double input, double predictionGoal, double alpha, int iterations)
9+
{
10+
var result = weight;
11+
for (int iteration = 0; iteration < iterations; iteration++)
12+
{
13+
var prediction = input * result;
14+
var error = Math.Pow(prediction - predictionGoal, 2);
15+
var delta = prediction - predictionGoal;
16+
var weightDelta = delta * input;
17+
result = result - alpha * weightDelta;
18+
}
19+
return result;
20+
}
21+
22+
private double FunctionLearnWeight(double weight, double input, double predictionGoal, double alpha, int iterations)
23+
{
24+
var result = weight;
25+
for(int iteration = 0; iteration < iterations; iteration++)
26+
{
27+
var prediction = result * input;
28+
var error = Math.Pow(prediction - predictionGoal, 2);
29+
var derivative = (prediction - predictionGoal) * input;
30+
result = result - alpha * derivative;
31+
}
32+
return result;
33+
}
34+
35+
[Test]
36+
public void ShouldNotBreakLearnWeight()
37+
{
38+
var (weight, input, predictionGoal, alpha, iterations) = (0.0, 2.0, 0.8, 0.1, 10);
39+
var learnedWeight = FunctionLearnWeight(weight, input, predictionGoal, alpha, iterations);
40+
var result = Math.Round(learnedWeight * input, 2);
41+
result.Should().Be(0.8);
42+
}
43+
}
44+
}

0 commit comments

Comments
 (0)