Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit dcbd424

Browse files
Merge pull request MorvanZhou#12 from YJH-666/master
mnist dataset download setting
2 parents 2ae0c0f + 895ad9a commit dcbd424

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

‎tutorial-contents/401_CNN.py‎

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
torchvision
88
matplotlib
99
"""
10+
# library
11+
# standard library
12+
import os
13+
14+
# third-party library
1015
import torch
1116
import torch.nn as nn
1217
from torch.autograd import Variable
@@ -20,16 +25,20 @@
2025
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
2126
BATCH_SIZE = 50
2227
LR = 0.001 # learning rate
23-
DOWNLOAD_MNIST = True# set to False if you have downloaded
28+
DOWNLOAD_MNIST = False
2429

2530

2631
# Mnist digits dataset
32+
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
33+
# not mnist dir or mnist is empyt dir
34+
DOWNLOAD_MNIST = True
35+
2736
train_data = torchvision.datasets.MNIST(
2837
root='./mnist/',
2938
train=True, # this is training data
3039
transform=torchvision.transforms.ToTensor(), # Converts a PIL.Image or numpy.ndarray to
3140
# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
32-
download=DOWNLOAD_MNIST,# download it if you don't have it
41+
download=DOWNLOAD_MNIST,
3342
)
3443

3544
# plot one example

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /