Source code for deepforest.model

# Model - common class
from deepforest.models import *
import torch


[docs]class Model(): """A architecture agnostic class that controls the basic train, eval and predict functions. A model should optionally allow a backbone for pretraining. To add new architectures, simply create a new module in models/ and write a create_model. Then add the result to the if else statement below. Args: num_classes (int): number of classes in the model nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1] score_thresh (float): minimum prediction score to keep during prediction [0,1] Returns: model: a pytorch nn module """ def __init__(self, config): # Check for required properties and formats self.config = config # Check input output format: self.check_model()
[docs] def create_model(self): """This function converts a deepforest config file into a model. An architecture should have a list of nested arguments in config that match this function""" raise ValueError( "The create_model class method needs to be implemented. Take in args and return a pytorch nn module." )
[docs] def check_model(self): """ Ensure that model follows deepforest guidelines, see ##### If fails, raise ValueError """ # This assumes model creation is not expensive test_model = self.create_model() test_model.eval() # Create a dummy batch of 3 band data. x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] predictions = test_model(x) # Model takes in a batch of images assert len(predictions) == 2 # Returns a list equal to number of images with proper keys per image model_keys = list(predictions[1].keys()) model_keys.sort() assert model_keys == ['boxes', 'labels', 'scores']