Skip to content

Commit 1c391ef

Browse files
committed
add learn multiple out
1 parent 40e99fc commit 1c391ef

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

MLTests/GradientOneInMulOutTests.cs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
using FluentAssertions;
2+
using NUnit.Framework;
3+
4+
namespace MLTests
5+
{
6+
public class GradientOneInMulOutTests
7+
{
8+
private double[] ElementMul(double[] vecA, double element)
9+
{
10+
var result = new double[vecA.Length];
11+
for (int i = 0; i < vecA.Length; i++)
12+
{
13+
result[i] = vecA[i] * element;
14+
}
15+
return result;
16+
}
17+
18+
private double[] Mul(double[] vecA, double[] vecB)
19+
{
20+
var result = new double[vecA.Length];
21+
for (int i = 0; i < vecA.Length; i++)
22+
{
23+
result[i] = vecA[i] * vecB[i];
24+
}
25+
return result;
26+
}
27+
28+
private double[] Sub(double[] vecA, double[] vecB)
29+
{
30+
var result = new double[vecA.Length];
31+
for (int i = 0; i < vecA.Length; i++)
32+
{
33+
result[i] = vecA[i] - vecB[i];
34+
}
35+
return result;
36+
37+
}
38+
39+
private double[] SubA(double[] vecA, double[] vecB, double alpha)
40+
{
41+
var result = new double[vecA.Length];
42+
for (int i = 0; i < vecA.Length; i++)
43+
{
44+
result[i] = vecA[i] - vecB[i]*alpha;
45+
}
46+
return result;
47+
48+
}
49+
50+
private double[] GradientLearnByOneInputToMulOut(double[] weights, double input, double[] predictionGoal, double alpha, int iterations)
51+
{
52+
var result = weights.ToArray();
53+
for(int i = 0; i < iterations; i++)
54+
{
55+
var predictions = ElementMul(result, input);
56+
var deltas = Sub(predictions, predictionGoal);
57+
var deltasWeights = ElementMul(deltas, input);
58+
result = SubA(result, deltasWeights, alpha);
59+
}
60+
return result;
61+
}
62+
63+
[Test]
64+
public void ShouldGradientLearnMulOut()
65+
{
66+
var weights = new double[] { 0.3, 0.2, 0.9 };
67+
68+
var wlrec = new double[] { 0.65, 1.0, 1.0, 0.9 };
69+
70+
var hurt = new double[] {0.1, 0.0, 0.0, 0.1};
71+
var win = new double[] { 1, 1, 0, 1 };
72+
var sad = new double[] { 0.1, 0.0, 0.1, 0.2 };
73+
74+
var alpha = 0.2;
75+
var iterations = 3000;
76+
77+
var input = wlrec[0];
78+
var predictionGoal = new double[] { hurt[0], win[0], sad[0] };
79+
var weightsLearned = GradientLearnByOneInputToMulOut(weights, input, predictionGoal, alpha, iterations);
80+
var result = ElementMul(weightsLearned, input);
81+
Assert.That(predictionGoal[0], Is.EqualTo(Math.Round(result[0], 2)));
82+
Assert.That(predictionGoal[1], Is.EqualTo(Math.Round(result[1], 2)));
83+
Assert.That(predictionGoal[2], Is.EqualTo(Math.Round(result[2], 2)));
84+
}
85+
}
86+
}

0 commit comments

Comments
 (0)