# entry point for deepforest model
import importlib
import os
import warnings
from numbers import Number
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torchmetrics
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from torch import optim
from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision
from deepforest import distributed, predict, utilities
from deepforest.datasets import prediction, training
from deepforest.metrics import RecallPrecision
Image.MAX_IMAGE_PIXELS = None
[docs]class deepforest(pl.LightningModule):
"""DeepForest model for tree crown detection in RGB images.
Args:
num_classes: Number of classes in the model
model: DeepForest model object
existing_train_dataloader: PyTorch dataloader for training data
existing_val_dataloader: PyTorch dataloader for validation data
config: DeepForest configuration object or name
config_args: Dictionary of config overrides
"""
def __init__(
self,
model=None,
transforms=None,
existing_train_dataloader=None,
existing_val_dataloader=None,
config: str | dict | DictConfig | None = None,
config_args: dict | None = None,
):
super().__init__()
if config is None:
config = utilities.load_config(overrides=config_args)
# Default/string config name
elif isinstance(config, str):
config = utilities.load_config(config_name=config, overrides=config_args)
# Checkpoint load
elif isinstance(config, dict):
config = OmegaConf.merge(config, config_args or {})
config = utilities.load_config(overrides=config)
# Hub overrides
elif "config_args" in config:
config = utilities.load_config(overrides=config["config_args"])
elif config_args is not None:
warnings.warn(
f"Ignoring options as configuration object was provided: {config_args}",
stacklevel=2,
)
self.config = config
# release version id to flag if release is being used
self.__release_version__ = None
self.existing_train_dataloader = existing_train_dataloader
self.existing_val_dataloader = existing_val_dataloader
self.model = model
self.original_batch_structure = []
if self.model is None:
self.create_model()
# Create a default trainer.
self.create_trainer()
# Add user supplied transforms
if transforms is None:
self.transforms = None
else:
self.transforms = transforms
self.save_hyperparameters(
{"config": OmegaConf.to_container(self.config, resolve=True)}
)
[docs] def setup_metrics(self):
# Guard against initialization before a validation csv_file is set
if not self.config.validation.csv_file and self.existing_val_dataloader is None:
return
# Box Metrics
if self.model.task == "box":
self.iou_metric = IntersectionOverUnion(
class_metrics=True, iou_threshold=self.config.validation.iou_threshold
)
self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval")
self.precision_recall_metric = RecallPrecision(
iou_threshold=self.config.validation.iou_threshold,
label_dict=self.label_dict,
task=self.model.task,
)
# Point Metrics
elif self.model.task == "point":
self.mae_metric = torchmetrics.MeanAbsoluteError()
self.precision_recall_metric = RecallPrecision(
distance_threshold=self.config.point.distance_threshold,
label_dict=self.label_dict,
task=self.model.task,
)
# Segmentation metrics
if self.model.task == "polygon":
self.mAP_metric = MeanAveragePrecision(
iou_type="segm", backend="faster_coco_eval"
)
[docs] def load_model(self, model_name=None, revision=None):
"""Loads a model that has already been pretrained for a specific task,
like tree crown detection.
Models (technically model weights) are distributed via Hugging Face
and designated the Hugging Face repository ID (model_name), which
is in the form: 'organization/repository'. For a list of models distributed
by the DeepForest team (and the associated model names) see the
documentation:
https://deepforest.readthedocs.io/en/stable/user_guide/02_prebuilt.html
Args:
model_name (str): A repository ID for huggingface in the form of organization/repository
revision (str): The model version ('main', 'v1.0.0', etc.).
Returns:
None
"""
if model_name is None:
model_name = self.config.model.name
if revision is None:
revision = self.config.model.revision
model_class = importlib.import_module(
f"deepforest.models.{self.config.architecture}"
)
self.model = model_class.Model(config=self.config).create_model(
pretrained=model_name, revision=revision
)
self.config.num_classes = self.model.num_classes
self.set_labels(self.model.label_dict)
return
[docs] def set_labels(self, label_dict):
"""Set new label mapping, updating both the label dictionary (str ->
int) and its inverse (int -> str).
Args:
label_dict (dict): Dictionary mapping class names to numeric IDs.
"""
if label_dict is None:
raise ValueError(
"Label dictionary not found. Check it was set in your config file or config_args."
)
if len(label_dict) != self.config.num_classes:
raise ValueError(
f"label_dict {label_dict} does not match requested number of "
f"classes {self.config.num_classes}, please supply a label_dict argument "
'{"label1":0, "label2":1, "label3":2 ... etc} '
"for each label in the "
"dataset"
)
# Check for duplicate values in label_dict:
if len(set(label_dict.values())) != len(label_dict):
raise ValueError("Found duplicate label IDs in label_dict.")
self.label_dict = label_dict
self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
[docs] def create_model(self, initialize_model=False):
"""Initialize a deepforest architecture.
This can be done in two ways. Passed as the model argument to deepforest __init__(), or as a named
architecture in config.architecture, which corresponds to a file in
models/, as is a subclass of model.Model(). The config args in the
.yaml are specified.
Returns:
None
"""
if self.config.model.name is None or initialize_model:
model_class = importlib.import_module(
f"deepforest.models.{self.config.architecture}"
)
self.model = model_class.Model(config=self.config).create_model()
self.set_labels(self.config.label_dict)
else:
self.load_model()
[docs] def create_trainer(self, logger=None, callbacks=None, **kwargs):
"""Create a pytorch lightning training by reading config files.
Args:
logger: Optional logger
callbacks: Optional list of callbacks
**kwargs: Additional trainer arguments
"""
# Setup metrics which may have changed if the config was modified
self.setup_metrics()
if callbacks is None:
callbacks = []
# Set default logger to use log_root from config instead of lightning_logs
if logger is None:
logger = CSVLogger(save_dir=self.config.log_root, name="")
# If val data is passed, monitor learning rate and setup classification metrics
has_val = (
self.config.validation.csv_file is not None
or self.existing_val_dataloader is not None
)
if has_val:
if logger is not None:
lr_monitor = LearningRateMonitor(logging_interval="epoch")
callbacks.append(lr_monitor)
limit_val_batches = 1.0
num_sanity_val_steps = 2
else:
# Disable validation, don't use trainer defaults
limit_val_batches = 0
num_sanity_val_steps = 0
# Check for model checkpoint object
checkpoint_types = [type(x).__qualname__ for x in callbacks]
if "ModelCheckpoint" in checkpoint_types:
enable_checkpointing = True
else:
enable_checkpointing = False
if torch.cuda.is_available():
torch.set_float32_matmul_precision(self.config.matmul_precision)
trainer_args = {
"logger": logger,
"max_epochs": self.config.train.epochs,
"enable_checkpointing": enable_checkpointing,
"devices": self.config.devices,
"accelerator": self.config.accelerator,
"num_nodes": self.config.num_nodes,
"strategy": self.config.strategy,
"precision": self.config.precision,
"sync_batchnorm": self.config.sync_batchnorm,
"use_distributed_sampler": self.config.use_distributed_sampler,
"fast_dev_run": self.config.train.fast_dev_run,
"callbacks": callbacks,
"limit_val_batches": limit_val_batches,
"num_sanity_val_steps": num_sanity_val_steps,
"default_root_dir": self.config.log_root,
}
if self.config.precision is not None:
trainer_args["precision"] = self.config.precision
# Update with kwargs to allow them to override config
trainer_args.update(kwargs)
self.trainer = pl.Trainer(**trainer_args)
def _format_prediction_frame(self, prediction_result, metadata):
"""Format a raw model prediction with dataset metadata."""
formatted_result = utilities.format_geometry(prediction_result)
if formatted_result is None:
return pd.DataFrame()
window_bounds = metadata.get("window_bounds")
if window_bounds is not None:
formatted_result["window_xmin"] = window_bounds[0]
formatted_result["window_ymin"] = window_bounds[1]
formatted_result["image_path"] = metadata.get("image_path")
return formatted_result.reset_index(drop=True)
def _gather_prediction_frames(self, frames):
"""Gather prediction frames across ranks."""
if not frames:
return distributed.gather_dataframe(pd.DataFrame())
local_frames = [frame for frame in frames if not frame.empty]
if not local_frames:
return distributed.gather_dataframe(frames[0].iloc[0:0].copy())
local_results = pd.concat(local_frames, ignore_index=True)
return distributed.gather_dataframe(local_results)
[docs] def on_fit_start(self):
if self.config.train.csv_file is None and self.existing_train_dataloader is None:
raise AttributeError(
"Cannot train without a train annotations file "
"or existing_train_dataloader. Please set "
"'config.train.csv_file' or pass "
"existing_train_dataloader before "
"calling deepforest.create_trainer()'"
)
[docs] def on_save_checkpoint(self, checkpoint):
# Update hparams in case they've changed since init
checkpoint["hyper_parameters"]["config"] = OmegaConf.to_container(
self.config, resolve=True
)
checkpoint["label_dict"] = self.label_dict
checkpoint["numeric_to_label_dict"] = self.numeric_to_label_dict
for key in checkpoint:
if isinstance(checkpoint[key], DictConfig):
checkpoint[key] = OmegaConf.to_container(checkpoint[key], resolve=True)
[docs] def on_load_checkpoint(self, checkpoint):
try:
self.label_dict = checkpoint["label_dict"]
self.numeric_to_label_dict = checkpoint["numeric_to_label_dict"]
except KeyError:
print(
"No label_dict found in checkpoint, using default label_dict, "
"please use deepforest.set_labels() to set the label_dict after loading the checkpoint."
)
# Pre 2.0 compatibility, the score_threshold used to be stored under retinanet.score_thresh
try:
self.config.score_thresh = self.config.retinanet.score_thresh
except AttributeError:
pass
if not hasattr(self.config.validation, "lr_plateau_target"):
default_config = utilities.load_config()
self.config.validation.lr_plateau_target = (
default_config.validation.lr_plateau_target
)
if not hasattr(self.config.train, "augmentations"):
default_config = utilities.load_config()
self.config.train.augmentations = default_config.train.augmentations
if not hasattr(self.config.validation, "augmentations"):
default_config = utilities.load_config()
self.config.validation.augmentations = default_config.validation.augmentations
[docs] def save_model(self, path):
"""Save the trainer checkpoint in user defined path, in order to access
in future.
Args:
Path: the path located the model checkpoint
"""
self.trainer.save_checkpoint(path)
[docs] def load_dataset(
self,
csv_file,
root_dir=None,
shuffle=True,
transforms=None,
augmentations=None,
preload_images=False,
validate_coordinates=True,
batch_size=1,
):
"""Create a dataset for inference or training.
Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the
image name and bounding box position. Image_path is the relative
filename, not absolute path, which is in the root_dir directory. One
bounding box per line.
Args:
csv_file: path to csv file
root_dir: directory of images. If none, uses "image_dir" in config
transforms: Albumentations transforms
batch_size: batch size
preload_images: if True, preload the images into memory
validate_coordinates: if True, check annotation coordinates fall within image bounds
augmentations: augmentation configuration (str, list, or dict)
Returns:
ds: a pytorch dataset
"""
if self.model.task == "box":
ds = training.BoxDataset(
csv_file=csv_file,
root_dir=root_dir,
transforms=transforms,
label_dict=self.label_dict,
augmentations=augmentations,
preload_images=preload_images,
validate_coordinates=validate_coordinates,
)
elif self.model.task == "point":
ds = training.PointDataset(
csv_file=csv_file,
root_dir=root_dir,
transforms=transforms,
label_dict=self.label_dict,
augmentations=augmentations,
preload_images=preload_images,
validate_coordinates=validate_coordinates,
)
elif self.model.task == "polygon":
ds = training.PolygonDataset(
csv_file=csv_file,
root_dir=root_dir,
transforms=transforms,
label_dict=self.label_dict,
augmentations=augmentations,
preload_images=preload_images,
validate_coordinates=validate_coordinates,
)
else:
raise ValueError(
f"Invalid task type: {self.model.task}, expected 'box', 'point' or 'polygon'"
)
if len(ds) == 0:
raise ValueError(
f"Dataset from {csv_file} is empty. Check CSV for valid entries and columns."
)
data_loader = torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=ds.collate_fn,
num_workers=self.config.workers,
)
return data_loader
[docs] def train_dataloader(self):
"""Train loader using the configurations.
Returns:
loader
"""
if self.existing_train_dataloader:
return self.existing_train_dataloader
loader = self.load_dataset(
csv_file=self.config.train.csv_file,
root_dir=self.config.train.root_dir,
augmentations=self.config.train.augmentations,
preload_images=self.config.train.preload_images,
validate_coordinates=self.config.train.validate_coordinates,
shuffle=True,
transforms=self.transforms,
batch_size=self.config.batch_size,
)
return loader
[docs] def val_dataloader(self):
"""Create a val data loader only if specified in config.
Returns:
a dataloader or a empty iterable.
"""
# The preferred route for skipping validation is now (pl-2.0) an empty list,
# see https://github.com/Lightning-AI/lightning/issues/17154
loader = []
if self.existing_val_dataloader:
return self.existing_val_dataloader
if self.config.validation.csv_file is not None:
loader = self.load_dataset(
csv_file=self.config.validation.csv_file,
root_dir=self.config.validation.root_dir,
augmentations=self.config.validation.augmentations,
shuffle=False,
preload_images=self.config.validation.preload_images,
validate_coordinates=self.config.validation.validate_coordinates,
batch_size=self.config.batch_size,
)
return loader
[docs] def predict_dataloader(self, ds, batch_size=None):
"""Create a PyTorch dataloader for prediction.
Args:
ds (torchvision.datasets.Dataset): A torchvision dataset to be wrapped into a dataloader using config args.
Returns:
torch.utils.data.DataLoader: A dataloader object that can be used for prediction.
"""
if batch_size is None:
batch_size = self.config.batch_size
else:
batch_size = batch_size
sampler = None
if (
self.config.use_distributed_sampler
and distributed.is_distributed()
and len(ds) < distributed.get_world_size()
):
rank = distributed.get_rank()
local_indices = [rank] if rank < len(ds) else []
sampler = distributed.FixedOrderSampler(local_indices)
loader = torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
shuffle=False,
sampler=sampler,
num_workers=self.config.workers,
collate_fn=ds.collate_fn,
pin_memory=self.config.predict.pin_memory,
)
return loader
[docs] def predict_image(self, image: np.ndarray | None = None, path: str | None = None):
"""Predict a single image with a deepforest model.
Args:
image: a float32 numpy array of a RGB with channels last format
path: optional path to read image from disk instead of passing image arg
Returns:
result: A pandas dataframe of predictions (Default)
"""
# Ensure we are in eval mode
self.model.eval()
if path:
image = np.array(Image.open(path).convert("RGB")).astype("float32")
# sanity checks on input images
if not isinstance(image, np.ndarray):
raise TypeError(
f"Input image is of type {type(image)}, expected numpy, if reading "
"from PIL, wrap in "
"np.array(image).astype(float32)"
)
if image.dtype != "float32":
warnings.warn(
f"Image type is {image.dtype}, transforming to float32. "
f"This assumes that the range of pixel values is 0-255, as "
f"opposed to 0-1.To suppress this warning, transform image "
f"(image.astype('float32')",
stacklevel=2,
)
image = image.astype("float32")
result = predict._predict_image_(
model=self.model,
image=image,
path=path,
iou_threshold=self.config.nms_thresh,
nms_distance_thresh=self.config.point.nms_distance_thresh,
)
# If there were no predictions, return None
if result is None:
return None
else:
result["label"] = result.label.apply(lambda x: self.numeric_to_label_dict[x])
if path is None:
warnings.warn(
"An image was passed directly to predict_image, the result.root_dir attribute "
"will be None in the output dataframe, to use visualize.plot_results, "
"please assign results.root_dir = <directory name>",
stacklevel=2,
)
else:
root_dir = os.path.dirname(path)
result = utilities.read_file(result, root_dir=root_dir)
return result
[docs] def predict_file(
self,
csv_file,
root_dir,
crop_model=None,
):
"""Create a dataset and predict entire annotation file CSV file format
is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax"
for the image name and bounding box position. Image_path is the
relative filename, not absolute path, which is in the root_dir
directory. One bounding box per line.
Args:
csv_file: path to csv file
root_dir: directory of images. If none, uses "image_dir" in config
crop_model: a deepforest.model.CropModel object to predict on crops
size: the size of the image to resize to. Optional, if not provided, the image is not resized.
Returns:
df: pandas dataframe with bounding boxes, label and scores for each image in the csv file
"""
ds = prediction.FromCSVFile(
csv_file=csv_file, root_dir=root_dir, return_metadata=True
)
dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size)
results = predict._dataloader_wrapper_(
model=self,
crop_model=crop_model,
trainer=self.trainer,
dataloader=dataloader,
root_dir=root_dir,
)
results.root_dir = root_dir
return results
[docs] def predict_tile(
self,
path=None,
image=None,
patch_size=400,
patch_overlap=0.05,
iou_threshold=0.15,
dataloader_strategy="single",
crop_model=None,
):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Args:
path: Path or list of paths to images on disk. If a single string is provided, it will be converted to a list.
image (array): Numpy image array in BGR channel order following openCV convention. Not possible in combination with dataloader_strategy='batch'.
patch_size: patch size for each window
patch_overlap: patch overlap among windows
iou_threshold: Minimum iou overlap among predictions between windows to be suppressed
dataloader_strategy: "single", "batch", or "window".
- "Single" loads the entire image into memory and passes individual windows to GPU and cannot be parallelized.
- "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images.
- "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows.
crop_model: a deepforest.model.CropModel object to predict on crops
Returns:
pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple
"""
self.model.eval()
self.model.nms_thresh = self.config.nms_thresh
# Check if path or image is provided
if dataloader_strategy == "single":
if path is None and image is None:
raise ValueError(
"Either path or image must be provided for single tile prediction"
)
if dataloader_strategy == "batch":
if path is None:
raise ValueError(
"path argument must be provided when using dataloader_strategy='batch'"
)
# Convert single path to list for consistent handling
if isinstance(path, str):
paths = [path]
elif path is None:
paths = [None]
else:
paths = path
image_results = []
if dataloader_strategy in ["single", "window"]:
for image_path in paths:
if dataloader_strategy == "single":
ds = prediction.SingleImage(
path=image_path,
image=image,
patch_overlap=patch_overlap,
patch_size=patch_size,
return_metadata=True,
)
else:
# Check for workers config when using out of memory dataset
if self.config.workers > 0:
raise ValueError(
"workers must be 0 when using out-of-memory dataset "
"(dataloader_strategy='window'). Set config['workers']=0 and recreate "
"trainer self.create_trainer()."
)
ds = prediction.TiledRaster(
path=image_path,
patch_overlap=patch_overlap,
patch_size=patch_size,
return_metadata=True,
)
dataloader = self.predict_dataloader(ds)
batched_results = self.trainer.predict(self, dataloader)
image_results.append(
predict._flatten_prediction_batches_(batched_results)
)
if not image_results:
results = pd.DataFrame()
else:
results = self._gather_prediction_frames(image_results)
elif dataloader_strategy == "batch":
self.original_batch_structure.clear()
ds = prediction.MultiImage(
paths=paths,
patch_overlap=patch_overlap,
patch_size=patch_size,
return_metadata=True,
)
dataloader = self.predict_dataloader(ds)
batched_results = self.trainer.predict(self, dataloader)
image_results.append(predict._flatten_prediction_batches_(batched_results))
if not image_results:
results = pd.DataFrame()
else:
results = self._gather_prediction_frames(image_results)
else:
raise ValueError(f"Invalid dataloader_strategy: {dataloader_strategy}")
if results.empty:
warnings.warn("No predictions made, returning None", stacklevel=2)
return None
# Perform mosaic for each image_path, or all if image_path is None
mosaic_results = []
if results["image_path"].isnull().all():
mosaic_results.append(
predict.mosaic(
results,
iou_threshold=iou_threshold,
nms_distance_thresh=self.config.point.nms_distance_thresh,
)
)
else:
for image_path in results["image_path"].unique():
image_results = results[results["image_path"] == image_path]
image_mosaic = predict.mosaic(
image_results,
iou_threshold=iou_threshold,
nms_distance_thresh=self.config.point.nms_distance_thresh,
)
image_mosaic["image_path"] = image_path
mosaic_results.append(image_mosaic)
mosaic_results = pd.concat(mosaic_results)
mosaic_results["label"] = mosaic_results.label.apply(
lambda x: self.numeric_to_label_dict.get(x, x)
)
if paths[0] is not None:
root_dir = os.path.dirname(paths[0])
else:
print(
"No image path provided, root_dir of the output results dataframe will be None, since either "
"images were directly provided or there were multiple image paths"
)
root_dir = None
if crop_model is not None:
cropmodel_results = []
for path in paths:
image_result = mosaic_results[
mosaic_results.image_path == os.path.basename(path)
]
if image_result.empty:
continue
image_result.root_dir = os.path.dirname(path)
cropmodel_result = predict._crop_models_wrapper_(
crop_model, self.trainer, image_result
)
cropmodel_results.append(cropmodel_result)
cropmodel_results = pd.concat(cropmodel_results)
else:
cropmodel_results = mosaic_results
formatted_results = utilities.__pandas_to_geodataframe__(cropmodel_results)
formatted_results.root_dir = root_dir
return formatted_results
[docs] def training_step(self, batch, batch_idx):
"""Train on a loaded dataset."""
# Confirm model is in train mode
self.model.train()
# allow for empty data if data augmentation is generated
images, targets, image_names = batch
loss_dict = self.model.forward(images, targets)
# sum of regression and classification loss
losses = sum(loss_dict.values())
# Log loss
for key, value in loss_dict.items():
self.log(
f"train_{key}",
value.detach(),
on_epoch=True,
batch_size=len(images),
sync_dist=distributed.should_sync(self.trainer),
)
# Log sum of losses
self.log(
"train_loss",
losses.detach(),
on_epoch=True,
batch_size=len(images),
sync_dist=distributed.should_sync(self.trainer),
)
return losses
[docs] def validation_step(self, batch, batch_idx):
"""Evaluate a batch."""
images, targets, image_names = batch
# Set model to train mode to return loss, but disable optimization.
# Torchvision does not return loss in eval mode.
self.model.train()
with torch.no_grad():
loss_dict = self.model.forward(images, targets)
# sum of regression and classification loss
losses = sum(loss_dict.values())
# Log losses
for key, value in loss_dict.items():
self.log(
f"val_{key}",
value.detach(),
on_epoch=True,
batch_size=len(images),
sync_dist=distributed.should_sync(self.trainer),
)
self.log(
"val_loss",
losses.detach(),
on_epoch=True,
batch_size=len(images),
sync_dist=distributed.should_sync(self.trainer),
)
# In eval model, return predictions to calculate prediction metrics
self.model.eval()
with torch.no_grad():
preds = self.model.forward(images, targets)
# Compute precision, recall and empty frame metrics.
if self.model.task == "box" or self.model.task == "point":
self.precision_recall_metric.update(preds, targets, image_names)
if self.model.task == "box":
# Filter out empty frames for IoU/mAP metrics. pred + target
non_empty_pred = []
non_empty_target = []
for pred, target in zip(preds, targets, strict=True):
if not (target["boxes"].numel() == 0 or torch.all(target["boxes"] == 0)):
non_empty_pred.append(pred)
non_empty_target.append(target)
self.iou_metric.update(non_empty_pred, non_empty_target)
self.mAP_metric.update(non_empty_pred, non_empty_target)
elif self.model.task == "point":
device = targets[0]["points"].device if targets else torch.device("cpu")
pred_counts = torch.tensor(
[float(len(p["points"])) for p in preds],
dtype=torch.float32,
device=device,
)
true_counts = torch.tensor(
[float(len(t["points"])) for t in targets],
dtype=torch.float32,
device=device,
)
self.mae_metric.update(pred_counts, true_counts)
elif self.model.task == "polygon":
non_empty_pred = []
non_empty_target = []
for pred, target in zip(preds, targets, strict=True):
if "masks" in pred:
masks = pred["masks"].squeeze(1)
if masks.dtype.is_floating_point:
masks = masks > 0.5
pred["masks"] = masks.to(torch.uint8)
if "masks" in target:
target["masks"] = target["masks"].to(torch.uint8)
if len(target["labels"] > 0):
non_empty_pred.append(pred)
non_empty_target.append(target)
self.mAP_metric.update(non_empty_pred, non_empty_target)
# Log the predictions if you want to use them for evaluation logs
for i, result in enumerate(preds):
formatted_result = utilities.format_geometry(result)
if formatted_result is not None:
formatted_result["image_path"] = image_names[i]
self.predictions.append(formatted_result)
return losses
[docs] def on_validation_epoch_start(self):
self.predictions = []
def _compute_epoch_metrics(self) -> dict:
"""Compute metrics and returns a Lightning-loggable dictionary.
This function is called automatically at the end of validation.
"""
metrics = {}
if self.model.task == "box":
# IoU and mAP
if len(self.iou_metric.groundtruth_labels) > 0:
metrics.update(self.iou_metric.compute())
# Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning
output = self.mAP_metric.compute()
# Remove classes from output dict
output = {
key: value for key, value in output.items() if not key == "classes"
}
metrics.update(output)
metrics.update(self.precision_recall_metric.compute())
elif self.model.task == "point":
metrics["val_mae"] = self.mae_metric.compute()
metrics.update(self.precision_recall_metric.compute())
return metrics
def _prepare_metrics_for_sync(self, metrics: dict) -> dict:
"""Move scalar metrics onto the module device for NCCL reduction."""
synced_metrics = {}
for key, value in metrics.items():
if isinstance(value, torch.Tensor):
synced_metrics[key] = value.to(self.device)
elif isinstance(value, Number):
synced_metrics[key] = torch.tensor(value, device=self.device)
else:
synced_metrics[key] = value
return synced_metrics
[docs] def on_validation_epoch_end(self):
"""Compute metrics and predictions at the end of the validation
epoch."""
if self.trainer.sanity_checking: # optional skip
return
gathered_predictions = self._gather_prediction_frames(self.predictions)
self.predictions = (
[gathered_predictions] if not gathered_predictions.empty else []
)
# Log epoch metrics
if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0:
metrics = self._compute_epoch_metrics()
should_sync = distributed.should_sync(self.trainer)
if should_sync:
metrics = self._prepare_metrics_for_sync(metrics)
self.log_dict(metrics, sync_dist=should_sync)
# Manual reset. Lightning does not do this automatically
# unless we log the metric objects directly
if self.model.task == "box":
self.iou_metric.reset()
self.mAP_metric.reset()
self.precision_recall_metric.reset()
elif self.model.task == "point":
self.mae_metric.reset()
self.precision_recall_metric.reset()
elif self.model.task == "polygon":
self.mAP_metric.reset()
[docs] def predict_step(self, batch, batch_idx):
"""Predict a batch of images with the deepforest model.
If batch is a list, concatenate the images, predict and then split the results,
useful for main.predict_tile.
Args:
batch (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W).
batch_idx (int): The index of the batch.
Returns:
"""
if isinstance(batch, dict):
images = batch["images"]
metadata = batch.get("metadata")
batch_indices = batch.get("batch_indices")
if batch_indices is not None:
self.original_batch_structure.append(batch_indices)
else:
images = batch
metadata = None
self.model.eval()
with torch.no_grad():
preds = self.model.forward(images)
if metadata is None:
return preds
return [
self._format_prediction_frame(prediction_result, sample_metadata)
for prediction_result, sample_metadata in zip(preds, metadata, strict=True)
]
[docs] def predict_batch(self, images, preprocess_fn=None):
"""Predict a batch of images with the deepforest model.
Args:
images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W).
preprocess_fn (callable, optional): A function to preprocess images before prediction.
If None, assumes images are preprocessed.
Returns:
List[pd.DataFrame]: A list of dataframes with predictions for each image.
"""
self.model.eval()
# convert to tensor if input is array
if isinstance(images, np.ndarray):
images = torch.tensor(images, device=self.device)
# apply preprocessing if available
if preprocess_fn:
images = preprocess_fn(images)
# using Pytorch Ligthning's predict_step
with torch.no_grad():
predictions = self.predict_step(images, 0)
# convert predictions to dataframes
results = []
for pred in predictions:
if len(pred["boxes"]) == 0:
continue
geom_type = utilities.determine_geometry_type(pred)
result = utilities.format_geometry(pred, geom_type=geom_type)
results.append(result)
return results
[docs] def evaluate(
self,
csv_file,
iou_threshold=None,
root_dir=None,
):
"""Compute intersection-over-union and precision/recall for a given
iou_threshold.
.. deprecated:: 2.0.0
This method is deprecated. Users should use `trainer.validate()` instead
to get evaluation statistics during training. This method will be removed
in a future version.
Args:
csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label"
Returns:
dict: Results dictionary containing precision, recall and other metrics
"""
warnings.warn(
"deepforest.evaluate() is deprecated and will be removed in a future version. "
"Please use trainer.validate() instead to get evaluation statistics during training.",
DeprecationWarning,
stacklevel=2,
)
# Set input csv file to validation csv file
self.config.validation.csv_file = csv_file
if root_dir is not None:
self.config.validation.root_dir = root_dir
if iou_threshold is not None:
self.config.validation.iou_threshold = iou_threshold
self.config.validation.val_accuracy_interval = 1
self.create_trainer()
validation_results = self.trainer.validate(self)
# Gather predictions from all ranks in multi-GPU settings
if self.trainer.world_size > 1:
all_predictions = [None] * self.trainer.world_size
torch.distributed.all_gather_object(all_predictions, self.predictions)
self.predictions = [pred for preds in all_predictions for pred in preds]
# Concat prediction dataframes and convert numeric labels to strings
if len(self.predictions) > 0:
self.predictions = pd.concat(self.predictions, ignore_index=True)
if "label" in self.predictions.columns:
self.predictions["label"] = self.predictions["label"].map(
lambda x: self.numeric_to_label_dict.get(int(x), x)
if pd.notna(x)
else x
)
else:
self.predictions = pd.DataFrame()
results = {}
if isinstance(validation_results, list):
if validation_results and isinstance(validation_results[0], dict):
results.update(validation_results[0])
elif isinstance(validation_results, dict):
results.update(validation_results)
results["predictions"] = self.predictions
if self.model.task == "box" or self.model.task == "point":
results["results"] = self.precision_recall_metric.get_results()
return results