Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 3a9d4f4

Browse files
author
xyliao
committed
finish fine tune trainer
1 parent 3534e1c commit 3a9d4f4

File tree

2 files changed

+210
-0
lines changed

2 files changed

+210
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# encoding: utf-8
2+
"""
3+
@author: xyliao
4+
@contact: xyliao1993@qq.com
5+
"""
6+
import warnings
7+
from pprint import pprint
8+
9+
10+
class DefaultConfig(object):
11+
model = 'resnet50'
12+
# Dataset.
13+
train_data_path = './hymenoptera_data/train/'
14+
test_data_path = './hymenoptera_data/val/'
15+
16+
# Store result and save models.
17+
# result_file = 'result.txt'
18+
save_file = './checkpoints/'
19+
save_freq = 30 # save model every N epochs
20+
save_best = True # If save best test metric model.
21+
22+
# Visualization results on tensorboard.
23+
# vis_dir = './vis/'
24+
plot_freq = 100 # plot in tensorboard every N iterations
25+
26+
# Model hyperparameters.
27+
use_gpu = True # use GPU or not
28+
ctx = 0 # running on which cuda device
29+
batch_size = 64 # batch size
30+
num_workers = 4 # how many workers for loading data
31+
max_epoch = 30
32+
lr = 1e-2 # initial learning rate
33+
momentum = 0
34+
weight_decay = 1e-4
35+
lr_decay = 0.95
36+
# lr_decay_freq = 10
37+
38+
def _parse(self, kwargs):
39+
for k, v in kwargs.items():
40+
if not hasattr(self, k):
41+
warnings.warn("Warning: opt has not attribut %s" % k)
42+
setattr(self, k, v)
43+
44+
print('=========user config==========')
45+
pprint(self._state_dict())
46+
print('============end===============')
47+
48+
def _state_dict(self):
49+
return {k: getattr(self, k) for k, _ in DefaultConfig.__dict__.items()
50+
if not k.startswith('_')}
51+
52+
53+
opt = DefaultConfig()
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# encoding: utf-8
2+
"""
3+
@author: xyliao
4+
@contact: xyliao1993@qq.com
5+
"""
6+
import copy
7+
8+
import torch
9+
from config import opt
10+
from mxtorch import meter
11+
from mxtorch import transforms as tfs
12+
from mxtorch.trainer import *
13+
from mxtorch.vision import model_zoo
14+
from torch import nn
15+
from torch.autograd import Variable
16+
from torch.utils.data import DataLoader
17+
from torchvision.datasets import ImageFolder
18+
from tqdm import tqdm
19+
20+
train_tf = tfs.Compose([
21+
tfs.RandomResizedCrop(224),
22+
tfs.RandomHorizontalFlip(),
23+
tfs.ToTensor(),
24+
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
25+
])
26+
27+
28+
def test_tf(img):
29+
img = tfs.Resize(256)(img)
30+
img, _ = tfs.CenterCrop(224)(img)
31+
normalize = tfs.Compose([
32+
tfs.ToTensor(),
33+
tfs.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
34+
])
35+
img = normalize(img)
36+
return img
37+
38+
39+
def get_train_data():
40+
train_set = ImageFolder(opt.train_data_path, train_tf)
41+
return DataLoader(train_set, opt.batch_size, True, num_workers=opt.num_workers)
42+
43+
44+
def get_test_data():
45+
test_set = ImageFolder(opt.test_data_path, test_tf)
46+
return DataLoader(test_set, opt.batch_size, True, num_workers=opt.num_workers)
47+
48+
49+
def get_model():
50+
model = model_zoo.resnet50(pretrained=True)
51+
model.fc = nn.Linear(2048, 2)
52+
if opt.use_gpu:
53+
model = model.cuda(opt.ctx)
54+
return model
55+
56+
57+
def get_loss(score, label):
58+
return nn.CrossEntropyLoss()(score, label)
59+
60+
61+
def get_optimizer(model):
62+
optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum,
63+
weight_decay=opt.weight_decay)
64+
return ScheduledOptim(optimizer)
65+
66+
67+
class FineTuneTrainer(Trainer):
68+
def __init__(self):
69+
model = get_model()
70+
criterion = get_loss
71+
optimizer = get_optimizer(model)
72+
super().__init__(model, criterion, optimizer)
73+
74+
self.metric_meter['loss'] = meter.AverageValueMeter()
75+
self.metric_meter['acc'] = meter.AverageValueMeter()
76+
77+
def train(self, train_data):
78+
self.model.train()
79+
for data in tqdm(train_data):
80+
img, label = data
81+
if opt.use_gpu:
82+
img = img.cuda(opt.ctx)
83+
label = label.cuda(opt.ctx)
84+
img = Variable(img)
85+
label = Variable(label)
86+
87+
# Forward.
88+
score = self.model(img)
89+
loss = self.criterion(score, label)
90+
91+
# Backward.
92+
self.optimizer.zero_grad()
93+
loss.backward()
94+
self.optimizer.step()
95+
96+
# Update meters.
97+
acc = (score.max(1)[1] == label).float().mean()
98+
self.metric_meter['loss'].add(loss.data[0])
99+
self.metric_meter['acc'].add(acc.data[0])
100+
101+
# Update to tensorboard.
102+
# if (self.n_iter + 1) % opt.plot_freq == 0:
103+
# self.writer.add_scalars('loss', {'train': self.metric_meter['loss'].value()[0]}, self.n_plot)
104+
# self.writer.add_scalars('acc', {'train': self.metric_meter['acc'].value()[0], self.n_plot})
105+
# self.n_plot += 1
106+
self.n_iter += 1
107+
108+
# Log the train metric dict to print result.
109+
self.metric_log['train loss'] = self.metric_meter['loss'].value()[0]
110+
self.metric_log['train acc'] = self.metric_meter['acc'].value()[0]
111+
112+
def test(self, test_data):
113+
self.model.eval()
114+
for data in tqdm(test_data):
115+
img, label = data
116+
if opt.use_gpu:
117+
img = img.cuda(opt.ctx)
118+
label = label.cuda(opt.ctx)
119+
img = Variable(img, volatile=True)
120+
label = Variable(label, volatile=True)
121+
122+
score = self.model(img)
123+
loss = self.criterion(score, label)
124+
acc = (score.max(1)[1] == label).float().mean()
125+
126+
self.metric_meter['loss'].add(loss.data[0])
127+
self.metric_meter['acc'].add(acc.data[0])
128+
129+
# Update to tensorboard.
130+
# self.writer.add_scalars('loss', {'test': self.metric_meter['loss'].value()[0]}, self.n_plot)
131+
# self.writer.add_scalars('acc', {'test': self.metric_meter['acc'].value()[0]}, self.n_plot)
132+
# self.n_plot += 1
133+
134+
# Log the test metric to dict.
135+
self.metric_log['test loss'] = self.metric_meter['loss'].value()[0]
136+
self.metric_log['test acc'] = self.metric_meter['acc'].value()[0]
137+
138+
def get_best_model(self):
139+
if self.metric_log['test loss'] < self.best_metric:
140+
self.best_model = copy.deepcopy(self.model.state_dict())
141+
self.best_metric = self.metric_log['test loss']
142+
143+
144+
def train(**kwargs):
145+
opt._parse(kwargs)
146+
147+
train_data = get_train_data()
148+
test_data = get_test_data()
149+
150+
fine_tune_trainer = FineTuneTrainer()
151+
fine_tune_trainer.fit(train_data, test_data)
152+
153+
154+
if __name__ == '__main__':
155+
import fire
156+
157+
fire.Fire()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /