Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1a4dc0d

Browse files
authored
Create NumberAI.cpp
1 parent 00ec09e commit 1a4dc0d

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

‎NumberAI.cpp

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
#pragma GCC optimize(2)
2+
#include<iostream>
3+
#include<cmath>
4+
#include<algorithm>
5+
#include<graphics.h>
6+
using namespace std;
7+
#define e 2.718281828
8+
const int number=300;
9+
const int train_time=1250;
10+
const int Smaller=20;
11+
// num cnt pixel
12+
float ima[30][6000][784];
13+
float learn_rate=0.01;
14+
const int n=20;
15+
#include"headfile/Dataloader.hpp"
16+
inline float sigmoid(float x){
17+
return 1/(1+pow(e,-x));
18+
}
19+
inline float sigmoid_derivative(float x){
20+
return x * (1 - x);
21+
}
22+
inline float mse_loss(float predicted, float actual) {
23+
return 0.5 * (predicted - actual) * (predicted - actual);
24+
}
25+
class NODE{
26+
public:
27+
NODE * from[1000];
28+
float w[1000];
29+
float b;
30+
int cnt;
31+
float value;
32+
void band(NODE * node){
33+
from[cnt++]=node;
34+
}
35+
void init(){
36+
for(int i=0;i<cnt;i++) w[i]=1.0*(rand()%200-100)/100;
37+
b=1.0*(rand()%200-100)/100;
38+
}
39+
void run(){
40+
value=0;
41+
for(int i=0;i<cnt;i++){
42+
value+=from[i]->get_value()*w[i];
43+
}
44+
value+=b;
45+
value=sigmoid(value);
46+
}
47+
float get_value(){
48+
return value;
49+
}
50+
void set(float v){
51+
value=v;
52+
}
53+
};
54+
class AI{
55+
public:
56+
NODE Input[784];
57+
NODE Hidden[2][n]; // 两个隐藏层,每个层20个节点
58+
NODE Output[10];
59+
void init(){
60+
// 初始化第一个隐藏层
61+
for(int i=0; i<n; i++){
62+
for(int j=0; j<784; j++){
63+
Hidden[0][i].band(&Input[j]);
64+
}
65+
Hidden[0][i].init();
66+
}
67+
// 初始化第二个隐藏层
68+
for(int i=0; i<n; i++){
69+
for(int j=0; j<n; j++){
70+
Hidden[1][i].band(&Hidden[0][j]);
71+
}
72+
Hidden[1][i].init();
73+
}
74+
// 初始化输出层
75+
for(int i=0; i<10; i++){
76+
for(int j=0; j<n; j++){
77+
Output[i].band(&Hidden[1][j]);
78+
}
79+
Output[i].init();
80+
}
81+
}
82+
void run(int num, int id){
83+
for(int i=0; i<784; i++){
84+
Input[i].set(ima[num][id][i]);
85+
}
86+
for(int i=0; i<n; i++){
87+
Hidden[0][i].run();
88+
}
89+
for(int i=0; i<n; i++){
90+
Hidden[1][i].run();
91+
}
92+
for(int i=0; i<10; i++){
93+
Output[i].run();
94+
}
95+
}
96+
float train(){
97+
float loss_sum=0;
98+
for (int num = 0; num < 10; num++) {
99+
for (int id = 1; id <= number; id++) {
100+
run(num, id); // 运行网络
101+
for (int i = 0; i < 10; i++) {
102+
register float predicted = Output[i].get_value();
103+
register float actual = i == num ? 1 : 0;
104+
register float loss = mse_loss(predicted, actual);
105+
loss_sum+=loss;
106+
// 反向传播 输出层
107+
register float output_grad = predicted - actual;
108+
for (int j = 0; j < n; j++) {
109+
Output[i].w[j] -= learn_rate * output_grad * Hidden[1][j].get_value();
110+
}
111+
Output[i].b -= learn_rate * output_grad;
112+
113+
// 反向传播 第二个隐藏层
114+
register float hidden2_grad = 0;
115+
for (int j = 0; j < 10; j++) {
116+
hidden2_grad += Output[j].w[i] * (Output[j].get_value() - (j == num ? 1 : 0));
117+
}
118+
hidden2_grad *= sigmoid_derivative(Hidden[1][i].get_value()); // sigmoid函数导数
119+
for (int j = 0; j < n; j++) {
120+
Hidden[1][i].w[j] -= learn_rate * hidden2_grad * Hidden[0][j].get_value();
121+
}
122+
Hidden[1][i].b -= learn_rate * hidden2_grad;
123+
124+
// 反向传播 第一个隐藏层
125+
register float hidden1_grad = 0;
126+
for (int j = 0; j < n; j++) {
127+
hidden1_grad += Hidden[1][j].w[i] * hidden2_grad;
128+
}
129+
hidden1_grad *= sigmoid_derivative(Hidden[0][i].get_value()); // sigmoid函数导数
130+
for (int j = 0; j < 784; j++) {
131+
Hidden[0][i].w[j] -= learn_rate * hidden1_grad * Input[j].get_value();
132+
}
133+
Hidden[0][i].b -= learn_rate * hidden1_grad;
134+
}
135+
}
136+
}
137+
return loss_sum;
138+
}
139+
};
140+
AI ai;
141+
int main(){
142+
initgraph(train_time,800);
143+
load_data();
144+
ai.init();
145+
for(int n=1;n<=1;n++){
146+
cout<<endl<<"------------训练中"<<n<<"------------" <<endl;
147+
int last=800;
148+
int last2=800;
149+
for(int i=1;i<train_time;i++){
150+
if(i==train_time/6) learn_rate/=Smaller;
151+
if(i==train_time/6*2) learn_rate/=Smaller;
152+
if(i==train_time/6*3) learn_rate/=Smaller;
153+
if(i==train_time/6*4) learn_rate/=Smaller;
154+
if(i==train_time/6*5) learn_rate/=Smaller;
155+
int loss=ai.train();
156+
setcolor(EGERGB(255,255,255));
157+
/*if(loss>last){
158+
setcolor(EGERGB(255,0,0));
159+
xyprintf(i,800-loss,"%d",i);
160+
}*/
161+
line(i,800-loss,i-1,800-last);
162+
last=loss;
163+
164+
float loss2=0;
165+
for(int i=10;i<20;i++){
166+
for(int j=1;j<=10;j++){
167+
ai.run(i,j);
168+
for(int tmp=1;tmp<10;tmp++){
169+
loss2+=mse_loss(tmp==i-10?1:0,ai.Output[tmp].get_value());
170+
}
171+
}
172+
}
173+
loss2*=18;
174+
setcolor(EGERGB(0,255,0));
175+
if(loss2>last2) setcolor(EGERGB(255,0,0));
176+
line(i,800-loss2,i-1,800-last2);
177+
last2=loss2;
178+
printf("round:%d loss:%d val:%d \r",i,loss,(int)loss2);
179+
180+
//Sleep(1);
181+
}
182+
cout<<endl<<"------------训练结束------------" <<endl;
183+
cout<<endl<<"------------结果------------" <<endl;
184+
for(int i=0;i<10;i++){
185+
cout<<"number"<<i<<endl;
186+
cout<<"predict:";
187+
for(int j=1;j<=number;j++){
188+
ai.run(i,j);
189+
float maxx=ai.Output[0].get_value();
190+
int res=0;
191+
for(int tmp=1;tmp<10;tmp++){
192+
//printf("%.3f ",ai.Output[tmp].get_value());
193+
if(ai.Output[tmp].get_value()>maxx){
194+
res=tmp;
195+
maxx=ai.Output[tmp].get_value();
196+
}
197+
}
198+
//cout<<endl;
199+
cout<<res<<" ";
200+
}
201+
cout<<endl;
202+
}
203+
learn_rate/=2.0;
204+
}
205+
for(int i=10;i<20;i++){
206+
cout<<"number"<<i-10<<endl;
207+
cout<<"predict:";
208+
for(int j=1;j<=10;j++){
209+
ai.run(i,j);
210+
float maxx=ai.Output[0].get_value();
211+
int res=0;
212+
for(int tmp=1;tmp<10;tmp++){
213+
//printf("%.3f ",ai.Output[tmp].get_value());
214+
if(ai.Output[tmp].get_value()>maxx){
215+
res=tmp;
216+
maxx=ai.Output[tmp].get_value();
217+
}
218+
}
219+
//cout<<endl;
220+
cout<<res<<" ";
221+
}
222+
cout<<endl;
223+
}
224+
getch();
225+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /