Skip to content

Commit e9dea9a

Browse files
committed
add Deep Learning digital recognizer
1 parent 9e709a4 commit e9dea9a

File tree

4 files changed

+90
-2
lines changed

4 files changed

+90
-2
lines changed

MLTests/MLTests.csproj

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<OutputType>Library</OutputType>
4+
<OutputType>WinExe</OutputType>
55
<TargetFramework>net6.0</TargetFramework>
66
<ImplicitUsings>enable</ImplicitUsings>
77
<Nullable>enable</Nullable>
8+
<PlatformTarget>AnyCPU</PlatformTarget>
9+
</PropertyGroup>
10+
11+
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
12+
<Optimize>False</Optimize>
813
</PropertyGroup>
914

1015
<ItemGroup>

MLTests/NDLDigitalRecognizer.cs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using Keras.Datasets;
2+
using Numpy;
3+
4+
namespace MLTests
5+
{
6+
public class NDLDigitalRecognizer
7+
{
8+
private NDarray _weights01;
9+
private NDarray _weights12;
10+
private string _dataPath;
11+
12+
private static NDarray Relu(NDarray layer) => layer * (layer > 0);
13+
private static NDarray ReluToDerivative(NDarray layer) => np.where(layer > 0, np.array(1), np.array(0));
14+
15+
public NDLDigitalRecognizer(string dataPath)
16+
{
17+
_weights01 = np.array(0f);
18+
_weights12 = np.array(0f);
19+
_dataPath = dataPath;
20+
}
21+
22+
public void TrainOnMNISTData()
23+
{
24+
var ((x_train, y_train), (x_test, y_test)) = MNIST.LoadData(_dataPath);
25+
26+
var (images, labels) = (x_train["0:1000"].reshape(1000, 28 * 28) / 255, y_train["0:1000"]);
27+
var one_hot_labels = np.zeros(labels.len, 10);
28+
for (var id = 0; id < labels.len; id++)
29+
{
30+
var label = labels[id];
31+
one_hot_labels[id][label] = np.array(1);
32+
}
33+
labels = one_hot_labels;
34+
35+
np.random.seed(1);
36+
37+
var (alpha, iterations, hidden_size, pixels_per_image, num_labels) = (0.005f, 30, 40, 784, 10);
38+
39+
_weights01 = 0.2f * np.random.rand(pixels_per_image, hidden_size) - 0.1f;
40+
_weights12 = 0.2f * np.random.rand(hidden_size, num_labels) - 0.1f;
41+
42+
for (var i = 0; i < iterations; i++)
43+
{
44+
for (var imageId = 0; imageId < images.len; ++imageId)
45+
{
46+
var layer_0 = images[$"{imageId}:{imageId + 1}"];
47+
var layer_1 = Relu(np.dot(layer_0, _weights01));
48+
var layer_2 = np.dot(layer_1, _weights12);
49+
50+
var layer_2_delta = labels[$"{imageId}:{imageId + 1}"] - layer_2;
51+
var layer_1_delta = layer_2_delta.dot(_weights12.T) * ReluToDerivative(layer_1);
52+
53+
_weights12 += alpha * layer_1.T.dot(layer_2_delta);
54+
_weights01 += alpha * layer_0.T.dot(layer_1_delta);
55+
}
56+
}
57+
}
58+
59+
public int PredictNumberFromImageTest(int testId)
60+
{
61+
var ((x_train, y_train), (x_test, y_test)) = MNIST.LoadData(_dataPath);
62+
63+
var test_images = x_test.reshape(x_test.len, 28 * 28) / 255;
64+
var test_labels = np.zeros(y_test.len, 10);
65+
for (var id = 0; id < y_test.len; id++)
66+
{
67+
var test = y_test[id];
68+
test_labels[id][test] = np.array(1);
69+
}
70+
71+
var layer_0 = test_images[$"{testId}:{testId + 1}"];
72+
var layer_1 = Relu(np.dot(layer_0, _weights01));
73+
var layer_2 = np.dot(layer_1, _weights12);
74+
var result = layer_2.GetData<double>();
75+
var max = result.Max();
76+
return result.Select((e, index) => new {e, index}).OrderBy(e => e.e).Last().index;
77+
}
78+
}
79+
}

MLTests/Program.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
using MLTests;
22

33
//var vector = ImageToVector.GetVectorOnPath(@"./data/0/1.jpg");
4-
Console.WriteLine("ML");
4+
var recognizer = new NDLDigitalRecognizer(Path.GetFullPath(@"D:/projects.active/LearnML/data/mnist.npz"));
5+
recognizer.TrainOnMNISTData();
6+
var result = recognizer.PredictNumberFromImageTest(0);
7+
Console.WriteLine(result);
8+
Console.ReadLine();

data/mnist.npz

11 MB
Binary file not shown.

0 commit comments

Comments
 (0)