Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
308 changes: 308 additions & 0 deletions examples/UnetOnRelis3D.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "IVxOn3wyEQ76"
},
"source": [
"# UNet on RELLIS-3D with DetectionMetrics\n",
"\n",
"This tutorial shows how to train a simple **UNet** model on the **RELLIS-3D dataset** and then **evaluate it** using the [DetectionMetrics](https://jderobot.github.io/DetectionMetrics/v2/) library. \n",
"\n",
"While training is included here for demonstration, the main focus of DetectionMetrics is **evaluation**. \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Installation\n",
"\n",
"First, install the required dependencies: **PyTorch**, **torchvision**, and **DetectionMetrics**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pip install torch torchvision\n",
"pip install detection-metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Imports\n",
"\n",
"We import PyTorch for model training and DetectionMetrics for dataset handling and evaluation.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4rrg2GbREPx9"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader\n",
"import torchvision.transforms as T\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from detection_metrics.datasets import Rellis3DImageSegmentationDataset\n",
"from detection_metrics.evaluators import SegmentationEvaluator\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Load RELLIS-3D Dataset\n",
"\n",
"DetectionMetrics provides a ready-to-use class `Rellis3DImageSegmentationDataset`. \n",
"Here we create **train** and **validation** splits, and apply basic transformations.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fcyZcTMDEVVV"
},
"outputs": [],
"source": [
"data_root = \"/path/to/rellis3d\" # TODO: replace with your dataset path\n",
"\n",
"transform = T.Compose([\n",
" T.ToTensor(),\n",
" T.Resize((256, 256)),\n",
"])\n",
"\n",
"train_dataset = Rellis3DImageSegmentationDataset(\n",
" root=data_root,\n",
" split=\"train\",\n",
" transforms=transform\n",
")\n",
"\n",
"val_dataset = Rellis3DImageSegmentationDataset(\n",
" root=data_root,\n",
" split=\"val\",\n",
" transforms=transform\n",
")\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)\n",
"val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)\n",
"\n",
"print(\"Train samples:\", len(train_dataset))\n",
"print(\"Val samples:\", len(val_dataset))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P0dkft82EkeN"
},
"source": [
"## 4. Define UNet Model\n",
"\n",
"We define a simple UNet architecture for semantic segmentation. \n",
"The final layer outputs `n_classes` channels (one for each class in the dataset).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pydGeenMErfZ"
},
"outputs": [],
"source": [
"class DoubleConv(nn.Module):\n",
" def __init__(self, in_channels, out_channels):\n",
" super(DoubleConv, self).__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Conv2d(in_channels, out_channels, 3, padding=1),\n",
" nn.ReLU(inplace=True),\n",
" nn.Conv2d(out_channels, out_channels, 3, padding=1),\n",
" nn.ReLU(inplace=True),\n",
" )\n",
" def forward(self, x):\n",
" return self.net(x)\n",
"\n",
"class UNet(nn.Module):\n",
" def __init__(self, n_classes):\n",
" super(UNet, self).__init__()\n",
" self.enc1 = DoubleConv(3, 64)\n",
" self.pool = nn.MaxPool2d(2)\n",
" self.enc2 = DoubleConv(64, 128)\n",
" self.enc3 = DoubleConv(128, 256)\n",
"\n",
" self.bottleneck = DoubleConv(256, 512)\n",
"\n",
" self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)\n",
" self.dec3 = DoubleConv(512, 256)\n",
" self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)\n",
" self.dec2 = DoubleConv(256, 128)\n",
" self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)\n",
" self.dec1 = DoubleConv(128, 64)\n",
"\n",
" self.final = nn.Conv2d(64, n_classes, 1)\n",
"\n",
" def forward(self, x):\n",
" e1 = self.enc1(x)\n",
" e2 = self.enc2(self.pool(e1))\n",
" e3 = self.enc3(self.pool(e2))\n",
" b = self.bottleneck(self.pool(e3))\n",
" d3 = self.up3(b)\n",
" d3 = self.dec3(torch.cat([d3, e3], dim=1))\n",
" d2 = self.up2(d3)\n",
" d2 = self.dec2(torch.cat([d2, e2], dim=1))\n",
" d1 = self.up1(d2)\n",
" d1 = self.dec1(torch.cat([d1, e1], dim=1))\n",
" return self.final(d1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y9bk1mBjExWa"
},
"source": [
"## 5. Training the Model\n",
"\n",
"We train UNet for a few epochs using **CrossEntropyLoss** and **Adam optimizer**.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SMlf-38xE0dT"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"num_classes = len(train_dataset.classes)\n",
"model = UNet(num_classes).to(device)\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
"\n",
"EPOCHS = 5\n",
"for epoch in range(EPOCHS):\n",
" model.train()\n",
" total_loss = 0\n",
" for imgs, masks in train_loader:\n",
" imgs, masks = imgs.to(device), masks.to(device)\n",
" optimizer.zero_grad()\n",
" outputs = model(imgs)\n",
" loss = criterion(outputs, masks)\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" print(f\"Epoch [{epoch+1}/{EPOCHS}] Loss: {total_loss/len(train_loader):.4f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a-MUDeDfE1qv"
},
"source": [
"## 6. Evaluation with DetectionMetrics\n",
"\n",
"Now we use `SegmentationEvaluator` from DetectionMetrics to compute metrics such as: \n",
"- **Mean Intersection over Union (mIoU)** \n",
"- **Pixel Accuracy** \n",
"- **Per-class metrics**\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KeogQzsqE5US"
},
"outputs": [],
"source": [
"evaluator = SegmentationEvaluator(num_classes=num_classes, class_names=train_dataset.classes)\n",
"\n",
"model.eval()\n",
"with torch.no_grad():\n",
" for imgs, masks in val_loader:\n",
" imgs, masks = imgs.to(device), masks.to(device)\n",
" outputs = model(imgs)\n",
" preds = torch.argmax(outputs, dim=1)\n",
" evaluator.add_batch(preds.cpu().numpy(), masks.cpu().numpy())\n",
"\n",
"results = evaluator.evaluate()\n",
"print(\"Evaluation Results:\")\n",
"for metric, value in results.items():\n",
" print(f\"{metric}: {value:.4f}\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 7. Visualizing Predictions\n",
"\n",
"Finally, let’s visualize some input images, their ground-truth masks, and the predicted segmentation maps.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"imgs, masks = next(iter(val_loader))\n",
"imgs = imgs.to(device)\n",
"outputs = model(imgs)\n",
"preds = torch.argmax(outputs, dim=1).cpu()\n",
"\n",
"plt.figure(figsize=(12,6))\n",
"for i in range(2):\n",
" plt.subplot(3, 2, i*2+1)\n",
" plt.imshow(imgs[i].permute(1,2,0).cpu())\n",
" plt.title(\"Input Image\")\n",
" plt.subplot(3, 2, i*2+2)\n",
" plt.imshow(preds[i])\n",
" plt.title(\"Predicted Mask\")\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ✅ Summary\n",
"\n",
"- We trained a UNet model on **RELLIS-3D**. \n",
"- More importantly, we used **DetectionMetrics** to evaluate it. \n",
"- The evaluation step is the main focus of DetectionMetrics and should always be included. \n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}