Tutorial for training a nest detection model#
Install Comet ML#
[ ]:
!pip install comet_ml
Install DeepForest library#
[ ]:
!git clone https://github.com/weecology/DeepForest.git
[ ]:
%cd DeepForest
!pip install -e .
%cd ..
[ ]:
import os
import subprocess
import sys
import time
import zipfile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from pytorch_lightning.loggers import CometLogger
from tqdm import tqdm
from deepforest import main, preprocess
deepforest_path = os.path.abspath("DeepForest")
print(deepforest_path)
[ ]:
if deepforest_path not in sys.path:
sys.path.insert(0, deepforest_path)
[ ]:
# load the modules
Set up Environment Variables#
In Google Colab#
Use Colab’s secret storage to securely store your API key.
Locate the
Secretstab on the left-hand side panel in your Colab notebook.Add a new secret with the key name as
COMET_API_KEYand paste your Comet ML API key as the value.
Locally#
Set an environment variable COMET_API_KEY in your operating system.
Windows#
Open Command Prompt and set the environment variable:
setx COMET_API_KEY "your_comet_api_key"
Restart your terminal or IDE.
macOS/Linux#
Open your terminal and add the following line to your
.bashrc,.zshrc, or.profilefile:export COMET_API_KEY="your_comet_api_key"
Save the file and reload the shell configuration:
source ~/.bashrc # or ~/.zshrc, ~/.profile, etc.
[ ]:
PLATFORM = "colab" # Platform can be colab or local
environment = {}
if PLATFORM == "colab":
from google.colab import userdata
environment["api_key"] = userdata.get("COMET_API_KEY")
else:
environment["api_key"] = os.getenv("COMET_API_KEY")
[10]:
api_key = environment["api_key"]
[ ]:
# change the project_name
comet_logger = CometLogger(project_name="temporary2", api_key=api_key)
Download the Bird nest dataset#
[ ]:
root_folder = "/content" if PLATFORM == "colab" else os.environ.get("ROOT_FOLDER")
def download_dataset(output_filename="Dataset.zip", extract_folder_name="dataset"):
"""
Download a file from a URL using 'wget', extract its contents to a specified folder,
and handle platform-specific root folder locations.
Args:
- output_filename (str): Name of the downloaded file.
- extract_folder_name (str): Name of the folder to extract the contents into.
Raises:
- FileNotFoundError: If the downloaded zip file does not exist.
Returns:
None
"""
url = "https://www.dropbox.com/s/iczokehl2c5hcjx/nest_images.zip?dl=0"
# Download the file using wget
result = subprocess.run(
["wget", "-O", output_filename, url], capture_output=True, text=True
)
# Check if the download was successful
if result.returncode == 0:
print("Download complete.")
else:
print("Error occurred:", result.stderr)
# Determine the root folder based on the platform
# Paths for zip file and extraction folder
zip_file = os.path.join(root_folder, output_filename)
extract_folder = os.path.join(root_folder, extract_folder_name)
# Check if the zip file exists
if not os.path.exists(zip_file):
raise FileNotFoundError(f"The zip file {zip_file} does not exist.")
# Create the extract folder if it doesn't exist
os.makedirs(extract_folder, exist_ok=True)
# Open the zip file and extract its contents
with zipfile.ZipFile(zip_file, "r") as zip_ref:
for file in tqdm(zip_ref.namelist(), desc="Extracting", unit="files"):
zip_ref.extract(file, extract_folder)
print(f"Successfully unzipped {zip_file} to {extract_folder}.")
return extract_folder
[ ]:
extract_folder = download_dataset()
[ ]:
# Check if the annotations file has been extracted from the zip file
annotations = pd.read_csv(os.path.join(extract_folder, "nest_data.csv"))
annotations.head()
[17]:
# Gather all the images ending with .JPG
image_names = [file for file in os.listdir(extract_folder) if file.endswith(".JPG")]
[ ]:
# Generate crops of the image which has Region of Interest (ROI)
crop_dir = os.path.join(os.getcwd(), "train_data_folder")
annotation_path = os.path.join(extract_folder, "nest_data.csv")
all_annotations = []
for image in image_names:
image_path = os.path.join(extract_folder, image)
annotations = preprocess.split_raster(
path_to_raster=image_path,
annotations_file=annotation_path,
patch_size=400,
patch_overlap=0.05,
base_dir=crop_dir,
)
all_annotations.append(annotations)
train_annotations = pd.concat(all_annotations, ignore_index=True)
[21]:
image_paths = train_annotations.image_path.unique()
# split into 70% train, 20% validation and 10% test annotations
temp_paths = np.random.choice(image_paths, int(len(image_paths) * 0.30))
valid_paths = np.random.choice(temp_paths, int(len(image_paths) * 0.20))
test_paths = [path for path in temp_paths if path not in valid_paths]
valid_annotations = train_annotations.loc[train_annotations.image_path.isin(valid_paths)]
test_annotations = train_annotations.loc[train_annotations.image_path.isin(test_paths)]
train_annotations = train_annotations.loc[~train_annotations.image_path.isin(temp_paths)]
[ ]:
# View output
print(train_annotations.head())
print(f"There are {train_annotations.shape[0]} training crown annotations")
print(f"There are {valid_annotations.shape[0]} test crown annotations")
# save to file and create the file dir
annotations_file = os.path.join(crop_dir, "train.csv")
validation_file = os.path.join(crop_dir, "valid.csv")
test_file = os.path.join(crop_dir, "test.csv")
# Write window annotations file without a header row, same location as the "base_dir" above.
train_annotations.to_csv(annotations_file, index=False)
valid_annotations.to_csv(validation_file, index=False)
test_annotations.to_csv(test_file, index=False)
[ ]:
# initialize the model and change the corresponding config file
m = main.deepforest(config_args={"label_dict": {"Nest": 0}, "num_classes": 1})
# move to GPU and use all the GPU resources
m.config["gpus"] = "-1"
m.config["train"]["csv_file"] = annotations_file
m.config["train"]["root_dir"] = os.path.dirname(annotations_file)
# Define the learning scheduler type
m.config["train"]["scheduler"]["type"] = "cosine"
m.config["score_thresh"] = 0.4
m.config["train"]["epochs"] = 10
m.config["validation"]["csv_file"] = validation_file
m.config["validation"]["root_dir"] = os.path.dirname(validation_file)
[ ]:
m.config["train"]["scheduler"]["type"]
[ ]:
# create a pytorch lighting trainer used to training
# Disable the sanity check for validation data
m.create_trainer(logger=comet_logger, num_sanity_val_steps=0)
# load the lastest release model (RetinaNet)
m.load_model("weecology/deepforest-tree")
[ ]:
# Start the training
start_time = time.time()
m.trainer.fit(m)
print(f"--- Training on GPU: {(time.time() - start_time):.2f} seconds ---")
[ ]:
# save the prediction result to a prediction folder
save_dir = os.path.join(os.getcwd(), "pred_result_test")
results = m.evaluate(
test_file, os.path.dirname(test_file), iou_threshold=0.4, savedir=save_dir
)
[ ]:
results["box_precision"]
[ ]:
results["box_recall"]
[30]:
# save the results to a csv file
results["results"].to_csv("results_test_lr_cosine.csv", index=False)
[ ]:
# Save the model checkpoint
m.trainer.save_checkpoint(
os.path.join(root_folder, "checkpoint_epochs_10_cosine_lr_retinanet.pl")
)
[ ]:
torch.save(m.model.state_dict(), os.path.join(root_folder, "weights_cosine_lr"))
[ ]:
# Load from the saved checkpoint
model = main.deepforest.load_from_checkpoint(
os.path.join(root_folder, "checkpoint_epochs_10_cosine_lr_retinanet.pl")
)
[ ]:
# Add a path to an image to test the model on
path = ""
predicted_image = model.predict_tile(
path=path, return_plot=True, patch_size=300, patch_overlap=0.25
)
plt.imshow(predicted_image)
plt.show()