Source code for deepforest.model

# Model - common class
from deepforest.models import *
import torch
from pytorch_lightning import LightningModule, Trainer
import os
import torchmetrics
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
import numpy as np
import rasterio
from torch.utils.data import Dataset
import torch.nn.functional as F
import cv2


[docs] class Model(): """A architecture agnostic class that controls the basic train, eval and predict functions. A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model. Then add the result to the if else statement below. Args: num_classes (int): number of classes in the model nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] score_thresh (float): minimum prediction score to keep during prediction [0,1] Returns: model: a pytorch nn module """ def __init__(self, config): # Check for required properties and formats self.config = config # Check input output format: self.check_model()
[docs] def create_model(self): """This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function """ raise ValueError( "The create_model class method needs to be implemented. Take in args and return a pytorch nn module." )
[docs] def check_model(self): """Ensure that model follows deepforest guidelines, see ##### If fails, raise ValueError.""" # This assumes model creation is not expensive test_model = self.create_model() test_model.eval() # Create a dummy batch of 3 band data. x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] predictions = test_model(x) # Model takes in a batch of images assert len(predictions) == 2 # Returns a list equal to number of images with proper keys per image model_keys = list(predictions[1].keys()) model_keys.sort() assert model_keys == ['boxes', 'labels', 'scores']
[docs] def simple_resnet_50(num_classes=2): m = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) num_ftrs = m.fc.in_features m.fc = torch.nn.Linear(num_ftrs, num_classes) return m
[docs] class CropModel(LightningModule): """A PyTorch Lightning module for classifying image crops from object detection models. This class provides a flexible architecture for training classification models on cropped regions identified by object detection models. It supports using either a default ResNet-50 model or a custom provided model. Args: num_classes (int): Number of classes for classification batch_size (int, optional): Batch size for training. Defaults to 4. num_workers (int, optional): Number of worker processes for data loading. Defaults to 0. lr (float, optional): Learning rate for optimization. Defaults to 0.0001. model (nn.Module, optional): Custom PyTorch model to use. If None, uses ResNet-50. Defaults to None. label_dict (dict, optional): Mapping of class labels to numeric indices. Defaults to None. Attributes: model (nn.Module): The classification model (ResNet-50 or custom) accuracy (torchmetrics.Accuracy): Per-class accuracy metric total_accuracy (torchmetrics.Accuracy): Overall accuracy metric precision_metric (torchmetrics.Precision): Precision metric metrics (torchmetrics.MetricCollection): Collection of all metrics batch_size (int): Batch size for training num_workers (int): Number of data loading workers lr (float): Learning rate label_dict (dict): Label to index mapping {"Bird": 0, "Mammal": 1} numeric_to_label_dict (dict): Index to label mapping {0: "Bird", 1: "Mammal"} """ def __init__(self, num_classes, batch_size=4, num_workers=0, lr=0.0001, model=None, label_dict=None): super().__init__() # Model self.num_classes = num_classes if model == None: self.model = simple_resnet_50(num_classes=num_classes) else: self.model = model # Metrics self.accuracy = torchmetrics.Accuracy(average='none', num_classes=num_classes, task="multiclass") self.total_accuracy = torchmetrics.Accuracy(num_classes=num_classes, task="multiclass") self.precision_metric = torchmetrics.Precision(num_classes=num_classes, task="multiclass") self.metrics = torchmetrics.MetricCollection({ "Class Accuracy": self.accuracy, "Accuracy": self.total_accuracy, "Precision": self.precision_metric }) # Training Hyperparameters self.batch_size = batch_size self.num_workers = num_workers self.lr = lr # Label dict self.label_dict = label_dict if label_dict is not None: self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
[docs] def create_trainer(self, **kwargs): """Create a pytorch lightning trainer object.""" self.trainer = Trainer(**kwargs)
[docs] def load_from_disk(self, train_dir, val_dir): self.train_ds = ImageFolder(root=train_dir, transform=self.get_transform(augment=True)) self.val_ds = ImageFolder(root=val_dir, transform=self.get_transform(augment=False))
[docs] def get_transform(self, augment): """Returns the data transformation pipeline for the model. Args: augment (bool): Flag indicating whether to apply data augmentation. Returns: torchvision.transforms.Compose: The composed data transformation pipeline. """ data_transforms = [] data_transforms.append(transforms.ToTensor()) data_transforms.append(self.normalize()) data_transforms.append(transforms.Resize([224, 224])) if augment: data_transforms.append(transforms.RandomHorizontalFlip(0.5)) return transforms.Compose(data_transforms)
[docs] def write_crops(self, root_dir, images, boxes, labels, savedir): """Write crops to disk. Args: root_dir (str): The root directory where the images are located. images (list): A list of image filenames. boxes (list): A list of bounding box coordinates in the format [xmin, ymin, xmax, ymax]. labels (list): A list of labels corresponding to each bounding box. savedir (str): The directory where the cropped images will be saved. Returns: None """ # Create a directory for each label for label in labels: os.makedirs(os.path.join(savedir, label), exist_ok=True) # Use rasterio to read the image for index, box in enumerate(boxes): xmin, ymin, xmax, ymax = box label = labels[index] image = images[index] with rasterio.open(os.path.join(root_dir, image)) as src: # Crop the image using the bounding box coordinates img = src.read(window=((ymin, ymax), (xmin, xmax))) # Save the cropped image as a PNG file using opencv img_path = os.path.join(savedir, label, f"crop_{index}.png") img = np.rollaxis(img, 0, 3) cv2.imwrite(img_path, img)
[docs] def normalize(self): return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
[docs] def forward(self, x): output = self.model(x) output = F.sigmoid(output) return output
[docs] def train_dataloader(self): """Train data loader.""" train_loader = torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) return train_loader
[docs] def predict_dataloader(self, ds): """Prediction data loader.""" loader = torch.utils.data.DataLoader(ds, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers) return loader
[docs] def val_dataloader(self): """Validation data loader.""" val_loader = torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers) return val_loader
[docs] def training_step(self, batch, batch_idx): x, y = batch outputs = self.forward(x) loss = F.cross_entropy(outputs, y) self.log("train_loss", loss) return loss
[docs] def predict_step(self, batch, batch_idx): outputs = self.forward(batch) yhat = F.softmax(outputs, 1) return yhat
[docs] def validation_step(self, batch, batch_idx): x, y = batch outputs = self(x) loss = F.cross_entropy(outputs, y) self.log("val_loss", loss) metric_dict = self.metrics(outputs, y) for key, value in metric_dict.items(): for key, value in metric_dict.items(): if isinstance(value, torch.Tensor) and value.numel() > 1: for i, v in enumerate(value): self.log(f"{key}_{i}", v, on_step=False, on_epoch=True) else: self.log(key, value, on_step=False, on_epoch=True) return loss
[docs] def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08) # Monitor rate is val data is used return {'optimizer': optimizer, 'lr_scheduler': scheduler, "monitor": 'val_loss'}
[docs] def dataset_confusion(self, loader): """Create a confusion matrix from a data loader.""" true_class = [] predicted_class = [] self.eval() for batch in loader: x, y = batch true_class.append(F.one_hot(y, num_classes=self.num_classes).detach().numpy()) prediction = self(x) predicted_class.append(prediction.detach().numpy()) true_class = np.concatenate(true_class) predicted_class = np.concatenate(predicted_class) return true_class, predicted_class