Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 307ca25

Browse files
added mnist fc
1 parent 88c9e27 commit 307ca25

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed
File renamed without changes.

‎example_mnist_fc.py‎

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
3+
from network import Network
4+
from fc_layer import FCLayer
5+
from activation_layer import ActivationLayer
6+
from activations import tanh, tanh_prime
7+
from losses import mse, mse_prime
8+
9+
from keras.datasets import mnist
10+
from keras.utils import np_utils
11+
12+
# load MNIST from server
13+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
14+
15+
# training data : 60000 samples
16+
# reshape and normalize input data
17+
x_train = x_train.reshape(x_train.shape[0], 1, 28*28)
18+
x_train = x_train.astype('float32')
19+
x_train /= 255
20+
# encode output which is a number in range [0,9] into a vector of size 10
21+
# e.g. number 3 will become [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
22+
y_train = np_utils.to_categorical(y_train)
23+
24+
# same for test data : 10000 samples
25+
x_test = x_test.reshape(x_test.shape[0], 1, 28*28)
26+
x_test = x_test.astype('float32')
27+
x_test /= 255
28+
y_test = np_utils.to_categorical(y_test)
29+
30+
# Network
31+
net = Network()
32+
net.add(FCLayer(28*28, 100)) # input_shape=(1, 28*28) ; output_shape=(1, 100)
33+
net.add(ActivationLayer(tanh, tanh_prime))
34+
net.add(FCLayer(100, 50)) # input_shape=(1, 100) ; output_shape=(1, 50)
35+
net.add(ActivationLayer(tanh, tanh_prime))
36+
net.add(FCLayer(50, 10)) # input_shape=(1, 50) ; output_shape=(1, 10)
37+
net.add(ActivationLayer(tanh, tanh_prime))
38+
39+
# train on 1000 samples
40+
# as we didn't implemented mini-batch GD, training will be pretty slow if we update at each iteration on 60000 samples...
41+
net.use(mse, mse_prime)
42+
net.fit(x_train[0:1000], y_train[0:1000], epochs=35, learning_rate=0.1)
43+
44+
# test on 3 samples
45+
out = net.predict(x_test[0:3])
46+
print("\n")
47+
print("predicted values : ")
48+
print(out, end="\n")
49+
print("true values : ")
50+
print(y_test[0:3])

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /