Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 5b1d191

Browse files
302 - remove .squeeze() from 1-D numpy array
prediction.data.numpy() is already 1-D, so the .squeeze() is unnecessary.
1 parent 906cf71 commit 5b1d191

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

‎tutorial-contents/302_classification.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ def forward(self, x):
5959
# plot and show learning process
6060
plt.cla()
6161
prediction = torch.max(out, 1)[1]
62-
pred_y = prediction.data.numpy().squeeze()
62+
pred_y = prediction.data.numpy()
6363
target_y = y.data.numpy()
6464
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, lw=0, cmap='RdYlGn')
6565
accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
6666
plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'})
6767
plt.pause(0.1)
6868

6969
plt.ioff()
70-
plt.show()
70+
plt.show()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /