Source code for deepforest.callbacks
"""A deepforest callback Callbacks must have the following methods
on_epoch_begin, on_epoch_end, on_fit_end, on_fit_begin methods and inject model
and epoch kwargs."""
from deepforest import visualize
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import glob
import tempfile
from pytorch_lightning import Callback
from deepforest import dataset
from deepforest import utilities
from deepforest import predict
import torch
[docs]
class images_callback(Callback):
"""Run evaluation on a file of annotations during training.
Args:
savedir: optional, directory to save predicted images
probability_threshold: minimum probablity for inclusion, see deepforest.evaluate
n: number of images to upload
select_random (False): whether to select random images or the first n images
every_n_epochs: run epoch interval
color: color of the bounding box as a tuple of BGR color, e.g. orange annotations is (0, 165, 255)
thickness: thickness of the rectangle border line in px
Returns:
None: either prints validation scores or logs them to the pytorch-lightning logger
"""
def __init__(self,
savedir,
n=2,
every_n_epochs=5,
select_random=False,
color=None,
thickness=1):
self.savedir = savedir
self.n = n
self.color = color
self.thickness = thickness
self.select_random = select_random
self.every_n_epochs = every_n_epochs
[docs]
def log_images(self, pl_module):
# It is not clear if this is per device, or per batch. If per batch, then this will not work.
df = pl_module.predictions[0]
# limit to n images, potentially randomly selected
if self.select_random:
selected_images = np.random.choice(df.image_path.unique(), self.n)
else:
selected_images = df.image_path.unique()[:self.n]
df = df[df.image_path.isin(selected_images)]
visualize.plot_prediction_dataframe(
df,
root_dir=pl_module.config["validation"]["root_dir"],
savedir=self.savedir,
color=self.color,
thickness=self.thickness)
try:
saved_plots = glob.glob("{}/*.png".format(self.savedir))
for x in saved_plots:
pl_module.logger.experiment.log_image(x)
except Exception as e:
print("Could not find comet logger in lightning module, "
"skipping upload, images were saved to {}, "
"error was rasied {}".format(self.savedir, e))
[docs]
def on_validation_epoch_end(self, trainer, pl_module):
if trainer.sanity_checking: # optional skip
return
if trainer.current_epoch % self.every_n_epochs == 0:
print("Running image callback")
self.log_images(pl_module)