Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 9fbf3da

Browse files
committed
add FCN
1 parent 3cd59a5 commit 9fbf3da

File tree

5 files changed

+270
-1
lines changed

5 files changed

+270
-1
lines changed

‎FCN/TinyFCN.py‎

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from tqdm import tqdm
2+
3+
from FCN.VOC2012Dataset import VOC2012SegDataIter
4+
import torch
5+
from torch import nn, optim
6+
from torch.nn import functional as F
7+
import numpy as np
8+
from torchvision import models
9+
10+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11+
12+
num_classes = 21
13+
14+
15+
def bilinear_kernel(in_channels, out_channels, kernel_size):
16+
factor = (kernel_size + 1) // 2
17+
if kernel_size % 2 == 1:
18+
center = factor - 1
19+
else:
20+
center = factor - 0.5
21+
og = np.ogrid[:kernel_size, :kernel_size]
22+
filt = (1 - abs(og[0] - center) / factor) * \
23+
(1 - abs(og[1] - center) / factor)
24+
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
25+
dtype='float32')
26+
weight[range(in_channels), range(out_channels), :, :] = filt
27+
print(weight.shape)
28+
return torch.tensor(weight)
29+
30+
31+
if __name__ == '__main__':
32+
train_iter, val_iter = VOC2012SegDataIter(16, (320, 480), 2, 200)
33+
resnet18 = models.resnet18(pretrained=True)
34+
resnet18_modules = [layer for layer in resnet18.modules()]
35+
net = nn.Sequential()
36+
for i, layer in enumerate(resnet18_modules[:-2]):
37+
net.add_module(str(i), layer)
38+
39+
# print(layer)
40+
net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
41+
net.add_module("ConvTranspose2d",
42+
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))
43+
44+
net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
45+
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)
46+
net = net.to(device)
47+
optimizer = optim.Adam(net.parameters(), lr=1e-3)
48+
lossFN = nn.CrossEntropyLoss()
49+
50+
num_epochs = 10
51+
for epoch in range(num_epochs):
52+
sum_loss = 0
53+
sum_acc = 0
54+
batch_count = 0
55+
n = 0
56+
for X, y in tqdm(train_iter):
57+
print(X.shape)
58+
X = X.to(device)
59+
y = y.to(device)
60+
61+
y_pred = net(X)
62+
loss = lossFN(y_pred, y)
63+
64+
optimizer.zero_grad()
65+
loss.backward()
66+
optimizer.step()
67+
68+
sum_loss += loss.cpu().item()
69+
sum_acc += (y_pred.argmax(dim=1) == y).sum().cpu().item()
70+
n += y.shape[0]
71+
batch_count += 1
72+
print("epoch %d: loss=%.4f \t acc=%.4f" % (epoch + 1, sum_loss / n, sum_acc / n))
73+
74+

‎FCN/VOC2012Dataset.py‎

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
import torchvision
3+
from PIL import Image
4+
import numpy as np
5+
6+
7+
def voc_label_indices(colormap, colormap2label):
8+
"""
9+
convert colormap (PIL image) to colormap2label (uint8 tensor).
10+
"""
11+
colormap = np.array(colormap.convert("RGB")).astype('int32')
12+
idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
13+
+ colormap[:, :, 2])
14+
return colormap2label[idx]
15+
16+
17+
def read_voc_images(root="./dataset/VOCdevkit/VOC2012",
18+
is_train=True, max_num=None):
19+
txt_fname = '%s/ImageSets/Segmentation/%s' % (
20+
root, 'train.txt' if is_train else 'val.txt')
21+
with open(txt_fname, 'r') as f:
22+
images = f.read().split()
23+
if max_num is not None:
24+
images = images[:min(max_num, len(images))]
25+
features, labels = [None] * len(images), [None] * len(images)
26+
for i, fname in enumerate(images):
27+
features[i] = Image.open('%s/JPEGImages/%s.jpg' % (root, fname)).convert("RGB")
28+
labels[i] = Image.open('%s/SegmentationClass/%s.png' % (root, fname)).convert("RGB")
29+
return features, labels # PIL image
30+
31+
32+
def voc_rand_crop(feature, label, height, width):
33+
"""
34+
Random crop feature (PIL image) and label (PIL image).
35+
"""
36+
i, j, h, w = torchvision.transforms.RandomCrop.get_params(
37+
feature, output_size=(height, width))
38+
39+
feature = torchvision.transforms.functional.crop(feature, i, j, h, w)
40+
label = torchvision.transforms.functional.crop(label, i, j, h, w)
41+
42+
return feature, label
43+
44+
45+
class VOCSegDataset(torch.utils.data.Dataset):
46+
def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None):
47+
"""
48+
crop_size: (h, w)
49+
"""
50+
self.rgb_mean = np.array([0.485, 0.456, 0.406])
51+
self.rgb_std = np.array([0.229, 0.224, 0.225])
52+
self.tsf = torchvision.transforms.Compose([
53+
torchvision.transforms.ToTensor(),
54+
torchvision.transforms.Normalize(mean=self.rgb_mean,
55+
std=self.rgb_std)
56+
])
57+
58+
self.crop_size = crop_size # (h, w)
59+
features, labels = read_voc_images(root=voc_dir,
60+
is_train=is_train,
61+
max_num=max_num)
62+
self.features = self.filter(features) # PIL image
63+
self.labels = self.filter(labels) # PIL image
64+
self.colormap2label = colormap2label
65+
print('read ' + str(len(self.features)) + ' valid examples')
66+
67+
def filter(self, imgs):
68+
return [img for img in imgs if (
69+
img.size[1] >= self.crop_size[0] and
70+
img.size[0] >= self.crop_size[1])]
71+
72+
def __getitem__(self, idx):
73+
feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
74+
*self.crop_size)
75+
76+
return (self.tsf(feature),
77+
voc_label_indices(label, self.colormap2label))
78+
79+
def __len__(self):
80+
return len(self.features)
81+
82+
83+
def VOC2012SegDataIter(batch_size=64, crop_size=(320, 480), num_workers=4, max_num=None):
84+
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
85+
[0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
86+
[64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
87+
[64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
88+
[0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
89+
[0, 64, 128]]
90+
VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
91+
'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
92+
'diningtable', 'dog', 'horse', 'motorbike', 'person',
93+
'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']
94+
95+
colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8)
96+
for i, colormap in enumerate(VOC_COLORMAP):
97+
colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
98+
99+
voc_train = VOCSegDataset(True, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num)
100+
voc_val = VOCSegDataset(False, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num)
101+
train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True, drop_last=True,
102+
num_workers=num_workers)
103+
val_iter = torch.utils.data.DataLoader(voc_val, batch_size, drop_last=True, num_workers=num_workers)
104+
return train_iter, val_iter

‎FCN/__init__.py‎

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from tqdm import tqdm
2+
3+
from FCN.VOC2012Dataset import VOC2012SegDataIter
4+
import torch
5+
from torch import nn, optim
6+
from torch.nn import functional as F
7+
import numpy as np
8+
from torchvision import models
9+
10+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11+
12+
num_classes = 21
13+
14+
15+
def bilinear_kernel(in_channels, out_channels, kernel_size):
16+
factor = (kernel_size + 1) // 2
17+
if kernel_size % 2 == 1:
18+
center = factor - 1
19+
else:
20+
center = factor - 0.5
21+
og = np.ogrid[:kernel_size, :kernel_size]
22+
filt = (1 - abs(og[0] - center) / factor) * \
23+
(1 - abs(og[1] - center) / factor)
24+
weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
25+
dtype='float32')
26+
weight[range(in_channels), range(out_channels), :, :] = filt
27+
return torch.tensor(weight)
28+
29+
30+
def cross_entropy2d(input, target, weight=None, size_average=True):
31+
# input: (n, c, h, w), target: (n, h, w)
32+
n, c, h, w = input.size()
33+
# log_p: (n, c, h, w)
34+
log_p = F.log_softmax(input, dim=1)
35+
# log_p: (n*h*w, c)
36+
log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous()
37+
log_p = log_p[target.view(n, h, w, 1).repeat(1, 1, 1, c) >= 0]
38+
log_p = log_p.view(-1, c)
39+
# target: (n*h*w,)
40+
mask = target >= 0
41+
target = target[mask]
42+
loss = F.nll_loss(log_p, target, weight=weight, reduction='sum')
43+
if size_average:
44+
loss /= mask.data.sum()
45+
return loss
46+
47+
48+
if __name__ == '__main__':
49+
batch_size = 4
50+
train_iter, val_iter = VOC2012SegDataIter(batch_size, (320, 480), 2, 200)
51+
resnet18 = models.resnet18(pretrained=True)
52+
resnet18_modules = [layer for layer in resnet18.children()]
53+
net = nn.Sequential()
54+
for i, layer in enumerate(resnet18_modules[:-2]):
55+
net.add_module(str(i), layer)
56+
57+
net.add_module("LinearTranspose", nn.Conv2d(512, num_classes, kernel_size=1))
58+
net.add_module("ConvTranspose2d",
59+
nn.ConvTranspose2d(num_classes, num_classes, kernel_size=64, padding=16, stride=32))
60+
61+
net[-1].weight = nn.Parameter(bilinear_kernel(num_classes, num_classes, 64), True)
62+
net[-2].weight = nn.init.xavier_uniform_(net[-2].weight)
63+
64+
net = net.to(device)
65+
optimizer = optim.Adam(net.parameters(), lr=1e-3)
66+
lossFN = nn.CrossEntropyLoss()
67+
68+
num_epochs = 10
69+
for epoch in range(num_epochs):
70+
sum_loss = 0
71+
sum_acc = 0
72+
batch_count = 0
73+
n = 0
74+
for X, y in tqdm(train_iter):
75+
X = X.to(device)
76+
y = y.to(device)
77+
y_pred = net(X).reshape((batch_size, 21, -1))
78+
y = y.reshape(batch_size, -1)
79+
print(y_pred.shape, y.shape)
80+
loss = lossFN(y_pred, y)
81+
# loss = cross_entropy2d(y_pred, y)
82+
83+
optimizer.zero_grad()
84+
loss.backward()
85+
optimizer.step()
86+
87+
sum_loss += loss.cpu().item()
88+
sum_acc += (y_pred.argmax(dim=1) == y).sum().cpu().item()
89+
n += y.shape[0]
90+
batch_count += 1
91+
print("epoch %d: loss=%.4f \t acc=%.4f" % (epoch + 1, sum_loss / n, sum_acc / n))

‎SSD/ssd.pt‎

0 Bytes
Binary file not shown.

‎SSD/train.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
88

9-
109
def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):
1110
# print(cls_preds.shape, cls_labels.shape, bbox_preds.shape, bbox_labels.shape, bbox_masks.shape)
1211
cls_preds = cls_preds.to(device)
@@ -58,6 +57,7 @@ def bbox_eval(bbox_preds, bbox_labels, bbox_masks):
5857
# 根据类别和偏移量的预测和标注计算损失函数
5958
cls_preds = cls_preds.reshape(-1, 2)
6059
cls_labels = cls_labels.reshape(-1)
60+
print(cls_preds.shape, cls_labels.shape)
6161
cls_loss, bbox_loss = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks)
6262
l = cls_loss + bbox_loss
6363
optimizer.zero_grad()

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /