{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial for training a nest detection model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install Comet ML" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install comet_ml" ] }, { "cell_type": "markdown", "metadata": { "id": "7EDk5oa1rphh" }, "source": [ "### Install DeepForest library" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!git clone https://github.com/weecology/DeepForest.git" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%cd DeepForest\n", "!pip install -e .\n", "%cd .." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yJNcfr1LAzUp" }, "outputs": [], "source": [ "import os\n", "import subprocess\n", "import sys\n", "import time\n", "import zipfile\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from pytorch_lightning.loggers import CometLogger\n", "from tqdm import tqdm\n", "\n", "from deepforest import main, preprocess\n", "\n", "deepforest_path = os.path.abspath(\"DeepForest\")\n", "print(deepforest_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1Aabh_eYA3OV" }, "outputs": [], "source": [ "if deepforest_path not in sys.path:\n", " sys.path.insert(0, deepforest_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QiccWJYy9oK6" }, "outputs": [], "source": [ "# load the modules" ] }, { "cell_type": "markdown", "metadata": { "id": "PfFbezZn_m6X" }, "source": [ "### Set up Environment Variables\n", "\n", "#### In Google Colab\n", "Use Colab's secret storage to securely store your API key.\n", "\n", "1. Locate the `Secrets` tab on the left-hand side panel in your Colab notebook.\n", "2. Add a new secret with the key name as `COMET_API_KEY` and paste your Comet ML API key as the value.\n", "\n", "#### Locally\n", "Set an environment variable `COMET_API_KEY` in your operating system.\n", "\n", "##### Windows\n", "1. Open Command Prompt and set the environment variable:\n", "\n", " ```bash\n", " setx COMET_API_KEY \"your_comet_api_key\"\n", " ```\n", "\n", "2. Restart your terminal or IDE.\n", "\n", "##### macOS/Linux\n", "1. Open your terminal and add the following line to your `.bashrc`, `.zshrc`, or `.profile` file:\n", "\n", " ```bash\n", " export COMET_API_KEY=\"your_comet_api_key\"\n", " ```\n", "\n", "2. Save the file and reload the shell configuration:\n", "\n", " ```bash\n", " source ~/.bashrc # or ~/.zshrc, ~/.profile, etc.\n", " ```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PLATFORM = \"colab\" # Platform can be colab or local\n", "environment = {}\n", "if PLATFORM == \"colab\":\n", " from google.colab import userdata\n", "\n", " environment[\"api_key\"] = userdata.get(\"COMET_API_KEY\")\n", "else:\n", " environment[\"api_key\"] = os.getenv(\"COMET_API_KEY\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "sSsyUgWt_Wf6" }, "outputs": [], "source": [ "api_key = environment[\"api_key\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AKU3QyhKO4rl" }, "outputs": [], "source": [ "# change the project_name\n", "comet_logger = CometLogger(project_name=\"temporary2\", api_key=api_key)" ] }, { "cell_type": "markdown", "metadata": { "id": "aqIGYsao9CY4" }, "source": [ "### Download the Bird nest dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "root_folder = \"/content\" if PLATFORM == \"colab\" else os.environ.get(\"ROOT_FOLDER\")\n", "\n", "\n", "def download_dataset(output_filename=\"Dataset.zip\", extract_folder_name=\"dataset\"):\n", " \"\"\"\n", " Download a file from a URL using 'wget', extract its contents to a specified folder,\n", " and handle platform-specific root folder locations.\n", "\n", " Args:\n", " - output_filename (str): Name of the downloaded file.\n", " - extract_folder_name (str): Name of the folder to extract the contents into.\n", "\n", " Raises:\n", " - FileNotFoundError: If the downloaded zip file does not exist.\n", "\n", " Returns:\n", " None\n", " \"\"\"\n", " url = \"https://www.dropbox.com/s/iczokehl2c5hcjx/nest_images.zip?dl=0\"\n", "\n", " # Download the file using wget\n", " result = subprocess.run(\n", " [\"wget\", \"-O\", output_filename, url], capture_output=True, text=True\n", " )\n", "\n", " # Check if the download was successful\n", " if result.returncode == 0:\n", " print(\"Download complete.\")\n", " else:\n", " print(\"Error occurred:\", result.stderr)\n", "\n", " # Determine the root folder based on the platform\n", "\n", " # Paths for zip file and extraction folder\n", " zip_file = os.path.join(root_folder, output_filename)\n", " extract_folder = os.path.join(root_folder, extract_folder_name)\n", "\n", " # Check if the zip file exists\n", " if not os.path.exists(zip_file):\n", " raise FileNotFoundError(f\"The zip file {zip_file} does not exist.\")\n", "\n", " # Create the extract folder if it doesn't exist\n", " os.makedirs(extract_folder, exist_ok=True)\n", "\n", " # Open the zip file and extract its contents\n", " with zipfile.ZipFile(zip_file, \"r\") as zip_ref:\n", " for file in tqdm(zip_ref.namelist(), desc=\"Extracting\", unit=\"files\"):\n", " zip_ref.extract(file, extract_folder)\n", "\n", " print(f\"Successfully unzipped {zip_file} to {extract_folder}.\")\n", " return extract_folder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "extract_folder = download_dataset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qxyenEzA9vHs" }, "outputs": [], "source": [ "# Check if the annotations file has been extracted from the zip file\n", "annotations = pd.read_csv(os.path.join(extract_folder, \"nest_data.csv\"))\n", "annotations.head()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "BgELotoq9zNS" }, "outputs": [], "source": [ "# Gather all the images ending with .JPG\n", "image_names = [file for file in os.listdir(extract_folder) if file.endswith(\".JPG\")]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5t7pSf-293rf" }, "outputs": [], "source": [ "# Generate crops of the image which has Region of Interest (ROI)\n", "crop_dir = os.path.join(os.getcwd(), \"train_data_folder\")\n", "annotation_path = os.path.join(extract_folder, \"nest_data.csv\")\n", "all_annotations = []\n", "for image in image_names:\n", " image_path = os.path.join(extract_folder, image)\n", " annotations = preprocess.split_raster(\n", " path_to_raster=image_path,\n", " annotations_file=annotation_path,\n", " patch_size=400,\n", " patch_overlap=0.05,\n", " base_dir=crop_dir,\n", " )\n", " all_annotations.append(annotations)\n", "train_annotations = pd.concat(all_annotations, ignore_index=True)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "id": "sBoIIt6B-G8B" }, "outputs": [], "source": [ "image_paths = train_annotations.image_path.unique()\n", "\n", "# split into 70% train, 20% validation and 10% test annotations\n", "temp_paths = np.random.choice(image_paths, int(len(image_paths) * 0.30))\n", "valid_paths = np.random.choice(temp_paths, int(len(image_paths) * 0.20))\n", "test_paths = [path for path in temp_paths if path not in valid_paths]\n", "\n", "valid_annotations = train_annotations.loc[train_annotations.image_path.isin(valid_paths)]\n", "test_annotations = train_annotations.loc[train_annotations.image_path.isin(test_paths)]\n", "train_annotations = train_annotations.loc[~train_annotations.image_path.isin(temp_paths)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8HGyeEcB-JGO" }, "outputs": [], "source": [ "# View output\n", "print(train_annotations.head())\n", "print(f\"There are {train_annotations.shape[0]} training crown annotations\")\n", "print(f\"There are {valid_annotations.shape[0]} test crown annotations\")\n", "\n", "# save to file and create the file dir\n", "annotations_file = os.path.join(crop_dir, \"train.csv\")\n", "validation_file = os.path.join(crop_dir, \"valid.csv\")\n", "test_file = os.path.join(crop_dir, \"test.csv\")\n", "\n", "# Write window annotations file without a header row, same location as the \"base_dir\" above.\n", "train_annotations.to_csv(annotations_file, index=False)\n", "valid_annotations.to_csv(validation_file, index=False)\n", "test_annotations.to_csv(test_file, index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "l58loc8xyww5" }, "outputs": [], "source": [ "# initialize the model and change the corresponding config file\n", "m = main.deepforest(config_args={\"label_dict\": {\"Nest\": 0}})\n", "\n", "# move to GPU and use all the GPU resources\n", "m.config[\"gpus\"] = \"-1\"\n", "m.config[\"train\"][\"csv_file\"] = annotations_file\n", "m.config[\"train\"][\"root_dir\"] = os.path.dirname(annotations_file)\n", "\n", "# Define the learning scheduler type\n", "m.config[\"train\"][\"scheduler\"][\"type\"] = \"cosine\"\n", "m.config[\"score_thresh\"] = 0.4\n", "m.config[\"train\"][\"epochs\"] = 10\n", "m.config[\"validation\"][\"csv_file\"] = validation_file\n", "m.config[\"validation\"][\"root_dir\"] = os.path.dirname(validation_file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "4Lv2tD9MNnuv" }, "outputs": [], "source": [ "m.config[\"train\"][\"scheduler\"][\"type\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OE-5NB4c-LNw" }, "outputs": [], "source": [ "# create a pytorch lighting trainer used to training\n", "# Disable the sanity check for validation data\n", "m.create_trainer(logger=comet_logger, num_sanity_val_steps=0)\n", "# load the lastest release model (RetinaNet)\n", "m.load_model(\"weecology/deepforest-tree\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "S7ew6GOcC11U" }, "outputs": [], "source": [ "# Start the training\n", "start_time = time.time()\n", "m.trainer.fit(m)\n", "print(f\"--- Training on GPU: {(time.time() - start_time):.2f} seconds ---\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Kh6MdWdwDDup" }, "outputs": [], "source": [ "# save the prediction result to a prediction folder\n", "save_dir = os.path.join(os.getcwd(), \"pred_result_test\")\n", "results = m.evaluate(\n", " test_file, os.path.dirname(test_file), iou_threshold=0.4, savedir=save_dir\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "AYIaznaZeji_" }, "outputs": [], "source": [ "results[\"box_precision\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yzFmrnYoen4r" }, "outputs": [], "source": [ "results[\"box_recall\"]" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "id": "zq6j3wXLepGt" }, "outputs": [], "source": [ "# save the results to a csv file\n", "results[\"results\"].to_csv(\"results_test_lr_cosine.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "08w7RL8EfPe8" }, "outputs": [], "source": [ "# Save the model checkpoint\n", "m.trainer.save_checkpoint(\n", " os.path.join(root_folder, \"checkpoint_epochs_10_cosine_lr_retinanet.pl\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9wRLVBHchOki" }, "outputs": [], "source": [ "torch.save(m.model.state_dict(), os.path.join(root_folder, \"weights_cosine_lr\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q2oL7HgXhTdW" }, "outputs": [], "source": [ "# Load from the saved checkpoint\n", "model = main.deepforest.load_from_checkpoint(\n", " os.path.join(root_folder, \"checkpoint_epochs_10_cosine_lr_retinanet.pl\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "amhOl20SFHUB" }, "outputs": [], "source": [ "# Add a path to an image to test the model on\n", "path = \"\"\n", "predicted_image = model.predict_tile(\n", " path=path, return_plot=True, patch_size=300, patch_overlap=0.25\n", ")\n", "plt.imshow(predicted_image)\n", "plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }