Custom segmentation dataset class for torchvision. Applies data augmentation to both images and segmentations.
Can be used with torchvision.transforms:
from utils import SegmentationDataset transform = transforms.Compose( [ transforms.Resize((224, 224)), transforms.RandomAffine( degrees=15, translate=(0.05, 0.05), scale=(0.95, 1.05), resample=2, fillcolor=0, ), transforms.ColorJitter( brightness=0.15, contrast=0.15, saturation=0.15, hue=0.05 ), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ] ) dataset = SegmentationDataset( dir_images="./my_dataset/images/", dir_masks="./my_dataset/masks/", transform=transform, )
Normalize,Lambda,Pad,ColorJitterandRandomErasingwon't be applied to masks by default- Images from: https://www.ntu.edu.sg/home/asjfcai/Benchmark_Website/benchmark_index.html