Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f134f59

Browse files
author
ZijunDeng
committed
Merge remote-tracking branch 'origin/master'
2 parents 7bf7113 + 890d0e5 commit f134f59

File tree

9 files changed

+479
-39
lines changed

9 files changed

+479
-39
lines changed

‎README.md‎

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@ This repository contains some models for semantic segmentation and the pipeline
33
implemented in PyTorch.
44

55
## Models
6-
1. Vanilla FCN: FCN32, FCN16, FCN8, in the versions of VGG, ResNet and DenseNet respectively.
7-
2. U-Net
8-
3. SegNet
9-
4. PSPNet
10-
5. GCN (global convolutional network)
6+
1. Vanilla FCN: FCN32, FCN16, FCN8, in the versions of VGG, ResNet and DenseNet respectively.
7+
([Fully convolutional networks for semantic segmentation](http://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Long_Fully_Convolutional_Networks_2015_CVPR_paper.pdf))
8+
2. U-Net ([U-net: Convolutional networks for biomedical image segmentation](https://arxiv.org/pdf/1505.04597))
9+
3. SegNet ([Segnet: A deep convolutional encoder-decoder architecture for image segmentation](https://arxiv.org/pdf/1511.00561))
10+
4. PSPNet ([Pyramid scene parsing network](https://arxiv.org/pdf/1612.01105))
11+
5. GCN ([Large Kernel Matters](https://arxiv.org/pdf/1703.02719))
12+
6. DUC, HDC ([understanding convolution for semantic segmentation](https://arxiv.org/pdf/1702.08502.pdf))
13+
7. Deformable Convolution Network (in PSPNet version) ([Deformable Convolutional Networks](https://arxiv.org/pdf/1703.06211))
1114

1215
## Visualization
1316
Use powerful visualization of TensorBoard for PyTorch. [Here](https://github.com/lanpa/tensorboard-pytorch) to install.
@@ -25,5 +28,4 @@ I have borrowed some code from these nice repositories: [[1]](https://github.com
2528
1. DeepLab v3
2629
2. RefineNet
2730
3. CRFAsRNN
28-
4. Some evaluation criterion (e.g. mIOU)
29-
5. More dataset
31+
4. More dataset (e.g. ADE)

‎models/__init__.py‎

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from fcn16 import *
2-
from fcn32 import *
3-
from fcn8 import *
4-
from psp_net import PSPNet
5-
from seg_net import *
6-
from u_net import *
1+
from fcn16 import FCN16ResNet, FCN16DenseNet, FCN16VGG
2+
from fcn32 import FCN32ResNet, FCN32DenseNet, FCN32VGG
3+
from fcn8 import FCN8ResNet, FCN8DenseNet, FCN8VGG
4+
from psp_net import PSPNet, PSPNetDeform
5+
from seg_net import SegNet
6+
from u_net import UNet
7+
from gcn import GCN
8+
from duc_hdc import ResNetDUC, ResNetDUCHDC

‎models/duc_hdc.py‎

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import torch
2+
from torch import nn
3+
from torchvision import models
4+
5+
from .config import res152_path
6+
7+
8+
class DenseUpsamplingConvModule(nn.Module):
9+
def __init__(self, down_factor, in_dim, num_classes):
10+
super(DenseUpsamplingConvModule, self).__init__()
11+
upsample_dim = (down_factor ** 2) * num_classes
12+
self.conv = nn.Conv2d(in_dim, upsample_dim, kernel_size=3, padding=1)
13+
self.bn = nn.BatchNorm2d(upsample_dim)
14+
self.relu = nn.ReLU()
15+
self.pixel_shuffle = nn.PixelShuffle(down_factor)
16+
17+
def forward(self, x):
18+
x = self.conv(x)
19+
x = self.bn(x)
20+
x = self.relu(x)
21+
x = self.pixel_shuffle(x)
22+
return x
23+
24+
25+
class ResNetDUC(nn.Module):
26+
# the size of image should be multiple of 8
27+
def __init__(self, num_classes, pretrained=True):
28+
super(ResNetDUC, self).__init__()
29+
resnet = models.resnet152()
30+
if pretrained:
31+
resnet.load_state_dict(torch.load(res152_path))
32+
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
33+
self.layer1 = resnet.layer1
34+
self.layer2 = resnet.layer2
35+
self.layer3 = resnet.layer3
36+
self.layer4 = resnet.layer4
37+
38+
for n, m in self.layer3.named_modules():
39+
if 'conv2' in n:
40+
m.dilation = (2, 2)
41+
m.padding = (2, 2)
42+
m.stride = (1, 1)
43+
elif 'downsample.0' in n:
44+
m.stride = (1, 1)
45+
for n, m in self.layer4.named_modules():
46+
if 'conv2' in n:
47+
m.dilation = (4, 4)
48+
m.padding = (4, 4)
49+
m.stride = (1, 1)
50+
elif 'downsample.0' in n:
51+
m.stride = (1, 1)
52+
53+
self.duc = DenseUpsamplingConvModule(8, 2048, num_classes)
54+
55+
def forward(self, x):
56+
x = self.layer0(x)
57+
x = self.layer1(x)
58+
x = self.layer2(x)
59+
x = self.layer3(x)
60+
x = self.layer4(x)
61+
x = self.duc(x)
62+
return x
63+
64+
65+
class ResNetDUCHDC(nn.Module):
66+
# the size of image should be multiple of 8
67+
def __init__(self, num_classes, pretrained=True):
68+
super(ResNetDUCHDC, self).__init__()
69+
resnet = models.resnet152()
70+
if pretrained:
71+
resnet.load_state_dict(torch.load(res152_path))
72+
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
73+
self.layer1 = resnet.layer1
74+
self.layer2 = resnet.layer2
75+
self.layer3 = resnet.layer3
76+
self.layer4 = resnet.layer4
77+
78+
for n, m in self.layer3.named_modules():
79+
if 'conv2' in n or 'downsample.0' in n:
80+
m.stride = (1, 1)
81+
for n, m in self.layer4.named_modules():
82+
if 'conv2' in n or 'downsample.0' in n:
83+
m.stride = (1, 1)
84+
layer3_group_config = [1, 2, 5, 9]
85+
for idx in range(len(self.layer3)):
86+
self.layer3[idx].conv2.dilation = (layer3_group_config[idx % 4], layer3_group_config[idx % 4])
87+
self.layer3[idx].conv2.padding = (layer3_group_config[idx % 4], layer3_group_config[idx % 4])
88+
layer4_group_config = [5, 9, 17]
89+
for idx in range(len(self.layer4)):
90+
self.layer4[idx].conv2.dilation = (layer4_group_config[idx], layer4_group_config[idx])
91+
self.layer4[idx].conv2.padding = (layer4_group_config[idx], layer4_group_config[idx])
92+
93+
self.duc = DenseUpsamplingConvModule(8, 2048, num_classes)
94+
95+
def forward(self, x):
96+
x = self.layer0(x)
97+
x = self.layer1(x)
98+
x = self.layer2(x)
99+
x = self.layer3(x)
100+
x = self.layer4(x)
101+
x = self.duc(x)
102+
return x

‎models/gcn.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ def forward(self, x):
3333

3434

3535
class BoundaryRefineModule(nn.Module):
36-
def __init__(self, in_dim):
36+
def __init__(self, dim):
3737
super(BoundaryRefineModule, self).__init__()
3838
self.relu = nn.ReLU()
39-
self.conv1 = nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1)
40-
self.conv2 = nn.Conv2d(in_dim, in_dim, kernel_size=3, padding=1)
39+
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
40+
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
4141

4242
def forward(self, x):
4343
residual = self.conv1(x)
@@ -50,8 +50,8 @@ def forward(self, x):
5050
class GCN(nn.Module):
5151
def __init__(self, num_classes, input_size, pretrained=True):
5252
super(GCN, self).__init__()
53-
resnet = models.resnet152()
5453
self.input_size = input_size
54+
resnet = models.resnet152()
5555
if pretrained:
5656
resnet.load_state_dict(torch.load(res152_path))
5757
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)

‎models/layer.py‎

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn.functional as F
4+
from torch import nn
5+
6+
7+
class Conv2dDeformable(nn.Module):
8+
def __init__(self, regular_filter, cuda=True):
9+
super(Conv2dDeformable, self).__init__()
10+
assert isinstance(regular_filter, nn.Conv2d)
11+
self.regular_filter = regular_filter
12+
self.offset_filter = nn.Conv2d(regular_filter.in_channels, 2 * regular_filter.in_channels, kernel_size=3,
13+
padding=1, bias=False)
14+
self.offset_filter.weight.data.normal_(0, 0.0005)
15+
self.input_shape = None
16+
self.grid_w = None
17+
self.grid_h = None
18+
self.cuda = cuda
19+
20+
def forward(self, x):
21+
x_shape = x.size() # (b, c, h, w)
22+
offset = self.offset_filter(x) # (b, 2*c, h, w)
23+
offset_w, offset_h = torch.split(offset, self.regular_filter.in_channels, 1) # (b, c, h, w)
24+
offset_w = offset_w.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w)
25+
offset_h = offset_h.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])) # (b*c, h, w)
26+
if not self.input_shape or self.input_shape != x_shape:
27+
self.input_shape = x_shape
28+
grid_w, grid_h = np.meshgrid(np.linspace(-1, 1, x_shape[3]), np.linspace(-1, 1, x_shape[2])) # (h, w)
29+
grid_w = torch.Tensor(grid_w)
30+
grid_h = torch.Tensor(grid_h)
31+
if self.cuda:
32+
grid_w = grid_w.cuda()
33+
grid_h = grid_h.cuda()
34+
self.grid_w = nn.Parameter(grid_w)
35+
self.grid_h = nn.Parameter(grid_h)
36+
offset_w = offset_w + self.grid_w # (b*c, h, w)
37+
offset_h = offset_h + self.grid_h # (b*c, h, w)
38+
x = x.contiguous().view(-1, int(x_shape[2]), int(x_shape[3])).unsqueeze(1) # (b*c, 1, h, w)
39+
x = F.grid_sample(x, torch.stack((offset_h, offset_w), 3)) # (b*c, h, w)
40+
x = x.contiguous().view(-1, int(x_shape[1]), int(x_shape[2]), int(x_shape[3])) # (b, c, h, w)
41+
x = self.regular_filter(x)
42+
return x

‎models/psp_net.py‎

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch import nn
66
from torchvision import models
77

8+
from layer import Conv2dDeformable
89
from utils.training import initialize_weights
910
from .config import res152_path
1011

@@ -20,7 +21,7 @@ def __init__(self, in_size, in_dim, reduction_dim, setting):
2021
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
2122
nn.BatchNorm2d(reduction_dim, momentum=.95),
2223
nn.ReLU(),
23-
nn.UpsamplingBilinear2d(size=in_size)
24+
nn.Upsample(size=in_size, mode='bilinear')
2425
))
2526
self.features = nn.ModuleList(self.features)
2627

@@ -93,5 +94,72 @@ def forward(self, x):
9394
x = self.ppm(x)
9495
x = self.final(x)
9596
if self.training and self.use_aux:
96-
return F.upsample_bilinear(x, self.input_size), F.upsample_bilinear(aux, self.input_size)
97-
return F.upsample_bilinear(x, self.input_size)
97+
return F.upsample(x, self.input_size, mode='bilinear'), F.upsample(aux, self.input_size, mode='bilinear')
98+
return F.upsample(x, self.input_size, mode='bilinear')
99+
100+
101+
class PSPNetDeform(nn.Module):
102+
def __init__(self, num_classes, input_size, pretrained=True, use_aux=True):
103+
super(PSPNetDeform, self).__init__()
104+
self.input_size = input_size
105+
self.use_aux = use_aux
106+
resnet = models.resnet152()
107+
if pretrained:
108+
resnet.load_state_dict(torch.load(res152_path))
109+
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
110+
self.layer1 = resnet.layer1
111+
self.layer2 = resnet.layer2
112+
self.layer3 = resnet.layer3
113+
self.layer4 = resnet.layer4
114+
115+
for n, m in self.layer3.named_modules():
116+
if 'conv2' in n:
117+
m.padding = (1, 1)
118+
m.stride = (1, 1)
119+
elif 'downsample.0' in n:
120+
m.stride = (1, 1)
121+
for n, m in self.layer4.named_modules():
122+
if 'conv2' in n:
123+
m.padding = (1, 1)
124+
m.stride = (1, 1)
125+
elif 'downsample.0' in n:
126+
m.stride = (1, 1)
127+
for idx in range(len(self.layer3)):
128+
self.layer3[idx].conv2 = Conv2dDeformable(self.layer3[idx].conv2)
129+
for idx in range(len(self.layer4)):
130+
self.layer4[idx].conv2 = Conv2dDeformable(self.layer4[idx].conv2)
131+
self.ppm = PyramidPoolingModule((int(math.ceil(input_size[0] / 8.0)), int(math.ceil(input_size[1] / 8.0))),
132+
2048, 512, (1, 2, 3, 6))
133+
self.final = nn.Sequential(
134+
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
135+
nn.BatchNorm2d(512, momentum=.95),
136+
nn.ReLU(),
137+
nn.Dropout(0.1),
138+
nn.Conv2d(512, num_classes, kernel_size=1)
139+
)
140+
if use_aux:
141+
self.aux_logits = nn.Sequential(
142+
PyramidPoolingModule((int(math.ceil(input_size[0] / 8.0)), int(math.ceil(input_size[1] / 8.0))),
143+
1024, 256, (1, 2, 3, 6)),
144+
nn.Conv2d(2048, 256, kernel_size=3, padding=1, bias=False),
145+
nn.BatchNorm2d(256, momentum=.95),
146+
nn.ReLU(),
147+
nn.Dropout(0.1),
148+
nn.Conv2d(256, num_classes, kernel_size=1)
149+
)
150+
151+
initialize_weights(self.ppm, self.final)
152+
153+
def forward(self, x):
154+
x = self.layer0(x)
155+
x = self.layer1(x)
156+
x = self.layer2(x)
157+
x = self.layer3(x)
158+
if self.training and self.use_aux:
159+
aux = self.aux_logits(x)
160+
x = self.layer4(x)
161+
x = self.ppm(x)
162+
x = self.final(x)
163+
if self.training and self.use_aux:
164+
return F.upsample(x, self.input_size, mode='bilinear'), F.upsample(aux, self.input_size, mode='bilinear')
165+
return F.upsample(x, self.input_size, mode='bilinear')

‎train_gcn.py‎

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,13 @@
3030
train_args = {
3131
'batch_size': 16,
3232
'epoch_num': 800, # I stop training only when val loss doesn't seem to decrease anymore, so just set a large value.
33-
'pretrained_lr': 1e-4, # used for the pretrained layers of model
34-
'new_lr': 1e-3, # used for the newly added layers of model
33+
'pretrained_lr': 1e-7, # used for the pretrained layers of model
34+
'new_lr': 1e-7, # used for the newly added layers of model
3535
'weight_decay': 5e-4,
36-
'snapshot': '', # empty string denotes initial training, otherwise it should be a string of snapshot name
36+
'snapshot': 'epoch_297_loss_0.8282_mean_iu_0.4390_lr_0.00000100.pth', # empty string denotes initial training, otherwise it should be a string of snapshot name
3737
'print_freq': 30,
3838
'input_size': (224, 448), # (height, width)
3939
}
40-
4140
val_args = {
4241
'batch_size': 8,
4342
'img_sample_rate': 0.15
@@ -97,23 +96,23 @@ def main():
9796
optimizer = optim.SGD([
9897
{'params': [param for name, param in net.named_parameters() if
9998
name[-4:] == 'bias' and ('gcm' in name or 'brm' in name)],
100-
'lr': train_args['new_lr']},
99+
'lr': 2*train_args['new_lr']},
101100
{'params': [param for name, param in net.named_parameters() if
102101
name[-4:] != 'bias' and ('gcm' in name or 'brm' in name)],
103102
'lr': train_args['new_lr'], 'weight_decay': train_args['weight_decay']},
104103
{'params': [param for name, param in net.named_parameters() if
105104
name[-4:] == 'bias' and not ('gcm' in name or 'brm' in name)],
106-
'lr': train_args['pretrained_lr']},
105+
'lr': 2*train_args['pretrained_lr']},
107106
{'params': [param for name, param in net.named_parameters() if
108107
name[-4:] != 'bias' and not ('gcm' in name or 'brm' in name)],
109108
'lr': train_args['pretrained_lr'], 'weight_decay': train_args['weight_decay']}
110109
], momentum=0.9, nesterov=True)
111110

112111
if len(train_args['snapshot']) > 0:
113112
optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
114-
optimizer.param_groups[0]['lr'] = train_args['new_lr']
113+
optimizer.param_groups[0]['lr'] = 2*train_args['new_lr']
115114
optimizer.param_groups[1]['lr'] = train_args['new_lr']
116-
optimizer.param_groups[2]['lr'] = train_args['pretrained_lr']
115+
optimizer.param_groups[2]['lr'] = 2*train_args['pretrained_lr']
117116
optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']
118117

119118
if not os.path.exists(ckpt_path):

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /