"""Dataset model.
https://pytorch.org/docs/stable/torchvision/models.html#object-detection-instance-segmentation-and-person-keypoint-detection
During training, the model expects both the input tensors, as well as a
targets (list of dictionary), containing:
boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2]
format, with values between 0 and H and 0 and W
labels (Int64Tensor[N]): the class label for each ground-truth box
https://colab.research.google.com/github/benihime91/pytorch_retinanet/blob/master/demo.ipynb#scrollTo=0zNGhr6D7xGN
"""
import os
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
import albumentations as A
from albumentations import functional as F
from albumentations.pytorch import ToTensorV2
import torch
import typing
from PIL import Image
import rasterio as rio
from deepforest import preprocess
from rasterio.windows import Window
from torchvision import transforms
import slidingwindow
import warnings
[docs]
class TreeDataset(Dataset):
def __init__(self,
csv_file,
root_dir,
transforms=None,
label_dict={"Tree": 0},
train=True,
preload_images=False):
"""
Args:
csv_file (string): Path to a single csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
label_dict: a dictionary where keys are labels from the csv column and values are numeric labels "Tree" -> 0
Returns:
If train, path, image, targets else image
"""
self.annotations = pd.read_csv(csv_file)
self.root_dir = root_dir
if transforms is None:
self.transform = get_transform(augment=train)
else:
self.transform = transforms
self.image_names = self.annotations.image_path.unique()
self.label_dict = label_dict
self.train = train
self.image_converter = A.Compose([ToTensorV2()])
self.preload_images = preload_images
# Pin data to memory if desired
if self.preload_images:
print("Pinning dataset to GPU memory")
self.image_dict = {}
for idx, x in enumerate(self.image_names):
img_name = os.path.join(self.root_dir, x)
image = np.array(Image.open(img_name).convert("RGB")) / 255
self.image_dict[idx] = image.astype("float32")
def __len__(self):
return len(self.image_names)
def __getitem__(self, idx):
# Read image if not in memory
if self.preload_images:
image = self.image_dict[idx]
else:
img_name = os.path.join(self.root_dir, self.image_names[idx])
image = np.array(Image.open(img_name).convert("RGB")) / 255
image = image.astype("float32")
if self.train:
# select annotations
image_annotations = self.annotations[self.annotations.image_path ==
self.image_names[idx]]
targets = {}
targets["boxes"] = image_annotations[["xmin", "ymin", "xmax",
"ymax"]].values.astype("float32")
# Labels need to be encoded
targets["labels"] = image_annotations.label.apply(
lambda x: self.label_dict[x]).values.astype(np.int64)
# If image has no annotations, don't augment
if np.sum(targets["boxes"]) == 0:
boxes = torch.zeros((0, 4), dtype=torch.float32)
labels = torch.zeros(0, dtype=torch.int64)
# channels last
image = np.rollaxis(image, 2, 0)
image = torch.from_numpy(image).float()
targets = {"boxes": boxes, "labels": labels}
return self.image_names[idx], image, targets
augmented = self.transform(image=image,
bboxes=targets["boxes"],
category_ids=targets["labels"])
image = augmented["image"]
boxes = np.array(augmented["bboxes"])
boxes = torch.from_numpy(boxes).float()
labels = np.array(augmented["category_ids"])
labels = torch.from_numpy(labels)
targets = {"boxes": boxes, "labels": labels}
return self.image_names[idx], image, targets
else:
# Mimic the train augmentation
converted = self.image_converter(image=image)
return converted["image"]
[docs]
class TileDataset(Dataset):
def __init__(self,
tile: typing.Optional[np.ndarray],
preload_images: bool = False,
patch_size: int = 400,
patch_overlap: float = 0.05):
"""
Args:
tile: an in memory numpy array.
patch_size (int): The size for the crops used to cut the input raster into smaller pieces. This is given in pixels, not any geographic unit.
patch_overlap (float): The horizontal and vertical overlap among patches
preload_images (bool): If true, the entire dataset is loaded into memory. This is useful for small datasets, but not recommended for large datasets since both the tile and the crops are stored in memory.
Returns:
ds: a pytorch dataset
"""
if not tile.shape[2] == 3:
raise ValueError(
"Only three band raster are accepted. Channels should be the final dimension. Input tile has shape {}. Check for transparent alpha channel and remove if present"
.format(tile.shape))
self.image = tile
self.preload_images = preload_images
self.windows = preprocess.compute_windows(self.image, patch_size, patch_overlap)
if self.preload_images:
self.crops = []
for window in self.windows:
crop = self.image[window.indices()]
crop = preprocess.preprocess_image(crop)
self.crops.append(crop)
def __len__(self):
return len(self.windows)
def __getitem__(self, idx):
# Read image if not in memory
if self.preload_images:
crop = self.crops[idx]
else:
crop = self.image[self.windows[idx].indices()]
crop = preprocess.preprocess_image(crop)
return crop
[docs]
class RasterDataset:
"""Dataset for predicting on raster windows.
Args:
raster_path (str): Path to raster file
patch_size (int): Size of windows to predict on
patch_overlap (float): Overlap between windows as fraction (0-1)
Returns:
A dataset of raster windows
"""
def __init__(self, raster_path, patch_size, patch_overlap):
self.raster_path = raster_path
self.patch_size = patch_size
self.patch_overlap = patch_overlap
# Get raster shape without keeping file open
with rio.open(raster_path) as src:
width = src.shape[0]
height = src.shape[1]
# Check is tiled
if not src.is_tiled:
raise ValueError(
"Out-of-memory dataset is selected, but raster is not tiled, "
"leading to entire raster being read into memory and defeating "
"the purpose of an out-of-memory dataset. "
"\nPlease run: "
"\ngdal_translate -of GTiff -co TILED=YES <input> <output> "
"to create a tiled raster")
# Generate sliding windows
self.windows = slidingwindow.generateForSize(
height,
width,
dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
maxWindowSize=patch_size,
overlapPercent=patch_overlap)
self.n_windows = len(self.windows)
def __len__(self):
return self.n_windows
def __getitem__(self, idx):
"""Get a window of the raster.
Args:
idx (int): Index of window to get
Returns:
crop (torch.Tensor): A tensor of shape (3, height, width)
"""
window = self.windows[idx]
# Open, read window, and close for each operation
with rio.open(self.raster_path) as src:
window_data = src.read(window=Window(window.x, window.y, window.w, window.h))
# Convert to torch tensor and rearrange dimensions
window_data = torch.from_numpy(window_data).float() # Convert to torch tensor
window_data = window_data / 255.0 # Normalize
return window_data # Already in (C, H, W) format from rasterio
resnet_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
[docs]
class BoundingBoxDataset(Dataset):
"""An in memory dataset for bounding box predictions.
Args:
df: a pandas dataframe with image_path and xmin,xmax,ymin,ymax columns
transform: a function to apply to the image
root_dir: the directory where the image is stored
Returns:
rgb: a tensor of shape (3, height, width)
"""
def __init__(self, df, root_dir, transform=None, augment=False):
self.df = df
if transform is None:
self.transform = bounding_box_transform(augment=augment)
else:
self.transform = transform
unique_image = self.df['image_path'].unique()
assert len(unique_image
) == 1, "There should be only one unique image for this class object"
# Open the image using rasterio
self.src = rio.open(os.path.join(root_dir, unique_image[0]))
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
row = self.df.iloc[idx]
xmin = row['xmin']
xmax = row['xmax']
ymin = row['ymin']
ymax = row['ymax']
# Read the RGB data
box = self.src.read(window=Window(xmin, ymin, xmax - xmin, ymax - ymin))
box = np.rollaxis(box, 0, 3)
if self.transform:
image = self.transform(box)
else:
image = box
return image