Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 906cf71

Browse files
update for new version of torch
1 parent ce55cc9 commit 906cf71

1 file changed

Lines changed: 56 additions & 50 deletions

File tree

‎tutorial-contents/504_batch_normalization.py‎

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch import nn
1212
from torch.nn import init
1313
import torch.utils.data as Data
14-
import torch.nn.functional as F
1514
import matplotlib.pyplot as plt
1615
import numpy as np
1716

@@ -24,7 +23,7 @@
2423
EPOCH = 12
2524
LR = 0.03
2625
N_HIDDEN = 8
27-
ACTIVATION = F.tanh
26+
ACTIVATION = torch.tanh
2827
B_INIT = -0.2 # use a bad bias constant initializer
2928

3029
# training data
@@ -48,6 +47,7 @@
4847
plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')
4948
plt.legend(loc='upper left')
5049

50+
5151
class Net(nn.Module):
5252
def __init__(self, batch_normalization=False):
5353
super(Net, self).__init__()
@@ -89,20 +89,20 @@ def forward(self, x):
8989

9090
nets = [Net(batch_normalization=False), Net(batch_normalization=True)]
9191

92-
print(*nets) # print net architecture
92+
# print(*nets) # print net architecture
9393

9494
opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]
9595

9696
loss_func = torch.nn.MSELoss()
9797

98-
f, axs = plt.subplots(4, N_HIDDEN+1, figsize=(10, 5))
99-
plt.ion() # something about plotting
100-
plt.show()
98+
10199
def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
102-
for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
100+
for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):
103101
[a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]
104-
if i == 0: p_range = (-7, 10);the_range = (-7, 10)
105-
else:p_range = (-4, 4);the_range = (-1, 1)
102+
if i == 0:
103+
p_range = (-7, 10);the_range = (-7, 10)
104+
else:
105+
p_range = (-4, 4);the_range = (-1, 1)
106106
ax_pa.set_title('L' + str(i))
107107
ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)
108108
ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')
@@ -111,44 +111,50 @@ def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
111111
axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')
112112
plt.pause(0.01)
113113

114-
# training
115-
losses = [[], []] # recode loss for two networks
116-
for epoch in range(EPOCH):
117-
print('Epoch: ', epoch)
118-
layer_inputs, pre_acts = [], []
119-
for net, l in zip(nets, losses):
120-
net.eval() # set eval mode to fix moving_mean and moving_var
121-
pred, layer_input, pre_act = net(test_x)
122-
l.append(loss_func(pred, test_y).data[0])
123-
layer_inputs.append(layer_input)
124-
pre_acts.append(pre_act)
125-
net.train() # free moving_mean and moving_var
126-
plot_histogram(*layer_inputs, *pre_acts) # plot histogram
127-
128-
for step, (b_x, b_y) in enumerate(train_loader):
129-
for net, opt in zip(nets, opts): # train for each network
130-
pred, _, _ = net(b_x)
131-
loss = loss_func(pred, b_y)
132-
opt.zero_grad()
133-
loss.backward()
134-
opt.step() # it will also learns the parameters in Batch Normalization
135-
136-
137-
plt.ioff()
138-
139-
# plot training loss
140-
plt.figure(2)
141-
plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
142-
plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
143-
plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best')
144-
145-
# evaluation
146-
# set net to eval mode to freeze the parameters in batch normalization layers
147-
[net.eval() for net in nets] # set eval mode to fix moving_mean and moving_var
148-
preds = [net(test_x)[0] for net in nets]
149-
plt.figure(3)
150-
plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')
151-
plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')
152-
plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')
153-
plt.legend(loc='best')
154-
plt.show()
114+
115+
if __name__ == "__main__":
116+
f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5))
117+
plt.ion() # something about plotting
118+
plt.show()
119+
120+
# training
121+
losses = [[], []] # recode loss for two networks
122+
123+
for epoch in range(EPOCH):
124+
print('Epoch: ', epoch)
125+
layer_inputs, pre_acts = [], []
126+
for net, l in zip(nets, losses):
127+
net.eval() # set eval mode to fix moving_mean and moving_var
128+
pred, layer_input, pre_act = net(test_x)
129+
l.append(loss_func(pred, test_y).data.item())
130+
layer_inputs.append(layer_input)
131+
pre_acts.append(pre_act)
132+
net.train() # free moving_mean and moving_var
133+
plot_histogram(*layer_inputs, *pre_acts) # plot histogram
134+
135+
for step, (b_x, b_y) in enumerate(train_loader):
136+
for net, opt in zip(nets, opts): # train for each network
137+
pred, _, _ = net(b_x)
138+
loss = loss_func(pred, b_y)
139+
opt.zero_grad()
140+
loss.backward()
141+
opt.step() # it will also learns the parameters in Batch Normalization
142+
143+
plt.ioff()
144+
145+
# plot training loss
146+
plt.figure(2)
147+
plt.plot(losses[0], c='#FF9359', lw=3, label='Original')
148+
plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')
149+
plt.xlabel('step');plt.ylabel('test loss');plt.ylim((0, 2000));plt.legend(loc='best')
150+
151+
# evaluation
152+
# set net to eval mode to freeze the parameters in batch normalization layers
153+
[net.eval() for net in nets] # set eval mode to fix moving_mean and moving_var
154+
preds = [net(test_x)[0] for net in nets]
155+
plt.figure(3)
156+
plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')
157+
plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')
158+
plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')
159+
plt.legend(loc='best')
160+
plt.show()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /