Skip to content

Commit 40e99fc

Browse files
committed
add learn with freezy weight
1 parent 356dcff commit 40e99fc

File tree

1 file changed

+61
-2
lines changed

1 file changed

+61
-2
lines changed

MLTests/GradientMulInOneOutTests.cs

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ private double WeightsSum(double[] vecA, double[] vecB)
1515
return result;
1616
}
1717

18-
private double[] WeightsSub(double[] vecA, double sub)
18+
private double[] WeightSub(double[] vecA, double sub)
1919
{
2020
var result = new double[vecA.Length];
2121
for (int i = 0; i < vecA.Length; i++)
@@ -26,6 +26,17 @@ private double[] WeightsSub(double[] vecA, double sub)
2626

2727
}
2828

29+
private double[] WeightsAlphaSub(double[] vecA, double[] vecB, double alpha)
30+
{
31+
var result = new double[vecA.Length];
32+
for (int i = 0; i < vecA.Length; i++)
33+
{
34+
result[i] = vecA[i] - vecB[i] * alpha;
35+
}
36+
return result;
37+
38+
}
39+
2940
private double[] WeightsMul(double[] vecA, double[] vecB)
3041
{
3142
var result = new double[vecA.Length];
@@ -45,7 +56,32 @@ private double[] GetLearnByGradientMul(double[] weights, double[] input, double
4556
var error = Math.Pow(prediction - predictionGoal, 2);
4657
var delta = prediction - predictionGoal;
4758
var deltaWeight = delta * prediction;
48-
result = WeightsSub(result, alpha * deltaWeight);
59+
result = WeightSub(result, alpha * deltaWeight);
60+
}
61+
return result;
62+
}
63+
64+
private double[] ElementMul(double[] vecA, double element)
65+
{
66+
var result = new double[vecA.Length];
67+
for (int i = 0; i < vecA.Length; i++)
68+
{
69+
result[i] = vecA[i] * element;
70+
}
71+
return result;
72+
}
73+
74+
private double[] GetLearnByGradientFreezyOne(double[] weights, int freezyIndex, double[] input, double predictionGoal, double alpha, double iterations)
75+
{
76+
var result = weights.ToArray();
77+
for(int iteration = 0; iteration < iterations; ++iteration)
78+
{
79+
var prediction = WeightsSum(input, result);
80+
var error = Math.Pow(prediction - predictionGoal, 2);
81+
var delta = prediction - predictionGoal;
82+
var deltaWeights = ElementMul(input, delta);
83+
deltaWeights[freezyIndex] = 0;
84+
result = WeightsAlphaSub(result, deltaWeights, alpha);
4985
}
5086
return result;
5187
}
@@ -72,5 +108,28 @@ public void ShouldGradientLearnByMultipleInputOneOutput()
72108

73109
//var learnWeights =
74110
}
111+
112+
[Test]
113+
public void ShouldGradientLearnByMultipleWhenFreezyOneWeight()
114+
{
115+
var weights = new double[] { 0.1, 0.2, -0.1 };
116+
var alpha = 0.5;
117+
var iterations = 20;
118+
119+
var toes = new double[] { 8.5, 9.5, 9.9, 9.0 };
120+
var wlrec = new double[] { 0.65, 0.8, 0.8, 0.9 };
121+
var nfans = new double[] { 1.2, 1.3, 0.5, 1.0 };
122+
123+
var winOrLoseData = new double[] {1, 1, 0, -1};
124+
125+
var predictionGoal = winOrLoseData[0];
126+
var input = new double[] { toes[0], wlrec[0], nfans[0] };
127+
128+
var result = GetLearnByGradientFreezyOne(weights, 0, input, predictionGoal, alpha, iterations);
129+
130+
Assert.That(WeightsSum(result, input), Is.EqualTo(1.0));
131+
132+
//var learnWeights =
133+
}
75134
}
76135
}

0 commit comments

Comments
 (0)