Source code for data
# DataProcessor
from torch.utils.data import Dataset
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from skimage.io import imread
from PIL import Image
import numpy
import torch
import glob
import random
from cv2 import convertScaleAbs
[docs]class DataProcessor(Dataset):
def __init__(self, image_dir, dir_type="image", target_size=(512, 512), image_suffix="tif"):
"""
:param image_dir: Path to images directory. Either images or masks dir
:param dir_type: One of "image" or "mask" Defaults to image
:param target_size: Target size to resize to. Defaults to (512, 512).
This was found to be the best for data preservation.
:param image_suffix: Image suffix of the images. Defaults to "tif"
:return An object of class DataProcessor ---> torch.utils.data.Dataset
"""
self.image_dir = image_dir
self.image_suffix = image_suffix
self.dir_type = dir_type
self.target_size = target_size
self.image_list = sorted(glob.glob(self.image_dir + "/*." + self.image_suffix))
def __len__(self):
return len(self.image_list)
[docs] def to_uint8(self, image):
if self.dir_type == "mask":
uint8_img = convertScaleAbs(image, alpha=255.0 / 65535.0)
else:
uint8_img = convertScaleAbs(image)
return uint8_img
def __getitem__(self, img_index):
"""
:param img_index: Index of the image
:return: Image
"""
if torch.is_tensor(img_index):
img_index = img_index.tolist()
if self.image_suffix == "tif":
final_images = imread(self.image_list[img_index], plugin="pil")
else:
final_images = imread(self.image_list[img_index])
# # Convert images to PIL/Tensor
# # List comprehension since we have image stacks.
# # uint16 doesn't work with resizing --> convert to uint8 but preserve information
# # Given the data, masks do not work well with alpha scaling
return list(map(lambda x: self.transform(Image.fromarray(self.to_uint8(x))), final_images))