# Model - common class
from deepforest.models import *
import torch
from pytorch_lightning import LightningModule, Trainer
import os
import torchmetrics
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
import numpy as np
import rasterio
from torch.utils.data import Dataset
import torch.nn.functional as F
import cv2
[docs]
class Model():
"""A architecture agnostic class that controls the basic train, eval and
predict functions. A model should optionally allow a backbone for
pretraining. To add new architectures, simply create a new module in
models/ and write a create_model. Then add the result to the if else
statement below.
Args:
num_classes (int): number of classes in the model
nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1]
score_thresh (float): minimum prediction score to keep during prediction [0,1]
Returns:
model: a pytorch nn module
"""
def __init__(self, config):
# Check for required properties and formats
self.config = config
# Check input output format:
self.check_model()
[docs]
def create_model(self):
"""This function converts a deepforest config file into a model.
An architecture should have a list of nested arguments in config
that match this function
"""
raise ValueError(
"The create_model class method needs to be implemented. Take in args and return a pytorch nn module."
)
[docs]
def check_model(self):
"""Ensure that model follows deepforest guidelines, see ##### If fails,
raise ValueError."""
# This assumes model creation is not expensive
test_model = self.create_model()
test_model.eval()
# Create a dummy batch of 3 band data.
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = test_model(x)
# Model takes in a batch of images
assert len(predictions) == 2
# Returns a list equal to number of images with proper keys per image
model_keys = list(predictions[1].keys())
model_keys.sort()
assert model_keys == ['boxes', 'labels', 'scores']
[docs]
def simple_resnet_50(num_classes=2):
m = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
num_ftrs = m.fc.in_features
m.fc = torch.nn.Linear(num_ftrs, num_classes)
return m
[docs]
class CropModel(LightningModule):
"""A PyTorch Lightning module for classifying image crops from object
detection models.
This class provides a flexible architecture for training classification models on cropped
regions identified by object detection models. It supports using either a default ResNet-50
model or a custom provided model.
Args:
num_classes (int): Number of classes for classification
batch_size (int, optional): Batch size for training. Defaults to 4.
num_workers (int, optional): Number of worker processes for data loading. Defaults to 0.
lr (float, optional): Learning rate for optimization. Defaults to 0.0001.
model (nn.Module, optional): Custom PyTorch model to use. If None, uses ResNet-50. Defaults to None.
label_dict (dict, optional): Mapping of class labels to numeric indices. Defaults to None.
Attributes:
model (nn.Module): The classification model (ResNet-50 or custom)
accuracy (torchmetrics.Accuracy): Per-class accuracy metric
total_accuracy (torchmetrics.Accuracy): Overall accuracy metric
precision_metric (torchmetrics.Precision): Precision metric
metrics (torchmetrics.MetricCollection): Collection of all metrics
batch_size (int): Batch size for training
num_workers (int): Number of data loading workers
lr (float): Learning rate
label_dict (dict): Label to index mapping {"Bird": 0, "Mammal": 1}
numeric_to_label_dict (dict): Index to label mapping {0: "Bird", 1: "Mammal"}
"""
def __init__(self,
num_classes,
batch_size=4,
num_workers=0,
lr=0.0001,
model=None,
label_dict=None):
super().__init__()
# Model
self.num_classes = num_classes
if model == None:
self.model = simple_resnet_50(num_classes=num_classes)
else:
self.model = model
# Metrics
self.accuracy = torchmetrics.Accuracy(average='none',
num_classes=num_classes,
task="multiclass")
self.total_accuracy = torchmetrics.Accuracy(num_classes=num_classes,
task="multiclass")
self.precision_metric = torchmetrics.Precision(num_classes=num_classes,
task="multiclass")
self.metrics = torchmetrics.MetricCollection({
"Class Accuracy": self.accuracy,
"Accuracy": self.total_accuracy,
"Precision": self.precision_metric
})
# Training Hyperparameters
self.batch_size = batch_size
self.num_workers = num_workers
self.lr = lr
# Label dict
self.label_dict = label_dict
if label_dict is not None:
self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
[docs]
def create_trainer(self, **kwargs):
"""Create a pytorch lightning trainer object."""
self.trainer = Trainer(**kwargs)
[docs]
def load_from_disk(self, train_dir, val_dir):
self.train_ds = ImageFolder(root=train_dir,
transform=self.get_transform(augment=True))
self.val_ds = ImageFolder(root=val_dir,
transform=self.get_transform(augment=False))
[docs]
def write_crops(self, root_dir, images, boxes, labels, savedir):
"""Write crops to disk.
Args:
root_dir (str): The root directory where the images are located.
images (list): A list of image filenames.
boxes (list): A list of bounding box coordinates in the format [xmin, ymin, xmax, ymax].
labels (list): A list of labels corresponding to each bounding box.
savedir (str): The directory where the cropped images will be saved.
Returns:
None
"""
# Create a directory for each label
for label in labels:
os.makedirs(os.path.join(savedir, label), exist_ok=True)
# Use rasterio to read the image
for index, box in enumerate(boxes):
xmin, ymin, xmax, ymax = box
label = labels[index]
image = images[index]
with rasterio.open(os.path.join(root_dir, image)) as src:
# Crop the image using the bounding box coordinates
img = src.read(window=((ymin, ymax), (xmin, xmax)))
# Save the cropped image as a PNG file using opencv
img_path = os.path.join(savedir, label, f"crop_{index}.png")
img = np.rollaxis(img, 0, 3)
cv2.imwrite(img_path, img)
[docs]
def normalize(self):
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
[docs]
def forward(self, x):
output = self.model(x)
output = F.sigmoid(output)
return output
[docs]
def train_dataloader(self):
"""Train data loader."""
train_loader = torch.utils.data.DataLoader(self.train_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)
return train_loader
[docs]
def predict_dataloader(self, ds):
"""Prediction data loader."""
loader = torch.utils.data.DataLoader(ds,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers)
return loader
[docs]
def val_dataloader(self):
"""Validation data loader."""
val_loader = torch.utils.data.DataLoader(self.val_ds,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers)
return val_loader
[docs]
def training_step(self, batch, batch_idx):
x, y = batch
outputs = self.forward(x)
loss = F.cross_entropy(outputs, y)
self.log("train_loss", loss)
return loss
[docs]
def predict_step(self, batch, batch_idx):
outputs = self.forward(batch)
yhat = F.softmax(outputs, 1)
return yhat
[docs]
def validation_step(self, batch, batch_idx):
x, y = batch
outputs = self(x)
loss = F.cross_entropy(outputs, y)
self.log("val_loss", loss)
metric_dict = self.metrics(outputs, y)
for key, value in metric_dict.items():
for key, value in metric_dict.items():
if isinstance(value, torch.Tensor) and value.numel() > 1:
for i, v in enumerate(value):
self.log(f"{key}_{i}", v, on_step=False, on_epoch=True)
else:
self.log(key, value, on_step=False, on_epoch=True)
return loss
[docs]
def dataset_confusion(self, loader):
"""Create a confusion matrix from a data loader."""
true_class = []
predicted_class = []
self.eval()
for batch in loader:
x, y = batch
true_class.append(F.one_hot(y, num_classes=self.num_classes).detach().numpy())
prediction = self(x)
predicted_class.append(prediction.detach().numpy())
true_class = np.concatenate(true_class)
predicted_class = np.concatenate(predicted_class)
return true_class, predicted_class