Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0b3aa23

Browse files
Merge pull request MorvanZhou#60 from keineahnung2345/402-squeeze
402 - remove squeeze for 1-D array
2 parents 903820e + 7411f60 commit 0b3aa23

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

‎tutorial-contents/402_RNN_classifier.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
# convert test data into Variable, pick 2000 samples to speed up testing
4848
test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
4949
test_x = test_data.test_data.type(torch.FloatTensor)[:2000]/255. # shape (2000, 28, 28) value in range(0,1)
50-
test_y = test_data.test_labels.numpy().squeeze()[:2000] # covert to numpy array
50+
test_y = test_data.test_labels.numpy()[:2000] # covert to numpy array
5151

5252

5353
class RNN(nn.Module):
@@ -94,13 +94,13 @@ def forward(self, x):
9494

9595
if step % 50 == 0:
9696
test_output = rnn(test_x) # (samples, time_step, input_size)
97-
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
97+
pred_y = torch.max(test_output, 1)[1].data.numpy()
9898
accuracy = float((pred_y == test_y).astype(int).sum()) / float(test_y.size)
9999
print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy(), '| test accuracy: %.2f' % accuracy)
100100

101101
# print 10 predictions from test data
102102
test_output = rnn(test_x[:10].view(-1, 28, 28))
103-
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
103+
pred_y = torch.max(test_output, 1)[1].data.numpy()
104104
print(pred_y, 'prediction number')
105105
print(test_y[:10], 'real number')
106106

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /