Skip to content

Commit 14fc172

Browse files
committed
add numpy sample mul in mul out neural net
1 parent d2412b2 commit 14fc172

File tree

5 files changed

+72
-16
lines changed

5 files changed

+72
-16
lines changed

MLTests/MulInMulOut.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public MulInMulOut(double[][] weights)
99
}
1010
public double[] GetPrediction(double[] vec)
1111
{
12+
return VectorMath.VectMatMul(vec, _weights);
1213
}
1314
}
1415
}

MLTests/MulInMulOutTests.cs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
using NUnit.Framework;
1+
using FluentAssertions;
2+
using Numpy;
3+
using NUnit.Framework;
24

35
namespace MLTests
46
{
@@ -12,7 +14,65 @@ public void ShouldMulPredict()
1214
new double[] {0.1, 0.2, 0.0},
1315
new double[] {0.0, 1.3, 0.1},
1416
});
17+
var result = mmInOut.GetPrediction(new double[] { 8.5, 0.65, 1.2 });
18+
result.SequenceEqual(new double[] { 0.555, 0.98000000000000009, 0.96500000000000008 }).Should().Be(true);
1519
}
1620

21+
[Test]
22+
public void ShouldImposeNetworks()
23+
{
24+
var firstNetwork = new MulInMulOut(new double[][] {
25+
new double[] {0.1, 0.2, -0.1},
26+
new double[] {-0.1, 0.1, 0.9},
27+
new double[] {0.1, 0.4, 0.1},
28+
});
29+
var secondNetwork = new MulInMulOut(new double[][] {
30+
new double[] {0.3, 1.1, -0.3},
31+
new double[] {0.1, 0.2, 0.0},
32+
new double[] {0.0, 1.3, 0.1},
33+
});
34+
var input = new double[] { 8.5, 0.65, 1.2 };
35+
36+
var intermediateResult = firstNetwork.GetPrediction(input);
37+
var result = secondNetwork.GetPrediction(intermediateResult);
38+
39+
result.SequenceEqual(new double[] { 0.21350000000000002, 0.14500000000000002, 0.5065 }).Should().Be(true);
40+
}
41+
42+
[Test]
43+
public void ShouldImposeNetworksByNumpy()
44+
{
45+
/*
46+
var weightsA = np.array(new double[,] {
47+
{0.1, -0.1, 0.1},
48+
{0.2, 0.1, 0.4},
49+
{-0.1, 0.9, 0.1},
50+
});
51+
var weightsB = np.array(new double[,] {
52+
{0.3, 0.1, 0.0},
53+
{1.1, 0.2, 1.3},
54+
{-0.3, 0.0, 0.1},
55+
});
56+
*/
57+
58+
var weightsA = np.array(new double[,] {
59+
{0.1, 0.2, -0.1},
60+
{-0.1, 0.1, 0.9},
61+
{0.1, 0.4, 0.1},
62+
}).T;
63+
var weightsB = np.array(new double[,] {
64+
{0.3, 1.1, -0.3},
65+
{0.1, 0.2, 0.0},
66+
{0.0, 1.3, 0.1},
67+
}).T;
68+
69+
var input = np.array(new double[] { 8.5, 0.65, 1.2 });
70+
71+
var hid = input.dot(weightsA);
72+
var result = hid.dot(weightsB);
73+
result.GetData<double>().SequenceEqual(new double[] { 0.21350000000000002, 0.14500000000000002, 0.5065 }).Should().Be(true);
74+
75+
//(double[])result.SequenceEqual(new double[] { 0.21350000000000002, 0.14500000000000002, 0.5065 }).Should().Be(true);
76+
}
1777
}
1878
}

MLTests/MultipleInputNeuralTests.cs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
using FluentAssertions;
2-
using NUnit.Framework;
3-
using System;
4-
using System.Collections.Generic;
5-
using System.Linq;
6-
using System.Text;
7-
using System.Threading.Tasks;
82
using Numpy;
3+
using NUnit.Framework;
94

105
namespace MLTests
116
{
@@ -18,9 +13,9 @@ public void ShouldMultiplePredict()
1813
double[] PlayedWin = new double[] { 0.65, 0.8, 0.8, 0.9 };
1914
double[] PlayedFanats = new double[] { 1.2, 1.3, 0.5, 1.0 };
2015

21-
MultipleWeightsNeural neural = new MultipleWeightsNeural() { Weights = np.array(new double[] {0.1, 0.2, 0}) };
16+
MultipleWeightsNeural neural = new MultipleWeightsNeural() { Weights = np.array(new double[] { 0.1, 0.2, 0 }) };
2217

23-
double result = (double) neural.Prediction(np.array(new double[] { PlayedTime[0], PlayedWin[0], PlayedFanats[0] }));
18+
double result = (double)neural.Prediction(np.array(new double[] { PlayedTime[0], PlayedWin[0], PlayedFanats[0] }));
2419
result.Should().BeInRange(0.97, 0.99);
2520
}
2621
}

MLTests/SimplestNeuralTests.cs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
using FluentAssertions;
22
using NUnit.Framework;
3-
using System;
4-
using System.Collections.Generic;
5-
using System.Linq;
6-
using System.Text;
7-
using System.Threading.Tasks;
83

94
namespace MLTests
105
{
@@ -14,11 +9,11 @@ public class SimplestNeuralTests
149
public void ShouldSimplePredict()
1510
{
1611
// reverse list of number
17-
12+
1813
// {0, 1} принадлежит 0.0000001134
1914
//[8.5, 9.5, 10, 9]
2015
// 10 * 0.5 = 5
21-
SimpiestNeural neuron = new SimpiestNeural() { Weight = 0.1};
16+
SimpiestNeural neuron = new SimpiestNeural() { Weight = 0.1 };
2217
neuron.Prediction(8.5).Should().BeInRange(0.84, 0.86);
2318
//gameData.GetGameState()
2419
}

MLTests/VectorMath.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ public static double[] ElementwiseMultiplication(double[] vecA, double[] vecB)
1414
return vecA.Select((elementA, index) => elementA * vecB[index]).ToArray();
1515
}
1616

17+
public static double[] VectMatMul(double[] vecA, double[][] matrix)
18+
{
19+
return vecA.Select((_, index) => ElementwiseMultiplication(vecA, matrix[index]).Sum()).ToArray();
20+
}
21+
1722
public static double[] ElementMul(double element, double[] vec)
1823
{
1924
var result = new double[vec.Length];

0 commit comments

Comments
 (0)