# entry point for deepforest model
import importlib
import os
import warnings
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import CSVLogger
from torch import optim
from torchmetrics.classification import BinaryAccuracy
from torchmetrics.detection import IntersectionOverUnion, MeanAveragePrecision
from deepforest import evaluate as evaluate_iou
from deepforest import predict, utilities
from deepforest.datasets import prediction, training
from deepforest.metrics import RecallPrecision
[docs]class deepforest(pl.LightningModule):
"""DeepForest model for tree crown detection in RGB images.
Args:
num_classes: Number of classes in the model
model: DeepForest model object
existing_train_dataloader: PyTorch dataloader for training data
existing_val_dataloader: PyTorch dataloader for validation data
config: DeepForest configuration object or name
config_args: Dictionary of config overrides
"""
def __init__(
self,
model=None,
transforms=None,
existing_train_dataloader=None,
existing_val_dataloader=None,
config: str | dict | DictConfig | None = None,
config_args: dict | None = None,
):
super().__init__()
if config is None:
config = utilities.load_config(overrides=config_args)
# Default/string config name
elif isinstance(config, str):
config = utilities.load_config(config_name=config, overrides=config_args)
# Checkpoint load
elif isinstance(config, dict):
config = OmegaConf.merge(config, config_args or {})
config = utilities.load_config(overrides=config)
# Hub overrides
elif "config_args" in config:
config = utilities.load_config(overrides=config["config_args"])
elif config_args is not None:
warnings.warn(
f"Ignoring options as configuration object was provided: {config_args}",
stacklevel=2,
)
self.config = config
# release version id to flag if release is being used
self.__release_version__ = None
self.existing_train_dataloader = existing_train_dataloader
self.existing_val_dataloader = existing_val_dataloader
self.model = model
self.original_batch_structure = []
if self.model is None:
self.create_model()
# Create a default trainer.
self.create_trainer()
# Add user supplied transforms
if transforms is None:
self.transforms = None
else:
self.transforms = transforms
self.save_hyperparameters(
{"config": OmegaConf.to_container(self.config, resolve=True)}
)
[docs] def setup_metrics(self):
# Guard against initialization before a validation csv_file is set
if not self.config.validation.csv_file:
return
# Metrics
self.iou_metric = IntersectionOverUnion(
class_metrics=True, iou_threshold=self.config.validation.iou_threshold
)
self.mAP_metric = MeanAveragePrecision(backend="faster_coco_eval")
# Empty frame accuracy
self.empty_frame_accuracy = BinaryAccuracy()
self.precision_recall_metric = RecallPrecision(
csv_file=self.config.validation.csv_file,
label_dict=self.label_dict,
)
[docs] def load_model(self, model_name=None, revision=None):
"""Loads a model that has already been pretrained for a specific task,
like tree crown detection.
Models (technically model weights) are distributed via Hugging Face
and designated the Hugging Face repository ID (model_name), which
is in the form: 'organization/repository'. For a list of models distributed
by the DeepForest team (and the associated model names) see the
documentation:
https://deepforest.readthedocs.io/en/latest/installation_and_setup/prebuilt.html
Args:
model_name (str): A repository ID for huggingface in the form of organization/repository
revision (str): The model version ('main', 'v1.0.0', etc.).
Returns:
None
"""
if model_name is None:
model_name = self.config.model.name
if revision is None:
revision = self.config.model.revision
model_class = importlib.import_module(
f"deepforest.models.{self.config.architecture}"
)
self.model = model_class.Model(config=self.config).create_model(
pretrained=model_name, revision=revision
)
self.config.num_classes = self.model.num_classes
self.set_labels(self.model.label_dict)
return
[docs] def set_labels(self, label_dict):
"""Set new label mapping, updating both the label dictionary (str ->
int) and its inverse (int -> str).
Args:
label_dict (dict): Dictionary mapping class names to numeric IDs.
"""
if label_dict is None:
raise ValueError(
"Label dictionary not found. Check it was set in your config file or config_args."
)
if len(label_dict) != self.config.num_classes:
raise ValueError(
f"label_dict {label_dict} does not match requested number of "
f"classes {self.config.num_classes}, please supply a label_dict argument "
'{"label1":0, "label2":1, "label3":2 ... etc} '
"for each label in the "
"dataset"
)
# Check for duplicate values in label_dict:
if len(set(label_dict.values())) != len(label_dict):
raise ValueError("Found duplicate label IDs in label_dict.")
self.label_dict = label_dict
self.numeric_to_label_dict = {v: k for k, v in label_dict.items()}
[docs] def create_model(self, initialize_model=False):
"""Initialize a deepforest architecture. This can be done in two ways.
Passed as the model argument to deepforest __init__(), or as a named
architecture in config.architecture, which corresponds to a file in
models/, as is a subclass of model.Model(). The config args in the
.yaml are specified.
Returns:
None
"""
if self.config.model.name is None or initialize_model:
model_class = importlib.import_module(
f"deepforest.models.{self.config.architecture}"
)
self.model = model_class.Model(config=self.config).create_model()
self.set_labels(self.config.label_dict)
else:
self.load_model()
[docs] def create_trainer(self, logger=None, callbacks=None, **kwargs):
"""Create a pytorch lightning training by reading config files.
Args:
logger: Optional logger
callbacks: Optional list of callbacks
**kwargs: Additional trainer arguments
"""
# Setup metrics which may have changed if the config was modified
self.setup_metrics()
if callbacks is None:
callbacks = []
# Set default logger to use log_root from config instead of lightning_logs
if logger is None:
logger = CSVLogger(save_dir=self.config.log_root, name="")
# If val data is passed, monitor learning rate and setup classification metrics
if self.config.validation.csv_file is not None:
if logger is not None:
lr_monitor = LearningRateMonitor(logging_interval="epoch")
callbacks.append(lr_monitor)
limit_val_batches = 1.0
num_sanity_val_steps = 2
else:
# Disable validation, don't use trainer defaults
limit_val_batches = 0
num_sanity_val_steps = 0
# Check for model checkpoint object
checkpoint_types = [type(x).__qualname__ for x in callbacks]
if "ModelCheckpoint" in checkpoint_types:
enable_checkpointing = True
else:
enable_checkpointing = False
trainer_args = {
"logger": logger,
"max_epochs": self.config.train.epochs,
"enable_checkpointing": enable_checkpointing,
"devices": self.config.devices,
"accelerator": self.config.accelerator,
"fast_dev_run": self.config.train.fast_dev_run,
"callbacks": callbacks,
"limit_val_batches": limit_val_batches,
"num_sanity_val_steps": num_sanity_val_steps,
"default_root_dir": self.config.log_root,
}
# Update with kwargs to allow them to override config
trainer_args.update(kwargs)
self.trainer = pl.Trainer(**trainer_args)
[docs] def on_fit_start(self):
if self.config.train.csv_file is None:
raise AttributeError(
"Cannot train with a train annotations file, "
"please set 'config['train']['csv_file'] before "
"calling deepforest.create_trainer()'"
)
[docs] def on_save_checkpoint(self, checkpoint):
# Update hparams in case they've changed since init
checkpoint["hyper_parameters"]["config"] = OmegaConf.to_container(
self.config, resolve=True
)
checkpoint["label_dict"] = self.label_dict
checkpoint["numeric_to_label_dict"] = self.numeric_to_label_dict
for key in checkpoint:
if isinstance(checkpoint[key], DictConfig):
checkpoint[key] = OmegaConf.to_container(checkpoint[key], resolve=True)
[docs] def on_load_checkpoint(self, checkpoint):
try:
self.label_dict = checkpoint["label_dict"]
self.numeric_to_label_dict = checkpoint["numeric_to_label_dict"]
except KeyError:
print(
"No label_dict found in checkpoint, using default label_dict, "
"please use deepforest.set_labels() to set the label_dict after loading the checkpoint."
)
# Pre 2.0 compatibility, the score_threshold used to be stored under retinanet.score_thresh
try:
self.config.score_thresh = self.config.retinanet.score_thresh
except AttributeError:
pass
if not hasattr(self.config.validation, "lr_plateau_target"):
default_config = utilities.load_config()
self.config.validation.lr_plateau_target = (
default_config.validation.lr_plateau_target
)
if not hasattr(self.config.train, "augmentations"):
default_config = utilities.load_config()
self.config.train.augmentations = default_config.train.augmentations
if not hasattr(self.config.validation, "augmentations"):
default_config = utilities.load_config()
self.config.validation.augmentations = default_config.validation.augmentations
[docs] def save_model(self, path):
"""Save the trainer checkpoint in user defined path, in order to access
in future.
Args:
Path: the path located the model checkpoint
"""
self.trainer.save_checkpoint(path)
[docs] def load_dataset(
self,
csv_file,
root_dir=None,
shuffle=True,
transforms=None,
augmentations=None,
preload_images=False,
batch_size=1,
):
"""Create a dataset for inference or training. Csv file format is .csv
file with the columns "image_path", "xmin","ymin","xmax","ymax" for the
image name and bounding box position. Image_path is the relative
filename, not absolute path, which is in the root_dir directory. One
bounding box per line.
Args:
csv_file: path to csv file
root_dir: directory of images. If none, uses "image_dir" in config
transforms: Albumentations transforms
batch_size: batch size
preload_images: if True, preload the images into memory
augmentations: augmentation configuration (str, list, or dict)
Returns:
ds: a pytorch dataset
"""
ds = training.BoxDataset(
csv_file=csv_file,
root_dir=root_dir,
transforms=transforms,
label_dict=self.label_dict,
augmentations=augmentations,
preload_images=preload_images,
)
if len(ds) == 0:
raise ValueError(
f"Dataset from {csv_file} is empty. Check CSV for valid entries and columns."
)
data_loader = torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=ds.collate_fn,
num_workers=self.config.workers,
)
return data_loader
[docs] def train_dataloader(self):
"""Train loader using the configurations.
Returns:
loader
"""
if self.existing_train_dataloader:
return self.existing_train_dataloader
loader = self.load_dataset(
csv_file=self.config.train.csv_file,
root_dir=self.config.train.root_dir,
augmentations=self.config.train.augmentations,
preload_images=self.config.train.preload_images,
shuffle=True,
transforms=self.transforms,
batch_size=self.config.batch_size,
)
return loader
[docs] def val_dataloader(self):
"""Create a val data loader only if specified in config.
Returns:
a dataloader or a empty iterable.
"""
# The preferred route for skipping validation is now (pl-2.0) an empty list,
# see https://github.com/Lightning-AI/lightning/issues/17154
loader = []
if self.existing_val_dataloader:
return self.existing_val_dataloader
if self.config.validation.csv_file is not None:
loader = self.load_dataset(
csv_file=self.config.validation.csv_file,
root_dir=self.config.validation.root_dir,
augmentations=self.config.validation.augmentations,
shuffle=False,
preload_images=self.config.validation.preload_images,
batch_size=self.config.batch_size,
)
return loader
[docs] def predict_dataloader(self, ds, batch_size=None):
"""Create a PyTorch dataloader for prediction.
Args:
ds (torchvision.datasets.Dataset): A torchvision dataset to be wrapped into a dataloader using config args.
Returns:
torch.utils.data.DataLoader: A dataloader object that can be used for prediction.
"""
if batch_size is None:
batch_size = self.config.batch_size
else:
batch_size = batch_size
loader = torch.utils.data.DataLoader(
ds,
batch_size=batch_size,
shuffle=False,
num_workers=self.config.workers,
collate_fn=ds.collate_fn,
)
return loader
[docs] def predict_image(self, image: np.ndarray | None = None, path: str | None = None):
"""Predict a single image with a deepforest model.
Args:
image: a float32 numpy array of a RGB with channels last format
path: optional path to read image from disk instead of passing image arg
Returns:
result: A pandas dataframe of predictions (Default)
"""
# Ensure we are in eval mode
self.model.eval()
if path:
image = np.array(Image.open(path).convert("RGB")).astype("float32")
# sanity checks on input images
if not isinstance(image, np.ndarray):
raise TypeError(
f"Input image is of type {type(image)}, expected numpy, if reading "
"from PIL, wrap in "
"np.array(image).astype(float32)"
)
if image.dtype != "float32":
warnings.warn(
f"Image type is {image.dtype}, transforming to float32. "
f"This assumes that the range of pixel values is 0-255, as "
f"opposed to 0-1.To suppress this warning, transform image "
f"(image.astype('float32')",
stacklevel=2,
)
image = image.astype("float32")
result = predict._predict_image_(
model=self.model, image=image, path=path, nms_thresh=self.config.nms_thresh
)
# If there were no predictions, return None
if result is None:
return None
else:
result["label"] = result.label.apply(lambda x: self.numeric_to_label_dict[x])
if path is None:
warnings.warn(
"An image was passed directly to predict_image, the result.root_dir attribute "
"will be None in the output dataframe, to use visualize.plot_results, "
"please assign results.root_dir = <directory name>",
stacklevel=2,
)
else:
root_dir = os.path.dirname(path)
result = utilities.read_file(result, root_dir=root_dir)
return result
[docs] def predict_file(
self,
csv_file,
root_dir,
crop_model=None,
):
"""Create a dataset and predict entire annotation file CSV file format
is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax"
for the image name and bounding box position. Image_path is the
relative filename, not absolute path, which is in the root_dir
directory. One bounding box per line.
Args:
csv_file: path to csv file
root_dir: directory of images. If none, uses "image_dir" in config
crop_model: a deepforest.model.CropModel object to predict on crops
size: the size of the image to resize to. Optional, if not provided, the image is not resized.
Returns:
df: pandas dataframe with bounding boxes, label and scores for each image in the csv file
"""
ds = prediction.FromCSVFile(csv_file=csv_file, root_dir=root_dir)
dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size)
results = predict._dataloader_wrapper_(
model=self,
crop_model=crop_model,
trainer=self.trainer,
dataloader=dataloader,
root_dir=root_dir,
)
results.root_dir = root_dir
return results
[docs] def predict_tile(
self,
path=None,
image=None,
patch_size=400,
patch_overlap=0.05,
iou_threshold=0.15,
dataloader_strategy="single",
crop_model=None,
):
"""For images too large to input into the model, predict_tile cuts the
image into overlapping windows, predicts trees on each window and
reassambles into a single array.
Args:
path: Path or list of paths to images on disk. If a single string is provided, it will be converted to a list.
image (array): Numpy image array in BGR channel order following openCV convention. Not possible in combination with dataloader_strategy='batch'.
patch_size: patch size for each window
patch_overlap: patch overlap among windows
iou_threshold: Minimum iou overlap among predictions between windows to be suppressed
dataloader_strategy: "single", "batch", or "window".
- "Single" loads the entire image into memory and passes individual windows to GPU and cannot be parallelized.
- "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images.
- "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows.
crop_model: a deepforest.model.CropModel object to predict on crops
Returns:
pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple
"""
self.model.eval()
self.model.nms_thresh = self.config.nms_thresh
# Check if path or image is provided
if dataloader_strategy == "single":
if path is None and image is None:
raise ValueError(
"Either path or image must be provided for single tile prediction"
)
if dataloader_strategy == "batch":
if path is None:
raise ValueError(
"path argument must be provided when using dataloader_strategy='batch'"
)
# Convert single path to list for consistent handling
if isinstance(path, str):
paths = [path]
elif path is None:
paths = [None]
else:
paths = path
image_results = []
if dataloader_strategy in ["single", "window"]:
for image_path in paths:
if dataloader_strategy == "single":
ds = prediction.SingleImage(
path=image_path,
image=image,
patch_overlap=patch_overlap,
patch_size=patch_size,
)
else:
# Check for workers config when using out of memory dataset
if self.config.workers > 0:
raise ValueError(
"workers must be 0 when using out-of-memory dataset "
"(dataloader_strategy='window'). Set config['workers']=0 and recreate "
"trainer self.create_trainer()."
)
ds = prediction.TiledRaster(
path=image_path,
patch_overlap=patch_overlap,
patch_size=patch_size,
)
dataloader = self.predict_dataloader(ds)
batched_results = self.trainer.predict(self, dataloader)
# Flatten list from batched prediction
# Track global window index across batches
global_window_idx = 0
for _idx, batch in enumerate(batched_results):
for _window_idx, window_result in enumerate(batch):
formatted_result = ds.postprocess(
window_result, global_window_idx
)
image_results.append(formatted_result)
global_window_idx += 1
if not image_results:
results = pd.DataFrame()
else:
results = pd.concat(image_results)
elif dataloader_strategy == "batch":
self.original_batch_structure.clear()
ds = prediction.MultiImage(
paths=paths, patch_overlap=patch_overlap, patch_size=patch_size
)
dataloader = self.predict_dataloader(ds)
batched_results = self.trainer.predict(self, dataloader)
# Flatten list from batched prediction
for idx, batch in enumerate(batched_results):
formatted_result = ds.postprocess(
batch, idx, self.original_batch_structure
)
image_results.append(formatted_result)
if not image_results:
results = pd.DataFrame()
else:
results = pd.concat(image_results)
else:
raise ValueError(f"Invalid dataloader_strategy: {dataloader_strategy}")
if results.empty:
warnings.warn("No predictions made, returning None", stacklevel=2)
return None
# Perform mosaic for each image_path, or all if image_path is None
mosaic_results = []
if results["image_path"].isnull().all():
mosaic_results.append(predict.mosiac(results, iou_threshold=iou_threshold))
else:
for image_path in results["image_path"].unique():
image_results = results[results["image_path"] == image_path]
image_mosaic = predict.mosiac(image_results, iou_threshold=iou_threshold)
image_mosaic["image_path"] = image_path
mosaic_results.append(image_mosaic)
mosaic_results = pd.concat(mosaic_results)
mosaic_results["label"] = mosaic_results.label.apply(
lambda x: self.numeric_to_label_dict[x]
)
if paths[0] is not None:
root_dir = os.path.dirname(paths[0])
else:
print(
"No image path provided, root_dir of the output results dataframe will be None, since either "
"images were directly provided or there were multiple image paths"
)
root_dir = None
if crop_model is not None:
cropmodel_results = []
for path in paths:
image_result = mosaic_results[
mosaic_results.image_path == os.path.basename(path)
]
if image_result.empty:
continue
image_result.root_dir = os.path.dirname(path)
cropmodel_result = predict._crop_models_wrapper_(
crop_model, self.trainer, image_result
)
cropmodel_results.append(cropmodel_result)
cropmodel_results = pd.concat(cropmodel_results)
else:
cropmodel_results = mosaic_results
formatted_results = utilities.__pandas_to_geodataframe__(cropmodel_results)
formatted_results.root_dir = root_dir
return formatted_results
[docs] def training_step(self, batch, batch_idx):
"""Train on a loaded dataset."""
# Confirm model is in train mode
self.model.train()
# allow for empty data if data augmentation is generated
images, targets, image_names = batch
loss_dict = self.model.forward(images, targets)
# sum of regression and classification loss
losses = sum(loss_dict.values())
# Log loss
for key, value in loss_dict.items():
self.log(
f"train_{key}", value.detach(), on_epoch=True, batch_size=len(images)
)
# Log sum of losses
self.log("train_loss", losses.detach(), on_epoch=True, batch_size=len(images))
return losses
[docs] def validation_step(self, batch, batch_idx):
"""Evaluate a batch."""
images, targets, image_names = batch
# Set model to train mode to return loss, but disable optimization.
# Torchvision does not return loss in eval mode.
self.model.train()
with torch.no_grad():
loss_dict = self.model.forward(images, targets)
# sum of regression and classification loss
losses = sum(loss_dict.values())
# Log losses
for key, value in loss_dict.items():
self.log(f"val_{key}", value.detach(), on_epoch=True, batch_size=len(images))
self.log("val_loss", losses.detach(), on_epoch=True, batch_size=len(images))
# In eval model, return predictions to calculate prediction metrics
self.model.eval()
with torch.no_grad():
preds = self.model.forward(images, targets)
if len(targets) > 0:
# Remove empty targets and corresponding predictions
filtered_preds = []
filtered_targets = []
for i, target in enumerate(targets):
# Empty frame accuracy
is_empty_frame = target["boxes"].numel() == 0 or torch.all(
target["boxes"] == 0
)
if is_empty_frame:
# 0 indicates empty frame or predication
device = target["boxes"].device
self.empty_frame_accuracy.update(
torch.tensor([min(len(preds[i]["boxes"]), 1)], device=device),
torch.tensor([0.0], device=device),
)
else:
# Non-empty frames go to all metrics
filtered_preds.append(preds[i])
filtered_targets.append(target)
# IoU and mAP metrics need preds/targets to exist
self.iou_metric.update(filtered_preds, filtered_targets)
self.mAP_metric.update(filtered_preds, filtered_targets)
# Precision recall metric can handle empty frames internally
self.precision_recall_metric.update(preds, image_names)
# Log the predictions if you want to use them for evaluation logs
for i, result in enumerate(preds):
formatted_result = utilities.format_geometry(result)
if formatted_result is not None:
formatted_result["image_path"] = image_names[i]
self.predictions.append(formatted_result)
return losses
[docs] def on_validation_epoch_start(self):
self.predictions = []
[docs] def calculate_empty_frame_accuracy(self, ground_df, predictions_df):
"""Calculate accuracy for empty frames (frames with no objects).
Args:
ground_df (pd.DataFrame): Ground truth dataframe containing image paths and bounding boxes.
Must have columns 'image_path', 'xmin', 'ymin', 'xmax', 'ymax'.
predictions_df (pd.DataFrame): Model predictions dataframe containing image paths and predicted boxes.
Must have column 'image_path'.
Returns:
float or None: Accuracy score for empty frame detection. A score of 1.0 means the model correctly
identified all empty frames (no false positives), while 0.0 means it predicted objects
in all empty frames (all false positives). Returns None if there are no empty frames.
"""
# Find images that are marked as empty in ground truth (all coordinates are 0)
empty_images = ground_df.loc[
(ground_df.xmin == 0)
& (ground_df.ymin == 0)
& (ground_df.xmax == 0)
& (ground_df.ymax == 0),
"image_path",
].unique()
if len(empty_images) == 0:
return None
if predictions_df.empty:
# Empty predictions with empty ground truth = 100% accuracy
empty_accuracy = 1
else:
# Get non-empty predictions for empty images
non_empty_predictions = predictions_df.loc[predictions_df.xmin.notnull()]
predictions_for_empty_images = non_empty_predictions.loc[
non_empty_predictions.image_path.isin(empty_images)
]
# Create prediction tensor - 1 if model predicted objects, 0 if predicted empty
predictions = torch.zeros(len(empty_images))
for index, image in enumerate(empty_images):
if (
len(
predictions_for_empty_images.loc[
predictions_for_empty_images.image_path == image
]
)
> 0
):
predictions[index] = 1
# Ground truth tensor - all zeros since these are empty frames
gt = torch.zeros(len(empty_images))
predictions = torch.tensor(predictions)
# Calculate accuracy using metric
self.empty_frame_accuracy.update(predictions, gt)
empty_accuracy = self.empty_frame_accuracy.compute()
self.empty_frame_accuracy.reset()
# Log empty frame accuracy
self.log("empty_frame_accuracy", empty_accuracy)
return empty_accuracy
def _compute_epoch_metrics(self) -> dict:
"""Compute metrics and returns a Lightning-loggable dictionary.
This function is called automatically at the end of validation.
"""
metrics = {}
# IoU and mAP
if len(self.iou_metric.groundtruth_labels) > 0:
metrics.update(self.iou_metric.compute())
# Lightning bug: claims this is a warning but it's not. See issue #16218 in Lightning-AI/pytorch-lightning
output = self.mAP_metric.compute()
# Remove classes from output dict
output = {key: value for key, value in output.items() if not key == "classes"}
metrics.update(output)
# Box recall/precision
metrics.update(self.precision_recall_metric.compute())
# Empty frame accuracy
if self.empty_frame_accuracy.update_called:
metrics["empty_frame_accuracy"] = self.empty_frame_accuracy.compute()
return metrics
[docs] def on_validation_epoch_end(self):
"""Compute metrics and predictions at the end of the validation
epoch."""
if self.trainer.sanity_checking: # optional skip
return
# Log epoch metrics
if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0:
metrics = self._compute_epoch_metrics()
self.log_dict(metrics)
# Manual reset. Lightning does not do this automatically
# unless we log the metric objects directly
self.precision_recall_metric.reset()
self.iou_metric.reset()
self.mAP_metric.reset()
self.empty_frame_accuracy.reset()
[docs] def predict_step(self, batch, batch_idx):
"""Predict a batch of images with the deepforest model. If batch is a
list, concatenate the images, predict and then split the results,
useful for main.predict_tile.
Args:
batch (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W).
batch_idx (int): The index of the batch.
Returns:
"""
if isinstance(batch, dict):
images = batch["images"]
batch_indices = batch["batch_indices"]
self.original_batch_structure.append(batch_indices)
else:
batch_indices = None
images = batch
self.model.eval()
with torch.no_grad():
preds = self.model.forward(images)
return preds
[docs] def predict_batch(self, images, preprocess_fn=None):
"""Predict a batch of images with the deepforest model.
Args:
images (torch.Tensor or np.ndarray): A batch of images with shape (B, C, H, W).
preprocess_fn (callable, optional): A function to preprocess images before prediction.
If None, assumes images are preprocessed.
Returns:
List[pd.DataFrame]: A list of dataframes with predictions for each image.
"""
self.model.eval()
# convert to tensor if input is array
if isinstance(images, np.ndarray):
images = torch.tensor(images, device=self.device)
# apply preprocessing if available
if preprocess_fn:
images = preprocess_fn(images)
# using Pytorch Ligthning's predict_step
with torch.no_grad():
predictions = self.predict_step(images, 0)
# convert predictions to dataframes
results = []
for pred in predictions:
if len(pred["boxes"]) == 0:
continue
geom_type = utilities.determine_geometry_type(pred)
result = utilities.format_geometry(pred, geom_type=geom_type)
results.append(result)
return results
def __evaluate__(
self,
csv_file,
iou_threshold=None,
root_dir=None,
predictions=None,
):
"""Internal method to compute intersection-over-union and
precision/recall for a given iou_threshold.
Args:
csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label"
iou_threshold: float [0,1] intersection-over-union threshold for true positive
predictions: list of predictions to use for evaluation. If None, predictions are generated from the model.
Returns:
dict: Results dictionary containing precision, recall and other metrics
"""
self.model.eval()
if root_dir is None:
if self.config.validation.root_dir is None:
raise ValueError("root_dir must be specified if not provided in config")
root_dir = self.config.validation.root_dir
ground_df = utilities.read_file(csv_file, root_dir=root_dir)
ground_df["label"] = ground_df.label.apply(lambda x: self.label_dict[x])
if predictions is None:
# Get the predict dataloader and use predict_batch
predictions = self.predict_file(
csv_file,
root_dir,
)
if iou_threshold is None:
iou_threshold = self.config.validation.iou_threshold
results = evaluate_iou.__evaluate_wrapper__(
predictions=predictions,
ground_df=ground_df,
iou_threshold=iou_threshold,
numeric_to_label_dict=self.numeric_to_label_dict,
geometry_type="box",
)
# empty frame accuracy
empty_accuracy = self.calculate_empty_frame_accuracy(ground_df, predictions)
results["empty_frame_accuracy"] = empty_accuracy
return results
[docs] def evaluate(
self,
csv_file,
iou_threshold=None,
root_dir=None,
predictions=None,
):
"""Compute intersection-over-union and precision/recall for a given
iou_threshold.
.. deprecated:: 2.0.0
This method is deprecated. Users should use `trainer.validate()` instead
to get evaluation statistics during training. This method will be removed
in a future version.
Args:
csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label"
iou_threshold: float [0,1] intersection-over-union threshold for true positive
predictions: list of predictions to use for evaluation. If None, predictions are generated from the model.
Returns:
dict: Results dictionary containing precision, recall and other metrics
"""
warnings.warn(
"deepforest.evaluate() is deprecated and will be removed in a future version. "
"Please use trainer.validate() instead to get evaluation statistics during training.",
DeprecationWarning,
stacklevel=2,
)
return self.__evaluate__(
csv_file=csv_file,
iou_threshold=iou_threshold,
root_dir=root_dir,
predictions=predictions,
)