"""DeepForest callback for logging images during training.
Callbacks must implement on_epoch_begin, on_epoch_end, on_fit_end,
on_fit_begin methods and inject model and epoch kwargs.
"""
import json
import os
import warnings
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import supervision as sv
import torch
from PIL import Image
from pytorch_lightning import Callback
from deepforest import utilities, visualize
from deepforest.datasets.training import BoxDataset
[docs]class ImagesCallback(Callback):
"""Log evaluation images during training.
Args:
save_dir: Directory to save predicted images
n: Number of images to process
every_n_epochs: Run interval in epochs
select_random: Whether to select random images
color: Bounding box color as BGR tuple
thickness: Border line thickness in pixels
"""
def __init__(
self,
save_dir,
prediction_samples=2,
dataset_samples=5,
every_n_epochs=5,
select_random=False,
color=None,
thickness=2,
):
self.savedir = save_dir
self.prediction_samples = prediction_samples
self.dataset_samples = dataset_samples
self.color = color
self.thickness = thickness
self.select_random = select_random
self.every_n_epochs = every_n_epochs
[docs] def on_train_start(self, trainer, pl_module):
"""Log sample images from training and validation datasets at training
start."""
if trainer.fast_dev_run:
return
self.trainer = trainer
self.pl_module = pl_module
# Training samples
pl_module.print("Logging training dataset samples.")
train_ds = trainer.train_dataloader.dataset
self._log_dataset_sample(train_ds, split="train")
# Validation samples
if trainer.val_dataloaders:
pl_module.print("Logging validation dataset samples.")
val_ds = trainer.val_dataloaders.dataset
self._log_dataset_sample(val_ds, split="validation")
[docs] def on_validation_end(self, trainer, pl_module):
"""Run callback at validation end."""
if trainer.sanity_checking or trainer.fast_dev_run:
return
if (trainer.current_epoch + 1) % self.every_n_epochs == 0:
pl_module.print("Logging prediction samples")
self._log_last_predictions(trainer, pl_module)
def _log_dataset_sample(self, dataset: BoxDataset, split: str):
"""Log random samples from a DeepForest BoxDataset."""
if self.dataset_samples == 0:
return
out_dir = os.path.join(self.savedir, split + "_sample")
os.makedirs(out_dir, exist_ok=True)
n_samples = min(self.dataset_samples, len(dataset))
sample_indices = torch.randperm(len(dataset))[:n_samples]
sample_data = [dataset[idx] for idx in sample_indices]
sample_images = [data[0] for data in sample_data]
sample_targets = [data[1] for data in sample_data]
sample_paths = [data[2] for data in sample_data]
for image, target, path in zip(
sample_images, sample_targets, sample_paths, strict=False
):
image_annotations = target.copy()
image_annotations = utilities.format_geometry(image_annotations, scores=False)
basename = Path(path).stem
image = (255 * image.cpu().numpy().transpose((1, 2, 0))).astype(np.uint8)
out_path = os.path.join(out_dir, basename + ".png")
if image_annotations is not None:
image_annotations.root_dir = dataset.root_dir
image_annotations["image_path"] = path
# Plot transformed image
fig = visualize.plot_annotations(
image=image,
annotations=image_annotations,
savedir=out_dir,
basename=basename,
thickness=self.thickness,
show=False,
)
plt.close(fig)
else:
# Save un-annotated image
Image.fromarray(image).save(out_path)
self._log_to_all(
image=out_path,
trainer=self.trainer,
tag=f"{split} dataset sample",
)
def _log_last_predictions(self, trainer, pl_module):
"""Log sample of predictions + targets from last validation."""
if self.prediction_samples == 0:
return
if len(pl_module.predictions) > 0:
df = pd.concat(pl_module.predictions)
else:
df = pd.DataFrame()
out_dir = os.path.join(self.savedir, "predictions")
os.makedirs(out_dir, exist_ok=True)
dataset = trainer.val_dataloaders.dataset
# Add root_dir to the dataframe
if "root_dir" not in df.columns:
df["root_dir"] = dataset.root_dir
# Limit to n images, potentially randomly selected
if self.select_random:
selected_images = np.random.choice(
df.image_path.unique(), self.prediction_samples
)
else:
selected_images = df.image_path.unique()[: self.prediction_samples]
# Ensure color is correctly assigned
if self.color is None:
num_classes = len(df["label"].unique())
results_color = sv.ColorPalette.from_matplotlib("viridis", num_classes)
else:
results_color = self.color
for image_name in selected_images:
pred_df = df[df.image_path == image_name]
targets = utilities.format_geometry(
dataset.annotations_for_path(image_name, return_tensor=True), scores=False
)
# Assume that validation images are un-augmented
basename = Path(image_name).stem + f"_{trainer.global_step}"
fig = visualize.plot_results(
basename=basename,
results=pred_df,
ground_truth=targets,
savedir=out_dir,
results_color=results_color,
thickness=self.thickness,
show=False,
)
plt.close(fig)
# Pred metadata, if supported.
stats = (
pred_df["score"]
.agg(
mean_confidence="mean",
max_confidence="max",
min_confidence="min",
std_confidence="std",
)
.to_dict()
)
metadata = {"pred_count": len(pred_df), "gt_count": len(targets)}
metadata.update(stats)
with open(os.path.join(out_dir, basename + ".json"), "w") as fp:
json.dump(metadata, fp, indent=1)
self._log_to_all(
image=os.path.join(out_dir, basename + ".png"),
trainer=trainer,
tag="prediction sample",
metadata=metadata,
)
def _log_to_all(self, image: str, trainer, tag, metadata: dict | None = None):
"""Log to all connected loggers.
Since Comet will pickup image logs to Tensorboard by default, we
add a check to log images preferentially to Tensorboard if both
are enabled.
"""
try:
img = np.array(Image.open(image).convert("RGB"))
loggers = [lg for lg in trainer.loggers if hasattr(lg, "experiment")]
tb = next((lg for lg in loggers if hasattr(lg.experiment, "add_image")), None)
if tb is not None:
tb.experiment.add_image(
tag=f"{tag}/{os.path.basename(image)}",
img_tensor=img,
global_step=trainer.global_step,
dataformats="HWC",
)
return
comet = next(
(lg for lg in loggers if hasattr(lg.experiment, "log_image")),
None,
)
if comet is not None:
meta = {
"image_name": os.path.basename(image),
"context": tag,
"step": trainer.global_step,
}
if metadata:
meta.update(metadata)
comet.experiment.log_image(
img,
name=tag,
step=trainer.global_step,
metadata=meta,
)
except Exception as e:
warnings.warn(f"Tried to log {image} exception raised: {e}", stacklevel=2)
[docs]class images_callback(ImagesCallback):
def __init__(self, savedir, **kwargs):
warnings.warn(
"Please use ImagesCallback instead.", DeprecationWarning, stacklevel=2
)
super().__init__(save_dir=savedir, **kwargs)