Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 689a58a

Browse files
author
xyliao
committed
update char rnn
1 parent 3575751 commit 689a58a

File tree

3 files changed

+55
-21
lines changed

3 files changed

+55
-21
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit b679ba82feb3c50bfa763f2035a7c5cdfbe72952

‎chapter9_Computer-Vision/fine_tune/main.py‎

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,14 @@ def test_tf(img):
3838

3939
def get_train_data():
4040
train_set = ImageFolder(opt.train_data_path, train_tf)
41-
return DataLoader(train_set, opt.batch_size, True, num_workers=opt.num_workers)
41+
return DataLoader(
42+
train_set, opt.batch_size, True, num_workers=opt.num_workers)
4243

4344

4445
def get_test_data():
4546
test_set = ImageFolder(opt.test_data_path, test_tf)
46-
return DataLoader(test_set, opt.batch_size, True, num_workers=opt.num_workers)
47+
return DataLoader(
48+
test_set, opt.batch_size, True, num_workers=opt.num_workers)
4749

4850

4951
def get_model():
@@ -59,8 +61,11 @@ def get_loss(score, label):
5961

6062

6163
def get_optimizer(model):
62-
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum,
63-
weight_decay=opt.weight_decay)
64+
optimizer = torch.optim.SGD(
65+
model.parameters(),
66+
lr=opt.lr,
67+
momentum=opt.momentum,
68+
weight_decay=opt.weight_decay)
6469
return ScheduledOptim(optimizer)
6570

6671

@@ -75,6 +80,7 @@ def __init__(self):
7580
self.metric_meter['acc'] = meter.AverageValueMeter()
7681

7782
def train(self, kwargs):
83+
self.reset_meter()
7884
self.model.train()
7985
train_data = kwargs['train_data']
8086
for data in tqdm(train_data):
@@ -101,8 +107,12 @@ def train(self, kwargs):
101107

102108
# Update to tensorboard.
103109
if (self.n_iter + 1) % opt.plot_freq == 0:
104-
self.writer.add_scalars('loss', {'train': self.metric_meter['loss'].value()[0]}, self.n_plot)
105-
self.writer.add_scalars('acc', {'train': self.metric_meter['acc'].value()[0]}, self.n_plot)
110+
self.writer.add_scalars(
111+
'loss', {'train': self.metric_meter['loss'].value()[0]},
112+
self.n_plot)
113+
self.writer.add_scalars(
114+
'acc', {'train': self.metric_meter['acc'].value()[0]},
115+
self.n_plot)
106116
self.n_plot += 1
107117
self.n_iter += 1
108118

@@ -111,6 +121,7 @@ def train(self, kwargs):
111121
self.metric_log['train acc'] = self.metric_meter['acc'].value()[0]
112122

113123
def test(self, kwargs):
124+
self.reset_meter()
114125
self.model.eval()
115126
test_data = kwargs['test_data']
116127
for data in tqdm(test_data):
@@ -129,8 +140,11 @@ def test(self, kwargs):
129140
self.metric_meter['acc'].add(acc.data[0])
130141

131142
# Update to tensorboard.
132-
self.writer.add_scalars('loss', {'test': self.metric_meter['loss'].value()[0]}, self.n_plot)
133-
self.writer.add_scalars('acc', {'test': self.metric_meter['acc'].value()[0]}, self.n_plot)
143+
self.writer.add_scalars('loss',
144+
{'test': self.metric_meter['loss'].value()[0]},
145+
self.n_plot)
146+
self.writer.add_scalars(
147+
'acc', {'test': self.metric_meter['acc'].value()[0]}, self.n_plot)
134148
self.n_plot += 1
135149

136150
# Log the test metric to dict.

‎chapter9_Computer-Vision/segmentation/main.py‎

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626

2727

2828
def get_data(is_train):
29-
voc_data = VocSegDataset(opt.voc_root, is_train, opt.crop_size, img_transforms)
30-
return DataLoader(voc_data, opt.batch_size, True, num_workers=opt.num_workers)
29+
voc_data = VocSegDataset(opt.voc_root, is_train, opt.crop_size,
30+
img_transforms)
31+
return DataLoader(
32+
voc_data, opt.batch_size, True, num_workers=opt.num_workers)
3133

3234

3335
def get_model(num_classes):
@@ -38,7 +40,8 @@ def get_model(num_classes):
3840

3941

4042
def get_optimizer(model):
41-
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
43+
optimizer = torch.optim.SGD(
44+
model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
4245
return ScheduledOptim(optimizer)
4346

4447

@@ -64,6 +67,7 @@ def __init__(self):
6467
self.metric_meter[m] = meter.AverageValueMeter()
6568

6669
def train(self, kwargs):
70+
self.reset_meter()
6771
self.model.train()
6872
train_data = kwargs['train_data']
6973
for data in tqdm(train_data):
@@ -97,28 +101,37 @@ def train(self, kwargs):
97101

98102
if (self.n_iter + 1) % opt.plot_freq == 0:
99103
# Plot metrics curve in tensorboard.
100-
self.writer.add_scalars('loss', {'train': self.metric_meter['loss'].value()[0]}, self.n_plot)
101-
self.writer.add_scalars('acc', {'train': self.metric_meter['acc'].value()[0]}, self.n_plot)
102-
self.writer.add_scalars('iou', {'train': self.metric_meter['iou'].value()[0]}, self.n_plot)
104+
self.writer.add_scalars(
105+
'loss', {'train': self.metric_meter['loss'].value()[0]},
106+
self.n_plot)
107+
self.writer.add_scalars(
108+
'acc', {'train': self.metric_meter['acc'].value()[0]},
109+
self.n_plot)
110+
self.writer.add_scalars(
111+
'iou', {'train': self.metric_meter['iou'].value()[0]},
112+
self.n_plot)
103113

104114
# Show segmentation images.
105115
# Get prediction segmentation and ground truth segmentation.
106116
origin_image = inverse_normalization(imgs[0].cpu().data)
107117
pred_seg = cm[pred_labels[0]]
108118
gt_seg = cm[true_labels[0]]
109119

110-
self.writer.add_image('train ori_img', origin_image, self.n_plot)
120+
self.writer.add_image('train ori_img', origin_image,
121+
self.n_plot)
111122
self.writer.add_image('train gt', gt_seg, self.n_plot)
112123
self.writer.add_image('train pred', pred_seg, self.n_plot)
113124
self.n_plot += 1
114125

115126
self.n_iter += 1
116127

117128
self.metric_log['Train Loss'] = self.metric_meter['loss'].value()[0]
118-
self.metric_log['Train Mean Class Accuracy'] = self.metric_meter['acc'].value()[0]
129+
self.metric_log['Train Mean Class Accuracy'] = self.metric_meter[
130+
'acc'].value()[0]
119131
self.metric_log['Train Mean IoU'] = self.metric_meter['iou'].value()[0]
120132

121133
def test(self, kwargs):
134+
self.reset_meter()
122135
self.model.eval()
123136
test_data = kwargs['test_data']
124137
for data in tqdm(test_data):
@@ -146,9 +159,13 @@ def test(self, kwargs):
146159
self.metric_meter['iou'].add(eval_metrics['miou'])
147160

148161
# Plot metrics curve in tensorboard.
149-
self.writer.add_scalars('loss', {'test': self.metric_meter['loss'].value()[0]}, self.n_plot)
150-
self.writer.add_scalars('acc', {'test': self.metric_meter['acc'].value()[0]}, self.n_plot)
151-
self.writer.add_scalars('iou', {'test': self.metric_meter['iou'].value()[0]}, self.n_plot)
162+
self.writer.add_scalars('loss',
163+
{'test': self.metric_meter['loss'].value()[0]},
164+
self.n_plot)
165+
self.writer.add_scalars(
166+
'acc', {'test': self.metric_meter['acc'].value()[0]}, self.n_plot)
167+
self.writer.add_scalars(
168+
'iou', {'test': self.metric_meter['iou'].value()[0]}, self.n_plot)
152169

153170
origin_img = inverse_normalization(imgs[0].cpu().data)
154171
pred_seg = cm[pred_labels[0]]
@@ -160,7 +177,8 @@ def test(self, kwargs):
160177
self.n_plot += 1
161178

162179
self.metric_log['Test Loss'] = self.metric_meter['loss'].value()[0]
163-
self.metric_log['Test Mean Class Accuracy'] = self.metric_meter['acc'].value()[0]
180+
self.metric_log['Test Mean Class Accuracy'] = self.metric_meter[
181+
'acc'].value()[0]
164182
self.metric_log['Test Mean IoU'] = self.metric_meter['iou'].value()[0]
165183

166184
def get_best_model(self):
@@ -178,7 +196,8 @@ def train(**kwargs):
178196
fcn_trainer = FcnTrainer()
179197
train_data = get_data(is_train=True)
180198
test_data = get_data(is_train=False)
181-
fcn_trainer.fit(train_data=train_data, test_data=test_data, epochs=opt.max_epoch)
199+
fcn_trainer.fit(
200+
train_data=train_data, test_data=test_data, epochs=opt.max_epoch)
182201

183202

184203
if __name__ == '__main__':

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /