Source code for deepforest.main

# entry point for deepforest model
import importlib
import os
import warnings
from numbers import Number

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torchmetrics
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from torch import optim
from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision

from deepforest import distributed, predict, utilities
from deepforest.datasets import prediction, training
from deepforest.metrics import RecallPrecision

Image.MAX_IMAGE_PIXELS = None


[docs]class deepforest(pl.LightningModule): """DeepForest model for tree crown detection in RGB images. Args: num_classes: Number of classes in the model model: DeepForest model object existing_train_dataloader: PyTorch dataloader for training data existing_val_dataloader: PyTorch dataloader for validation data config: DeepForest configuration object or name config_args: Dictionary of config overrides """ def __init__( self, model=None, transforms=None, existing_train_dataloader=None, existing_val_dataloader=None, config: str | dict | DictConfig | None = None, config_args: dict | None = None, ): super().__init__() if config is None: config = utilities.load_config(overrides=config_args) # Default/string config name elif isinstance(config, str): config = utilities.load_config(config_name=config, overrides=config_args) # Checkpoint load elif isinstance(config, dict): config = OmegaConf.merge(config, config_args or {}) config = utilities.load_config(overrides=config) # Hub overrides elif "config_args" in config: config = utilities.load_config(overrides=config["config_args"]) elif config_args is not None: warnings.warn( f"Ignoring options as configuration object was provided: {config_args}", stacklevel=2, ) self.config = config # release version id to flag if release is being used self.__release_version__ = None self.existing_train_dataloader = existing_train_dataloader self.existing_val_dataloader = existing_val_dataloader self.model = model self.original_batch_structure = [] if self.model is None: self.create_model() # Create a default trainer. self.create_trainer() # Add user supplied transforms if transforms is None: self.transforms = None else: self.transforms = transforms self.save_hyperparameters( {"config": OmegaConf.to_container(self.config, resolve=True)} )
[docs] def setup_metrics(self): # Guard against initialization before a validation csv_file is set if not self.config.validation.csv_file and self.existing_val_dataloader is None: return # Box Metrics if self.model.task == "box": self.iou_metric = IntersectionOverUnion( class_metrics=True, iou_threshold=self.config.validation.iou_threshold ) self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval") self.precision_recall_metric = RecallPrecision( iou_threshold=self.config.validation.iou_threshold, label_dict=self.label_dict, task=self.model.task, ) # Point Metrics elif self.model.task == "point": self.mae_metric = torchmetrics.MeanAbsoluteError() self.precision_recall_metric = RecallPrecision( distance_threshold=self.config.point.distance_threshold, label_dict=self.label_dict, task=self.model.task, ) # Segmentation metrics if self.model.task == "polygon": self.mAP_metric = MeanAveragePrecision( iou_type="segm", backend="faster_coco_eval" )
[docs] def load_model(self, model_name=None, revision=None): """Loads a model that has already been pretrained for a specific task, like tree crown detection. Models (technically model weights) are distributed via Hugging Face and designated the Hugging Face repository ID (model_name), which is in the form: 'organization/repository'. For a list of models distributed by the DeepForest team (and the associated model names) see the documentation: https://deepforest.readthedocs.io/en/stable/user_guide/02_prebuilt.html Args: model_name (str): A repository ID for huggingface in the form of organization/repository revision (str): The model version ('main', 'v1.0.0', etc.). Returns: None """ if model_name is None: model_name = self.config.model.name if revision is None: revision = self.config.model.revision model_class = importlib.import_module( f"deepforest.models.{self.config.architecture}" ) self.model = model_class.Model(config=self.config).create_model( pretrained=model_name, revision=revision ) self.config.num_classes = self.model.num_classes self.set_labels(self.model.label_dict) return
[docs] def set_labels(self, label_dict): """Set new label mapping, updating both the label dictionary (str -> int) and its inverse (int -> str). Args: label_dict (dict): Dictionary mapping class names to numeric IDs. """ if label_dict is None: raise ValueError( "Label dictionary not found. Check it was set in your config file or config_args." ) if len(label_dict) != self.config.num_classes: raise ValueError( f"label_dict {label_dict} does not match requested number of " f"classes {self.config.num_classes}, please supply a label_dict argument " '{"label1":0, "label2":1, "label3":2 ... etc} ' "for each label in the " "dataset" ) # Check for duplicate values in label_dict: if len(set(label_dict.values())) != len(label_dict): raise ValueError("Found duplicate label IDs in label_dict.") self.label_dict = label_dict self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
[docs] def create_model(self, initialize_model=False): """Initialize a deepforest architecture. This can be done in two ways. Passed as the model argument to deepforest __init__(), or as a named architecture in config.architecture, which corresponds to a file in models/, as is a subclass of model.Model(). The config args in the .yaml are specified. Returns: None """ if self.config.model.name is None or initialize_model: model_class = importlib.import_module( f"deepforest.models.{self.config.architecture}" ) self.model = model_class.Model(config=self.config).create_model() self.set_labels(self.config.label_dict) else: self.load_model()
[docs] def create_trainer(self, logger=None, callbacks=None, **kwargs): """Create a pytorch lightning training by reading config files. Args: logger: Optional logger callbacks: Optional list of callbacks **kwargs: Additional trainer arguments """ # Setup metrics which may have changed if the config was modified self.setup_metrics() if callbacks is None: callbacks = [] # Set default logger to use log_root from config instead of lightning_logs if logger is None: logger = CSVLogger(save_dir=self.config.log_root, name="") # If val data is passed, monitor learning rate and setup classification metrics has_val = ( self.config.validation.csv_file is not None or self.existing_val_dataloader is not None ) if has_val: if logger is not None: lr_monitor = LearningRateMonitor(logging_interval="epoch") callbacks.append(lr_monitor) limit_val_batches = 1.0 num_sanity_val_steps = 2 else: # Disable validation, don't use trainer defaults limit_val_batches = 0 num_sanity_val_steps = 0 # Check for model checkpoint object checkpoint_types = [type(x).__qualname__ for x in callbacks] if "ModelCheckpoint" in checkpoint_types: enable_checkpointing = True else: enable_checkpointing = False if torch.cuda.is_available(): torch.set_float32_matmul_precision(self.config.matmul_precision) trainer_args = { "logger": logger, "max_epochs": self.config.train.epochs, "enable_checkpointing": enable_checkpointing, "devices": self.config.devices, "accelerator": self.config.accelerator, "num_nodes": self.config.num_nodes, "strategy": self.config.strategy, "precision": self.config.precision, "sync_batchnorm": self.config.sync_batchnorm, "use_distributed_sampler": self.config.use_distributed_sampler, "fast_dev_run": self.config.train.fast_dev_run, "callbacks": callbacks, "limit_val_batches": limit_val_batches, "num_sanity_val_steps": num_sanity_val_steps, "default_root_dir": self.config.log_root, } if self.config.precision is not None: trainer_args["precision"] = self.config.precision # Update with kwargs to allow them to override config trainer_args.update(kwargs) self.trainer = pl.Trainer(**trainer_args)
def _format_prediction_frame(self, prediction_result, metadata): """Format a raw model prediction with dataset metadata.""" formatted_result = utilities.format_geometry(prediction_result) if formatted_result is None: return pd.DataFrame() window_bounds = metadata.get("window_bounds") if window_bounds is not None: formatted_result["window_xmin"] = window_bounds[0] formatted_result["window_ymin"] = window_bounds[1] formatted_result["image_path"] = metadata.get("image_path") return formatted_result.reset_index(drop=True) def _gather_prediction_frames(self, frames): """Gather prediction frames across ranks.""" if not frames: return distributed.gather_dataframe(pd.DataFrame()) local_frames = [frame for frame in frames if not frame.empty] if not local_frames: return distributed.gather_dataframe(frames[0].iloc[0:0].copy()) local_results = pd.concat(local_frames, ignore_index=True) return distributed.gather_dataframe(local_results)
[docs] def on_fit_start(self): if self.config.train.csv_file is None and self.existing_train_dataloader is None: raise AttributeError( "Cannot train without a train annotations file " "or existing_train_dataloader. Please set " "'config.train.csv_file' or pass " "existing_train_dataloader before " "calling deepforest.create_trainer()'" )
[docs] def on_save_checkpoint(self, checkpoint): # Update hparams in case they've changed since init checkpoint["hyper_parameters"]["config"] = OmegaConf.to_container( self.config, resolve=True ) checkpoint["label_dict"] = self.label_dict checkpoint["numeric_to_label_dict"] = self.numeric_to_label_dict for key in checkpoint: if isinstance(checkpoint[key], DictConfig): checkpoint[key] = OmegaConf.to_container(checkpoint[key], resolve=True)
[docs] def on_load_checkpoint(self, checkpoint): try: self.label_dict = checkpoint["label_dict"] self.numeric_to_label_dict = checkpoint["numeric_to_label_dict"] except KeyError: print( "No label_dict found in checkpoint, using default label_dict, " "please use deepforest.set_labels() to set the label_dict after loading the checkpoint." ) # Pre 2.0 compatibility, the score_threshold used to be stored under retinanet.score_thresh try: self.config.score_thresh = self.config.retinanet.score_thresh except AttributeError: pass if not hasattr(self.config.validation, "lr_plateau_target"): default_config = utilities.load_config() self.config.validation.lr_plateau_target = ( default_config.validation.lr_plateau_target ) if not hasattr(self.config.train, "augmentations"): default_config = utilities.load_config() self.config.train.augmentations = default_config.train.augmentations if not hasattr(self.config.validation, "augmentations"): default_config = utilities.load_config() self.config.validation.augmentations = default_config.validation.augmentations
[docs] def save_model(self, path): """Save the trainer checkpoint in user defined path, in order to access in future. Args: Path: the path located the model checkpoint """ self.trainer.save_checkpoint(path)
[docs] def load_dataset( self, csv_file, root_dir=None, shuffle=True, transforms=None, augmentations=None, preload_images=False, validate_coordinates=True, batch_size=1, ): """Create a dataset for inference or training. Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position. Image_path is the relative filename, not absolute path, which is in the root_dir directory. One bounding box per line. Args: csv_file: path to csv file root_dir: directory of images. If none, uses "image_dir" in config transforms: Albumentations transforms batch_size: batch size preload_images: if True, preload the images into memory validate_coordinates: if True, check annotation coordinates fall within image bounds augmentations: augmentation configuration (str, list, or dict) Returns: ds: a pytorch dataset """ if self.model.task == "box": ds = training.BoxDataset( csv_file=csv_file, root_dir=root_dir, transforms=transforms, label_dict=self.label_dict, augmentations=augmentations, preload_images=preload_images, validate_coordinates=validate_coordinates, ) elif self.model.task == "point": ds = training.PointDataset( csv_file=csv_file, root_dir=root_dir, transforms=transforms, label_dict=self.label_dict, augmentations=augmentations, preload_images=preload_images, validate_coordinates=validate_coordinates, ) elif self.model.task == "polygon": ds = training.PolygonDataset( csv_file=csv_file, root_dir=root_dir, transforms=transforms, label_dict=self.label_dict, augmentations=augmentations, preload_images=preload_images, validate_coordinates=validate_coordinates, ) else: raise ValueError( f"Invalid task type: {self.model.task}, expected 'box', 'point' or 'polygon'" ) if len(ds) == 0: raise ValueError( f"Dataset from {csv_file} is empty. Check CSV for valid entries and columns." ) data_loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=shuffle, collate_fn=ds.collate_fn, num_workers=self.config.workers, ) return data_loader
[docs] def train_dataloader(self): """Train loader using the configurations. Returns: loader """ if self.existing_train_dataloader: return self.existing_train_dataloader loader = self.load_dataset( csv_file=self.config.train.csv_file, root_dir=self.config.train.root_dir, augmentations=self.config.train.augmentations, preload_images=self.config.train.preload_images, validate_coordinates=self.config.train.validate_coordinates, shuffle=True, transforms=self.transforms, batch_size=self.config.batch_size, ) return loader
[docs] def val_dataloader(self): """Create a val data loader only if specified in config. Returns: a dataloader or a empty iterable. """ # The preferred route for skipping validation is now (pl-2.0) an empty list, # see https://github.com/Lightning-AI/lightning/issues/17154 loader = [] if self.existing_val_dataloader: return self.existing_val_dataloader if self.config.validation.csv_file is not None: loader = self.load_dataset( csv_file=self.config.validation.csv_file, root_dir=self.config.validation.root_dir, augmentations=self.config.validation.augmentations, shuffle=False, preload_images=self.config.validation.preload_images, validate_coordinates=self.config.validation.validate_coordinates, batch_size=self.config.batch_size, ) return loader
[docs] def predict_dataloader(self, ds, batch_size=None): """Create a PyTorch dataloader for prediction. Args: ds (torchvision.datasets.Dataset): A torchvision dataset to be wrapped into a dataloader using config args. Returns: torch.utils.data.DataLoader: A dataloader object that can be used for prediction. """ if batch_size is None: batch_size = self.config.batch_size else: batch_size = batch_size sampler = None if ( self.config.use_distributed_sampler and distributed.is_distributed() and len(ds) < distributed.get_world_size() ): rank = distributed.get_rank() local_indices = [rank] if rank < len(ds) else [] sampler = distributed.FixedOrderSampler(local_indices) loader = torch.utils.data.DataLoader( ds, batch_size=batch_size, shuffle=False, sampler=sampler, num_workers=self.config.workers, collate_fn=ds.collate_fn, pin_memory=self.config.predict.pin_memory, ) return loader
[docs] def predict_image(self, image: np.ndarray | None = None, path: str | None = None): """Predict a single image with a deepforest model. Args: image: a float32 numpy array of a RGB with channels last format path: optional path to read image from disk instead of passing image arg Returns: result: A pandas dataframe of predictions (Default) """ # Ensure we are in eval mode self.model.eval() if path: image = np.array(Image.open(path).convert("RGB")).astype("float32") # sanity checks on input images if not isinstance(image, np.ndarray): raise TypeError( f"Input image is of type {type(image)}, expected numpy, if reading " "from PIL, wrap in " "np.array(image).astype(float32)" ) if image.dtype != "float32": warnings.warn( f"Image type is {image.dtype}, transforming to float32. " f"This assumes that the range of pixel values is 0-255, as " f"opposed to 0-1.To suppress this warning, transform image " f"(image.astype('float32')", stacklevel=2, ) image = image.astype("float32") result = predict._predict_image_( model=self.model, image=image, path=path, iou_threshold=self.config.nms_thresh, nms_distance_thresh=self.config.point.nms_distance_thresh, ) # If there were no predictions, return None if result is None: return None else: result["label"] = result.label.apply(lambda x: self.numeric_to_label_dict[x]) if path is None: warnings.warn( "An image was passed directly to predict_image, the result.root_dir attribute " "will be None in the output dataframe, to use visualize.plot_results, " "please assign results.root_dir = <directory name>", stacklevel=2, ) else: root_dir = os.path.dirname(path) result = utilities.read_file(result, root_dir=root_dir) return result
[docs] def predict_file( self, csv_file, root_dir, crop_model=None, ): """Create a dataset and predict entire annotation file CSV file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position. Image_path is the relative filename, not absolute path, which is in the root_dir directory. One bounding box per line. Args: csv_file: path to csv file root_dir: directory of images. If none, uses "image_dir" in config crop_model: a deepforest.model.CropModel object to predict on crops size: the size of the image to resize to. Optional, if not provided, the image is not resized. Returns: df: pandas dataframe with bounding boxes, label and scores for each image in the csv file """ ds = prediction.FromCSVFile( csv_file=csv_file, root_dir=root_dir, return_metadata=True ) dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size) results = predict._dataloader_wrapper_( model=self, crop_model=crop_model, trainer=self.trainer, dataloader=dataloader, root_dir=root_dir, ) results.root_dir = root_dir return results
[docs] def predict_tile( self, path=None, image=None, patch_size=400, patch_overlap=0.05, iou_threshold=0.15, dataloader_strategy="single", crop_model=None, ): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. Args: path: Path or list of paths to images on disk. If a single string is provided, it will be converted to a list. image (array): Numpy image array in BGR channel order following openCV convention. Not possible in combination with dataloader_strategy='batch'. patch_size: patch size for each window patch_overlap: patch overlap among windows iou_threshold: Minimum iou overlap among predictions between windows to be suppressed dataloader_strategy: "single", "batch", or "window". - "Single" loads the entire image into memory and passes individual windows to GPU and cannot be parallelized. - "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images. - "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows. crop_model: a deepforest.model.CropModel object to predict on crops Returns: pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple """ self.model.eval() self.model.nms_thresh = self.config.nms_thresh # Check if path or image is provided if dataloader_strategy == "single": if path is None and image is None: raise ValueError( "Either path or image must be provided for single tile prediction" ) if dataloader_strategy == "batch": if path is None: raise ValueError( "path argument must be provided when using dataloader_strategy='batch'" ) # Convert single path to list for consistent handling if isinstance(path, str): paths = [path] elif path is None: paths = [None] else: paths = path image_results = [] if dataloader_strategy in ["single", "window"]: for image_path in paths: if dataloader_strategy == "single": ds = prediction.SingleImage( path=image_path, image=image, patch_overlap=patch_overlap, patch_size=patch_size, return_metadata=True, ) else: # Check for workers config when using out of memory dataset if self.config.workers > 0: raise ValueError( "workers must be 0 when using out-of-memory dataset " "(dataloader_strategy='window'). Set config['workers']=0 and recreate " "trainer self.create_trainer()." ) ds = prediction.TiledRaster( path=image_path, patch_overlap=patch_overlap, patch_size=patch_size, return_metadata=True, ) dataloader = self.predict_dataloader(ds) batched_results = self.trainer.predict(self, dataloader) image_results.append( predict._flatten_prediction_batches_(batched_results) ) if not image_results: results = pd.DataFrame() else: results = self._gather_prediction_frames(image_results) elif dataloader_strategy == "batch": self.original_batch_structure.clear() ds = prediction.MultiImage( paths=paths, patch_overlap=patch_overlap, patch_size=patch_size, return_metadata=True, ) dataloader = self.predict_dataloader(ds) batched_results = self.trainer.predict(self, dataloader) image_results.append(predict._flatten_prediction_batches_(batched_results)) if not image_results: results = pd.DataFrame() else: results = self._gather_prediction_frames(image_results) else: raise ValueError(f"Invalid dataloader_strategy: {dataloader_strategy}") if results.empty: warnings.warn("No predictions made, returning None", stacklevel=2) return None # Perform mosaic for each image_path, or all if image_path is None mosaic_results = [] if results["image_path"].isnull().all(): mosaic_results.append( predict.mosaic( results, iou_threshold=iou_threshold, nms_distance_thresh=self.config.point.nms_distance_thresh, ) ) else: for image_path in results["image_path"].unique(): image_results = results[results["image_path"] == image_path] image_mosaic = predict.mosaic( image_results, iou_threshold=iou_threshold, nms_distance_thresh=self.config.point.nms_distance_thresh, ) image_mosaic["image_path"] = image_path mosaic_results.append(image_mosaic) mosaic_results = pd.concat(mosaic_results) mosaic_results["label"] = mosaic_results.label.apply( lambda x: self.numeric_to_label_dict.get(x, x) ) if paths[0] is not None: root_dir = os.path.dirname(paths[0]) else: print( "No image path provided, root_dir of the output results dataframe will be None, since either " "images were directly provided or there were multiple image paths" ) root_dir = None if crop_model is not None: cropmodel_results = [] for path in paths: image_result = mosaic_results[ mosaic_results.image_path == os.path.basename(path) ] if image_result.empty: continue image_result.root_dir = os.path.dirname(path) cropmodel_result = predict._crop_models_wrapper_( crop_model, self.trainer, image_result ) cropmodel_results.append(cropmodel_result) cropmodel_results = pd.concat(cropmodel_results) else: cropmodel_results = mosaic_results formatted_results = utilities.__pandas_to_geodataframe__(cropmodel_results) formatted_results.root_dir = root_dir return formatted_results
[docs] def training_step(self, batch, batch_idx): """Train on a loaded dataset.""" # Confirm model is in train mode self.model.train() # allow for empty data if data augmentation is generated images, targets, image_names = batch loss_dict = self.model.forward(images, targets) # sum of regression and classification loss losses = sum(loss_dict.values()) # Log loss for key, value in loss_dict.items(): self.log( f"train_{key}", value.detach(), on_epoch=True, batch_size=len(images), sync_dist=distributed.should_sync(self.trainer), ) # Log sum of losses self.log( "train_loss", losses.detach(), on_epoch=True, batch_size=len(images), sync_dist=distributed.should_sync(self.trainer), ) return losses
[docs] def validation_step(self, batch, batch_idx): """Evaluate a batch.""" images, targets, image_names = batch # Set model to train mode to return loss, but disable optimization. # Torchvision does not return loss in eval mode. self.model.train() with torch.no_grad(): loss_dict = self.model.forward(images, targets) # sum of regression and classification loss losses = sum(loss_dict.values()) # Log losses for key, value in loss_dict.items(): self.log( f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images), sync_dist=distributed.should_sync(self.trainer), ) self.log( "val_loss", losses.detach(), on_epoch=True, batch_size=len(images), sync_dist=distributed.should_sync(self.trainer), ) # In eval model, return predictions to calculate prediction metrics self.model.eval() with torch.no_grad(): preds = self.model.forward(images, targets) # Compute precision, recall and empty frame metrics. if self.model.task == "box" or self.model.task == "point": self.precision_recall_metric.update(preds, targets, image_names) if self.model.task == "box": # Filter out empty frames for IoU/mAP metrics. pred + target non_empty_pred = [] non_empty_target = [] for pred, target in zip(preds, targets, strict=True): if not (target["boxes"].numel() == 0 or torch.all(target["boxes"] == 0)): non_empty_pred.append(pred) non_empty_target.append(target) self.iou_metric.update(non_empty_pred, non_empty_target) self.mAP_metric.update(non_empty_pred, non_empty_target) elif self.model.task == "point": device = targets[0]["points"].device if targets else torch.device("cpu") pred_counts = torch.tensor( [float(len(p["points"])) for p in preds], dtype=torch.float32, device=device, ) true_counts = torch.tensor( [float(len(t["points"])) for t in targets], dtype=torch.float32, device=device, ) self.mae_metric.update(pred_counts, true_counts) elif self.model.task == "polygon": non_empty_pred = [] non_empty_target = [] for pred, target in zip(preds, targets, strict=True): if "masks" in pred: masks = pred["masks"].squeeze(1) if masks.dtype.is_floating_point: masks = masks > 0.5 pred["masks"] = masks.to(torch.uint8) if "masks" in target: target["masks"] = target["masks"].to(torch.uint8) if len(target["labels"] > 0): non_empty_pred.append(pred) non_empty_target.append(target) self.mAP_metric.update(non_empty_pred, non_empty_target) # Log the predictions if you want to use them for evaluation logs for i, result in enumerate(preds): formatted_result = utilities.format_geometry(result) if formatted_result is not None: formatted_result["image_path"] = image_names[i] self.predictions.append(formatted_result) return losses
[docs] def on_validation_epoch_start(self): self.predictions = []
def _compute_epoch_metrics(self) -> dict: """Compute metrics and returns a Lightning-loggable dictionary. This function is called automatically at the end of validation. """ metrics = {} if self.model.task == "box": # IoU and mAP if len(self.iou_metric.groundtruth_labels) > 0: metrics.update(self.iou_metric.compute()) # Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning output = self.mAP_metric.compute() # Remove classes from output dict output = { key: value for key, value in output.items() if not key == "classes" } metrics.update(output) metrics.update(self.precision_recall_metric.compute()) elif self.model.task == "point": metrics["val_mae"] = self.mae_metric.compute() metrics.update(self.precision_recall_metric.compute()) return metrics def _prepare_metrics_for_sync(self, metrics: dict) -> dict: """Move scalar metrics onto the module device for NCCL reduction.""" synced_metrics = {} for key, value in metrics.items(): if isinstance(value, torch.Tensor): synced_metrics[key] = value.to(self.device) elif isinstance(value, Number): synced_metrics[key] = torch.tensor(value, device=self.device) else: synced_metrics[key] = value return synced_metrics
[docs] def on_validation_epoch_end(self): """Compute metrics and predictions at the end of the validation epoch.""" if self.trainer.sanity_checking: # optional skip return gathered_predictions = self._gather_prediction_frames(self.predictions) self.predictions = ( [gathered_predictions] if not gathered_predictions.empty else [] ) # Log epoch metrics if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0: metrics = self._compute_epoch_metrics() should_sync = distributed.should_sync(self.trainer) if should_sync: metrics = self._prepare_metrics_for_sync(metrics) self.log_dict(metrics, sync_dist=should_sync) # Manual reset. Lightning does not do this automatically # unless we log the metric objects directly if self.model.task == "box": self.iou_metric.reset() self.mAP_metric.reset() self.precision_recall_metric.reset() elif self.model.task == "point": self.mae_metric.reset() self.precision_recall_metric.reset() elif self.model.task == "polygon": self.mAP_metric.reset()
[docs] def predict_step(self, batch, batch_idx): """Predict a batch of images with the deepforest model. If batch is a list, concatenate the images, predict and then split the results, useful for main.predict_tile. Args: batch (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W). batch_idx (int): The index of the batch. Returns: """ if isinstance(batch, dict): images = batch["images"] metadata = batch.get("metadata") batch_indices = batch.get("batch_indices") if batch_indices is not None: self.original_batch_structure.append(batch_indices) else: images = batch metadata = None self.model.eval() with torch.no_grad(): preds = self.model.forward(images) if metadata is None: return preds return [ self._format_prediction_frame(prediction_result, sample_metadata) for prediction_result, sample_metadata in zip(preds, metadata, strict=True) ]
[docs] def predict_batch(self, images, preprocess_fn=None): """Predict a batch of images with the deepforest model. Args: images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W). preprocess_fn (callable, optional): A function to preprocess images before prediction. If None, assumes images are preprocessed. Returns: List[pd.DataFrame]: A list of dataframes with predictions for each image. """ self.model.eval() # convert to tensor if input is array if isinstance(images, np.ndarray): images = torch.tensor(images, device=self.device) # apply preprocessing if available if preprocess_fn: images = preprocess_fn(images) # using Pytorch Ligthning's predict_step with torch.no_grad(): predictions = self.predict_step(images, 0) # convert predictions to dataframes results = [] for pred in predictions: if len(pred["boxes"]) == 0: continue geom_type = utilities.determine_geometry_type(pred) result = utilities.format_geometry(pred, geom_type=geom_type) results.append(result) return results
[docs] def configure_optimizers(self): opt_cfg = self.config.train.optimizer lr = self.config.train.lr if opt_cfg.type == "Adam": optimizer = optim.Adam( self.model.parameters(), lr=lr, betas=tuple(opt_cfg.betas), weight_decay=opt_cfg.weight_decay, ) elif opt_cfg.type == "AdamW": optimizer = optim.AdamW( self.model.parameters(), lr=lr, betas=tuple(opt_cfg.betas), weight_decay=opt_cfg.weight_decay, ) elif opt_cfg.type == "SGD": optimizer = optim.SGD( self.model.parameters(), lr=lr, momentum=opt_cfg.momentum, weight_decay=opt_cfg.weight_decay, ) else: raise ValueError( f"Unknown optimizer type '{opt_cfg.type}'. Choose from: SGD, Adam, AdamW." ) scheduler_config = self.config.train.scheduler scheduler_type = scheduler_config.type params = scheduler_config.params # Assume the lambda is a function of epoch def lr_lambda(epoch): return eval(params.lr_lambda) if scheduler_type is None or scheduler_type == "constantLR": scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, factor=1.0, total_iters=0 ) elif scheduler_type == "cosine": scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=params.T_max, eta_min=params.eta_min ) elif scheduler_type == "lambdaLR": scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) elif scheduler_type == "multiplicativeLR": scheduler = torch.optim.lr_scheduler.MultiplicativeLR( optimizer, lr_lambda=lr_lambda ) elif scheduler_type == "stepLR": scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=params.step_size, gamma=params.gamma ) elif scheduler_type == "multistepLR": scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=params.milestones, gamma=params.gamma ) elif scheduler_type == "exponentialLR": scheduler = torch.optim.lr_scheduler.ExponentialLR( optimizer, gamma=params.gamma ) elif scheduler_type in ("ReduceLROnPlateau", "reduceLROnPlateau"): scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode=params["mode"], factor=params["factor"], patience=params["patience"], threshold=params["threshold"], threshold_mode=params["threshold_mode"], cooldown=params["cooldown"], min_lr=params["min_lr"], eps=params["eps"], ) else: raise ValueError( f"Unknown scheduler type '{scheduler_type}'. Choose from: " "constantLR, cosine, lambdaLR, multiplicativeLR, stepLR, multistepLR, " "exponentialLR, ReduceLROnPlateau." ) # Monitor learning rate if val data is used if self.config.validation.csv_file is not None: return { "optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.config.validation.lr_plateau_target, } else: return optimizer
[docs] def evaluate( self, csv_file, iou_threshold=None, root_dir=None, ): """Compute intersection-over-union and precision/recall for a given iou_threshold. .. deprecated:: 2.0.0 This method is deprecated. Users should use `trainer.validate()` instead to get evaluation statistics during training. This method will be removed in a future version. Args: csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label" Returns: dict: Results dictionary containing precision, recall and other metrics """ warnings.warn( "deepforest.evaluate() is deprecated and will be removed in a future version. " "Please use trainer.validate() instead to get evaluation statistics during training.", DeprecationWarning, stacklevel=2, ) # Set input csv file to validation csv file self.config.validation.csv_file = csv_file if root_dir is not None: self.config.validation.root_dir = root_dir if iou_threshold is not None: self.config.validation.iou_threshold = iou_threshold self.config.validation.val_accuracy_interval = 1 self.create_trainer() validation_results = self.trainer.validate(self) # Gather predictions from all ranks in multi-GPU settings if self.trainer.world_size > 1: all_predictions = [None] * self.trainer.world_size torch.distributed.all_gather_object(all_predictions, self.predictions) self.predictions = [pred for preds in all_predictions for pred in preds] # Concat prediction dataframes and convert numeric labels to strings if len(self.predictions) > 0: self.predictions = pd.concat(self.predictions, ignore_index=True) if "label" in self.predictions.columns: self.predictions["label"] = self.predictions["label"].map( lambda x: self.numeric_to_label_dict.get(int(x), x) if pd.notna(x) else x ) else: self.predictions = pd.DataFrame() results = {} if isinstance(validation_results, list): if validation_results and isinstance(validation_results[0], dict): results.update(validation_results[0]) elif isinstance(validation_results, dict): results.update(validation_results) results["predictions"] = self.predictions if self.model.task == "box" or self.model.task == "point": results["results"] = self.precision_recall_metric.get_results() return results