Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit a699e10

Browse files
update tensorboard
1 parent f8e5c60 commit a699e10

File tree

2 files changed

+224
-2
lines changed

2 files changed

+224
-2
lines changed

‎README.md‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ Learn Deep Learning with PyTorch
5555
- 深度卷积对抗网络(DCGANs)
5656

5757
- Chapter 7: PyTorch高级
58-
- [tensorboard 可视化]()
58+
- [tensorboard 可视化](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_PyTorch-Advances/tensorboard.ipynb)
5959
- 优化算法
6060
- [SGD](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_PyTorch-Advances/optimizer/sgd.ipynb)
6161
- [动量法](https://github.com/SherlockLiao/code-of-learn-deep-learning-with-pytorch/blob/master/chapter6_PyTorch-Advances/optimizer/momentum.ipynb)
@@ -72,7 +72,7 @@ Learn Deep Learning with PyTorch
7272

7373
### part2: 深度学习的应用
7474
- Chapter 8: 计算机视觉
75-
- Fine-tuning: 通过微调进行迁移学习
75+
- [Fine-tuning: 通过微调进行迁移学习]()
7676
- 语义分割: 通过 FCN 实现像素级别的分类
7777
- Neural Transfer: 通过卷积网络实现风格迁移
7878
- Deep Dream: 探索卷积网络眼中的世界
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# TensorBoard 可视化\n",
8+
"[github](https://github.com/lanpa/tensorboard-pytorch)"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 1,
14+
"metadata": {
15+
"ExecuteTime": {
16+
"end_time": "2017年12月24日T09:39:39.910789Z",
17+
"start_time": "2017年12月24日T09:39:39.398570Z"
18+
},
19+
"collapsed": true
20+
},
21+
"outputs": [],
22+
"source": [
23+
"import numpy as np\n",
24+
"import torch\n",
25+
"from torch import nn\n",
26+
"import torch.nn.functional as F\n",
27+
"from torch.autograd import Variable\n",
28+
"from torchvision.datasets import CIFAR10\n",
29+
"from utils import resnet\n",
30+
"from torchvision import transforms as tfs\n",
31+
"from datetime import datetime\n",
32+
"from tensorboardX import SummaryWriter"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 2,
38+
"metadata": {
39+
"ExecuteTime": {
40+
"end_time": "2017年12月24日T09:39:41.981293Z",
41+
"start_time": "2017年12月24日T09:39:40.621895Z"
42+
},
43+
"collapsed": true
44+
},
45+
"outputs": [],
46+
"source": [
47+
"# 使用数据增强\n",
48+
"def train_tf(x):\n",
49+
" im_aug = tfs.Compose([\n",
50+
" tfs.Resize(120),\n",
51+
" tfs.RandomHorizontalFlip(),\n",
52+
" tfs.RandomCrop(96),\n",
53+
" tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),\n",
54+
" tfs.ToTensor(),\n",
55+
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
56+
" ])\n",
57+
" x = im_aug(x)\n",
58+
" return x\n",
59+
"\n",
60+
"def test_tf(x):\n",
61+
" im_aug = tfs.Compose([\n",
62+
" tfs.Resize(96),\n",
63+
" tfs.ToTensor(),\n",
64+
" tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n",
65+
" ])\n",
66+
" x = im_aug(x)\n",
67+
" return x\n",
68+
"\n",
69+
"train_set = CIFAR10('./data', train=True, transform=train_tf)\n",
70+
"train_data = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)\n",
71+
"valid_set = CIFAR10('./data', train=False, transform=test_tf)\n",
72+
"valid_data = torch.utils.data.DataLoader(valid_set, batch_size=256, shuffle=False, num_workers=4)\n",
73+
"\n",
74+
"net = resnet(3, 10)\n",
75+
"optimizer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)\n",
76+
"criterion = nn.CrossEntropyLoss()"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": 3,
82+
"metadata": {
83+
"ExecuteTime": {
84+
"end_time": "2017年12月24日T09:53:40.434024Z",
85+
"start_time": "2017年12月24日T09:39:41.984480Z"
86+
},
87+
"collapsed": false
88+
},
89+
"outputs": [
90+
{
91+
"name": "stdout",
92+
"output_type": "stream",
93+
"text": [
94+
"Epoch 0. Train Loss: 1.877906, Train Acc: 0.315410, Valid Loss: 2.198587, Valid Acc: 0.293164, Time 00:00:26\n",
95+
"Epoch 1. Train Loss: 1.398501, Train Acc: 0.498657, Valid Loss: 1.877540, Valid Acc: 0.400098, Time 00:00:27\n",
96+
"Epoch 2. Train Loss: 1.141419, Train Acc: 0.597628, Valid Loss: 1.872355, Valid Acc: 0.446777, Time 00:00:27\n",
97+
"Epoch 3. Train Loss: 0.980048, Train Acc: 0.658367, Valid Loss: 1.672951, Valid Acc: 0.475391, Time 00:00:27\n",
98+
"Epoch 4. Train Loss: 0.871448, Train Acc: 0.695073, Valid Loss: 1.263234, Valid Acc: 0.578613, Time 00:00:28\n",
99+
"Epoch 5. Train Loss: 0.794649, Train Acc: 0.723992, Valid Loss: 2.142715, Valid Acc: 0.466699, Time 00:00:27\n",
100+
"Epoch 6. Train Loss: 0.736611, Train Acc: 0.741554, Valid Loss: 1.701331, Valid Acc: 0.500391, Time 00:00:27\n",
101+
"Epoch 7. Train Loss: 0.695095, Train Acc: 0.756816, Valid Loss: 1.385478, Valid Acc: 0.597656, Time 00:00:28\n",
102+
"Epoch 8. Train Loss: 0.652659, Train Acc: 0.773796, Valid Loss: 1.029726, Valid Acc: 0.676465, Time 00:00:27\n",
103+
"Epoch 9. Train Loss: 0.623829, Train Acc: 0.784144, Valid Loss: 0.933388, Valid Acc: 0.682520, Time 00:00:27\n",
104+
"Epoch 10. Train Loss: 0.581615, Train Acc: 0.798792, Valid Loss: 1.291557, Valid Acc: 0.635938, Time 00:00:27\n",
105+
"Epoch 11. Train Loss: 0.559358, Train Acc: 0.805708, Valid Loss: 1.430408, Valid Acc: 0.586426, Time 00:00:28\n",
106+
"Epoch 12. Train Loss: 0.534197, Train Acc: 0.816853, Valid Loss: 0.960802, Valid Acc: 0.704785, Time 00:00:27\n",
107+
"Epoch 13. Train Loss: 0.512111, Train Acc: 0.822389, Valid Loss: 0.923353, Valid Acc: 0.716602, Time 00:00:27\n",
108+
"Epoch 14. Train Loss: 0.494577, Train Acc: 0.828225, Valid Loss: 1.023517, Valid Acc: 0.687207, Time 00:00:27\n",
109+
"Epoch 15. Train Loss: 0.473396, Train Acc: 0.835212, Valid Loss: 0.842679, Valid Acc: 0.727930, Time 00:00:27\n",
110+
"Epoch 16. Train Loss: 0.459708, Train Acc: 0.840290, Valid Loss: 0.826854, Valid Acc: 0.726953, Time 00:00:28\n",
111+
"Epoch 17. Train Loss: 0.433836, Train Acc: 0.847931, Valid Loss: 0.730658, Valid Acc: 0.764258, Time 00:00:27\n",
112+
"Epoch 18. Train Loss: 0.422375, Train Acc: 0.854401, Valid Loss: 0.677953, Valid Acc: 0.778125, Time 00:00:27\n",
113+
"Epoch 19. Train Loss: 0.410208, Train Acc: 0.857370, Valid Loss: 0.787286, Valid Acc: 0.754102, Time 00:00:27\n",
114+
"Epoch 20. Train Loss: 0.395556, Train Acc: 0.862923, Valid Loss: 0.859754, Valid Acc: 0.738965, Time 00:00:27\n",
115+
"Epoch 21. Train Loss: 0.382050, Train Acc: 0.866554, Valid Loss: 1.266704, Valid Acc: 0.651660, Time 00:00:27\n",
116+
"Epoch 22. Train Loss: 0.368614, Train Acc: 0.871213, Valid Loss: 0.912465, Valid Acc: 0.738672, Time 00:00:27\n",
117+
"Epoch 23. Train Loss: 0.358302, Train Acc: 0.873964, Valid Loss: 0.963238, Valid Acc: 0.706055, Time 00:00:27\n",
118+
"Epoch 24. Train Loss: 0.347568, Train Acc: 0.879620, Valid Loss: 0.777171, Valid Acc: 0.751855, Time 00:00:27\n",
119+
"Epoch 25. Train Loss: 0.339247, Train Acc: 0.882215, Valid Loss: 0.707863, Valid Acc: 0.777734, Time 00:00:27\n",
120+
"Epoch 26. Train Loss: 0.329292, Train Acc: 0.885830, Valid Loss: 0.682976, Valid Acc: 0.790527, Time 00:00:27\n",
121+
"Epoch 27. Train Loss: 0.313049, Train Acc: 0.890761, Valid Loss: 0.665912, Valid Acc: 0.795410, Time 00:00:27\n",
122+
"Epoch 28. Train Loss: 0.305482, Train Acc: 0.891944, Valid Loss: 0.880263, Valid Acc: 0.743848, Time 00:00:27\n",
123+
"Epoch 29. Train Loss: 0.301507, Train Acc: 0.895289, Valid Loss: 1.062325, Valid Acc: 0.708398, Time 00:00:27\n"
124+
]
125+
}
126+
],
127+
"source": [
128+
"writer = SummaryWriter()\n",
129+
"\n",
130+
"def get_acc(output, label):\n",
131+
" total = output.shape[0]\n",
132+
" _, pred_label = output.max(1)\n",
133+
" num_correct = (pred_label == label).sum().data[0]\n",
134+
" return num_correct / total\n",
135+
"\n",
136+
"if torch.cuda.is_available():\n",
137+
" net = net.cuda()\n",
138+
"prev_time = datetime.now()\n",
139+
"for epoch in range(30):\n",
140+
" train_loss = 0\n",
141+
" train_acc = 0\n",
142+
" net = net.train()\n",
143+
" for im, label in train_data:\n",
144+
" if torch.cuda.is_available():\n",
145+
" im = Variable(im.cuda()) # (bs, 3, h, w)\n",
146+
" label = Variable(label.cuda()) # (bs, h, w)\n",
147+
" else:\n",
148+
" im = Variable(im)\n",
149+
" label = Variable(label)\n",
150+
" # forward\n",
151+
" output = net(im)\n",
152+
" loss = criterion(output, label)\n",
153+
" # backward\n",
154+
" optimizer.zero_grad()\n",
155+
" loss.backward()\n",
156+
" optimizer.step()\n",
157+
"\n",
158+
" train_loss += loss.data[0]\n",
159+
" train_acc += get_acc(output, label)\n",
160+
" cur_time = datetime.now()\n",
161+
" h, remainder = divmod((cur_time - prev_time).seconds, 3600)\n",
162+
" m, s = divmod(remainder, 60)\n",
163+
" time_str = \"Time %02d:%02d:%02d\" % (h, m, s)\n",
164+
" valid_loss = 0\n",
165+
" valid_acc = 0\n",
166+
" net = net.eval()\n",
167+
" for im, label in valid_data:\n",
168+
" if torch.cuda.is_available():\n",
169+
" im = Variable(im.cuda(), volatile=True)\n",
170+
" label = Variable(label.cuda(), volatile=True)\n",
171+
" else:\n",
172+
" im = Variable(im, volatile=True)\n",
173+
" label = Variable(label, volatile=True)\n",
174+
" output = net(im)\n",
175+
" loss = criterion(output, label)\n",
176+
" valid_loss += loss.data[0]\n",
177+
" valid_acc += get_acc(output, label)\n",
178+
" epoch_str = (\n",
179+
" \"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, \"\n",
180+
" % (epoch, train_loss / len(train_data),\n",
181+
" train_acc / len(train_data), valid_loss / len(valid_data),\n",
182+
" valid_acc / len(valid_data)))\n",
183+
" prev_time = cur_time\n",
184+
" # ====================== 使用 tensorboard ==================\n",
185+
" writer.add_scalars('Loss', {'train': train_loss / len(train_data),\n",
186+
" 'valid': valid_loss / len(valid_data)}, epoch)\n",
187+
" writer.add_scalars('Acc', {'train': train_acc / len(train_data),\n",
188+
" 'valid': valid_acc / len(valid_data)}, epoch)\n",
189+
" # =========================================================\n",
190+
" print(epoch_str + time_str)"
191+
]
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"metadata": {},
196+
"source": [
197+
"![](https://ws1.sinaimg.cn/large/006tNc79ly1fms31s3i4yj31gc0qimy6.jpg)"
198+
]
199+
}
200+
],
201+
"metadata": {
202+
"kernelspec": {
203+
"display_name": "Python 3",
204+
"language": "python",
205+
"name": "python3"
206+
},
207+
"language_info": {
208+
"codemirror_mode": {
209+
"name": "ipython",
210+
"version": 3
211+
},
212+
"file_extension": ".py",
213+
"mimetype": "text/x-python",
214+
"name": "python",
215+
"nbconvert_exporter": "python",
216+
"pygments_lexer": "ipython3",
217+
"version": "3.6.2"
218+
}
219+
},
220+
"nbformat": 4,
221+
"nbformat_minor": 2
222+
}

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /