|
| 1 | +import torch |
| 2 | +import torchvision |
| 3 | +from PIL import Image |
| 4 | +import numpy as np |
| 5 | + |
| 6 | + |
| 7 | +def voc_label_indices(colormap, colormap2label): |
| 8 | + """ |
| 9 | + convert colormap (PIL image) to colormap2label (uint8 tensor). |
| 10 | + """ |
| 11 | + colormap = np.array(colormap.convert("RGB")).astype('int32') |
| 12 | + idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 |
| 13 | + + colormap[:, :, 2]) |
| 14 | + return colormap2label[idx] |
| 15 | + |
| 16 | + |
| 17 | +def read_voc_images(root="./dataset/VOCdevkit/VOC2012", |
| 18 | + is_train=True, max_num=None): |
| 19 | + txt_fname = '%s/ImageSets/Segmentation/%s' % ( |
| 20 | + root, 'train.txt' if is_train else 'val.txt') |
| 21 | + with open(txt_fname, 'r') as f: |
| 22 | + images = f.read().split() |
| 23 | + if max_num is not None: |
| 24 | + images = images[:min(max_num, len(images))] |
| 25 | + features, labels = [None] * len(images), [None] * len(images) |
| 26 | + for i, fname in enumerate(images): |
| 27 | + features[i] = Image.open('%s/JPEGImages/%s.jpg' % (root, fname)).convert("RGB") |
| 28 | + labels[i] = Image.open('%s/SegmentationClass/%s.png' % (root, fname)).convert("RGB") |
| 29 | + return features, labels # PIL image |
| 30 | + |
| 31 | + |
| 32 | +def voc_rand_crop(feature, label, height, width): |
| 33 | + """ |
| 34 | + Random crop feature (PIL image) and label (PIL image). |
| 35 | + """ |
| 36 | + i, j, h, w = torchvision.transforms.RandomCrop.get_params( |
| 37 | + feature, output_size=(height, width)) |
| 38 | + |
| 39 | + feature = torchvision.transforms.functional.crop(feature, i, j, h, w) |
| 40 | + label = torchvision.transforms.functional.crop(label, i, j, h, w) |
| 41 | + |
| 42 | + return feature, label |
| 43 | + |
| 44 | + |
| 45 | +class VOCSegDataset(torch.utils.data.Dataset): |
| 46 | + def __init__(self, is_train, crop_size, voc_dir, colormap2label, max_num=None): |
| 47 | + """ |
| 48 | + crop_size: (h, w) |
| 49 | + """ |
| 50 | + self.rgb_mean = np.array([0.485, 0.456, 0.406]) |
| 51 | + self.rgb_std = np.array([0.229, 0.224, 0.225]) |
| 52 | + self.tsf = torchvision.transforms.Compose([ |
| 53 | + torchvision.transforms.ToTensor(), |
| 54 | + torchvision.transforms.Normalize(mean=self.rgb_mean, |
| 55 | + std=self.rgb_std) |
| 56 | + ]) |
| 57 | + |
| 58 | + self.crop_size = crop_size # (h, w) |
| 59 | + features, labels = read_voc_images(root=voc_dir, |
| 60 | + is_train=is_train, |
| 61 | + max_num=max_num) |
| 62 | + self.features = self.filter(features) # PIL image |
| 63 | + self.labels = self.filter(labels) # PIL image |
| 64 | + self.colormap2label = colormap2label |
| 65 | + print('read ' + str(len(self.features)) + ' valid examples') |
| 66 | + |
| 67 | + def filter(self, imgs): |
| 68 | + return [img for img in imgs if ( |
| 69 | + img.size[1] >= self.crop_size[0] and |
| 70 | + img.size[0] >= self.crop_size[1])] |
| 71 | + |
| 72 | + def __getitem__(self, idx): |
| 73 | + feature, label = voc_rand_crop(self.features[idx], self.labels[idx], |
| 74 | + *self.crop_size) |
| 75 | + |
| 76 | + return (self.tsf(feature), |
| 77 | + voc_label_indices(label, self.colormap2label)) |
| 78 | + |
| 79 | + def __len__(self): |
| 80 | + return len(self.features) |
| 81 | + |
| 82 | + |
| 83 | +def VOC2012SegDataIter(batch_size=64, crop_size=(320, 480), num_workers=4, max_num=None): |
| 84 | + VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], |
| 85 | + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], |
| 86 | + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], |
| 87 | + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], |
| 88 | + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], |
| 89 | + [0, 64, 128]] |
| 90 | + VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat', |
| 91 | + 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', |
| 92 | + 'diningtable', 'dog', 'horse', 'motorbike', 'person', |
| 93 | + 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor'] |
| 94 | + |
| 95 | + colormap2label = torch.zeros(256 ** 3, dtype=torch.uint8) |
| 96 | + for i, colormap in enumerate(VOC_COLORMAP): |
| 97 | + colormap2label[(colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i |
| 98 | + |
| 99 | + voc_train = VOCSegDataset(True, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num) |
| 100 | + voc_val = VOCSegDataset(False, crop_size, "../dataset/VOCdevkit/VOC2012", colormap2label, max_num) |
| 101 | + train_iter = torch.utils.data.DataLoader(voc_train, batch_size, shuffle=True, drop_last=True, |
| 102 | + num_workers=num_workers) |
| 103 | + val_iter = torch.utils.data.DataLoader(voc_val, batch_size, drop_last=True, num_workers=num_workers) |
| 104 | + return train_iter, val_iter |
0 commit comments