Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions requirements-sam2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Requirements for SAM2 Tree Detection
torch>=2.1.0
torchvision>=0.16.0
opencv-python>=4.8.0
numpy>=1.24.0
sam2>=1.0
rasterio>=1.3.9
geopandas>=0.14.0
scipy>=1.11.0
tqdm>=4.66.0
matplotlib>=3.8.0
scikit-image>=0.21.0
1 change: 1 addition & 0 deletions tools/sam2_tree_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Intentionally left blank
197 changes: 197 additions & 0 deletions tools/sam2_tree_detection/tree_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import torch
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from pathlib import Path
from typing import List, Tuple, Optional
import rasterio
import geopandas as gpd
from shapely.geometry import Polygon

class TreeDetector:
def __init__(self, model_type: str = "vit_h", device: str = "cuda"):
"""Initialize the TreeDetector with SAM2 model.

Args:
model_type: Type of SAM2 model to use (vit_h, vit_l, vit_b)
device: Device to run model on (cuda or cpu)
"""
self.device = device if torch.cuda.is_available() and device == "cuda" else "cpu"

# The official sam2 implementation uses a config file and checkpoint path
# This part may need adjustment based on how the final model is loaded.
# For now, assuming a similar registry or build function.
model_cfg = f"configs/sam2.1/sam2.1_hiera_{model_type[-1]}.yaml" # e.g., vit_h -> h
checkpoint = f"./checkpoints/sam2.1_hiera_{model_type[-1]}.pt"

self.model = build_sam2(model_cfg, checkpoint)
self.model.to(device)
self.predictor = SAM2ImagePredictor(self.model)

def load_image(self, image_path: str) -> np.ndarray:
"""Load and preprocess aerial image."""
with rasterio.open(image_path) as src:
image = src.read().transpose(1, 2, 0)
if image.shape[2] > 3:
image = image[:, :, :3] # Use only RGB channels
return image

def fine_tune(self,
train_images: List[str],
train_masks: List[str],
epochs: int = 10,
learning_rate: float = 1e-5):
"""Fine-tune SAM2 on tree examples.

Args:
train_images: List of paths to training images
train_masks: List of paths to corresponding tree masks
epochs: Number of training epochs
learning_rate: Learning rate for fine-tuning
"""
self.model.train()
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)

for epoch in range(epochs):
total_loss = 0
for img_path, mask_path in zip(train_images, train_masks):
image = self.load_image(img_path)
mask = np.load(mask_path)

# Embed image
self.predictor.set_image(image)

# Generate prompts from mask
points = self._generate_points_from_mask(mask)
point_labels = np.ones(len(points))

# Get prediction and loss
masks, iou_predictions, low_res_masks = self.predictor.predict(
point_coords=points,
point_labels=point_labels,
multimask_output=False
)

# Calculate loss
loss = self._calculate_loss(masks, mask)

# Backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()

total_loss += loss.item()

print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_images)}")

self.model.eval()

def detect_trees(self,
image_path: str,
confidence_threshold: float = 0.93,
min_area: int = 100) -> gpd.GeoDataFrame:
"""Detect trees in aerial image and return as GeoDataFrame.

Args:
image_path: Path to aerial image
confidence_threshold: Confidence threshold for detection
min_area: Minimum area for tree detection

Returns:
GeoDataFrame with tree polygons and confidence scores
"""
image = self.load_image(image_path)
self.predictor.set_image(image)

# Generate grid of points for detection
points = self._generate_grid_points(image.shape[:2])

masks = []
scores = []

# Process points in batches
batch_size = 64
for i in range(0, len(points), batch_size):
batch_points = points[i:i+batch_size]
batch_labels = np.ones(len(batch_points))

batch_masks, batch_iou, _ = self.predictor.predict(
point_coords=batch_points,
point_labels=batch_labels,
multimask_output=False
)

masks.extend(batch_masks)
scores.extend(batch_iou)

# Convert masks to polygons
polygons = []
confidences = []

for mask, score in zip(masks, scores):
if score > confidence_threshold:
polygon = self._mask_to_polygon(mask)
if polygon and polygon.area > min_area:
polygons.append(polygon)
confidences.append(score)

# Create GeoDataFrame
gdf = gpd.GeoDataFrame({
'geometry': polygons,
'confidence': confidences
})

return gdf

def _generate_points_from_mask(self, mask: np.ndarray) -> np.ndarray:
"""Generate prompt points from mask centroids."""
from scipy import ndimage
labeled_mask, num_features = ndimage.label(mask)
points = []

for i in range(1, num_features + 1):
y, x = ndimage.center_of_mass(labeled_mask == i)
points.append([x, y])

return np.array(points)

def _generate_grid_points(self, shape: Tuple[int, int],
spacing: int = 32) -> np.ndarray:
"""Generate grid of points for detection."""
y, x = np.mgrid[spacing//2:shape[0]:spacing,
spacing//2:shape[1]:spacing]
return np.column_stack((x.ravel(), y.ravel()))

def _mask_to_polygon(self, mask: np.ndarray) -> Optional[Polygon]:
"""Convert binary mask to polygon."""
from skimage import measure
contours = measure.find_contours(mask, 0.5)

if not contours:
return None

# Find largest contour
contour = max(contours, key=len)

# Convert to polygon
try:
polygon = Polygon(contour)
if polygon.is_valid:
return polygon
except:
return None

return None

def _calculate_loss(self, pred_masks: torch.Tensor,
true_mask: np.ndarray) -> torch.Tensor:
"""Calculate loss between predicted and true masks."""
true_mask = torch.from_numpy(true_mask).float().to(self.device)
pred_masks = torch.from_numpy(pred_masks).float().to(self.device)

# Dice loss
intersection = (pred_masks * true_mask).sum()
union = pred_masks.sum() + true_mask.sum()
dice_loss = 1 - (2.0 * intersection + 1e-6) / (union + 1e-6)

return dice_loss
135 changes: 108 additions & 27 deletions tools/widgetTool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from qgis.PyQt import QtCore
from qgis.PyQt.QtCore import Qt, QThread, pyqtSignal
from qgis.PyQt.QtGui import QColor, QKeySequence
from qgis.PyQt.QtWidgets import QApplication, QDockWidget, QFileDialog, QShortcut
from qgis.PyQt.QtWidgets import QApplication, QDockWidget, QFileDialog, QShortcut, QPushButton
from rasterio.windows import from_bounds as window_from_bounds
from torchgeo.datasets import BoundingBox
from torchgeo.samplers import Units
Expand All @@ -43,6 +43,7 @@
from .messageTool import MessageTool
from .SAMTool import SAM_Model
from .torchgeo_sam import SamTestGridGeoSampler, SamTestRasterDataset
from .sam2_tree_detection.tree_detector import TreeDetector

SAM_Model_Types_Full: List[str] = ["vit_h (huge)", "vit_l (large)", "vit_b (base)"]
SAM_Model_Types = [i.split(" ")[0].strip() for i in SAM_Model_Types_Full]
Expand Down Expand Up @@ -207,6 +208,14 @@ def open_widget(self):
self.toggle_sam_hover_mode
)

# Add SAM2 Tree Detection button
self.pushButton_detect_trees = QPushButton("Detect Trees (SAM2)")
self.pushButton_detect_trees.setToolTip(
"Automatically detect trees in a raster layer using SAM2."
)
self.wdg_sel.vl_group_tools.addWidget(self.pushButton_detect_trees)
self.pushButton_detect_trees.clicked.connect(self.run_tree_detection)

# threshold of area
self.wdg_sel.Box_min_pixel.valueChanged.connect(self.filter_feature_by_area)
self.wdg_sel.Box_min_pixel_default.valueChanged.connect(
Expand Down Expand Up @@ -393,30 +402,25 @@ def reset_to_project_crs(self):
self.project.setCrs(self.crs_project)

def destruct(self):
"""Destruct actions when closed widget"""
self.clear_layers(clear_extent=True)
self.reset_to_project_crs()
self.iface.actionPan().trigger()
"""Destructor"""
self.unload()
# self.wdg_sel.closed.disconnect(self.destruct) # This will be destroyed, no need to disconnect

def unload(self):
"""Unload actions when plugin is closed"""
self.clear_layers(clear_extent=True)
if hasattr(self, "shortcut_tab"):
self.disconnect_safely(self.shortcut_tab)
if hasattr(self, "shortcut_undo_sam_pg"):
self.disconnect_safely(self.shortcut_undo_sam_pg)
if hasattr(self, "shortcut_clear"):
self.disconnect_safely(self.shortcut_clear)
if hasattr(self, "shortcut_undo"):
self.disconnect_safely(self.shortcut_undo)
if hasattr(self, "shortcut_save"):
self.disconnect_safely(self.shortcut_save)
if hasattr(self, "shortcut_hover_mode"):
self.disconnect_safely(self.shortcut_hover_mode)
if hasattr(self, "wdg_sel"):
self.disconnect_safely(self.wdg_sel.MapLayerComboBox.layerChanged)
self.iface.removeDockWidget(self.wdg_sel)
self.destruct()
"""Unload the plugin"""
self.canvas.scene().removeItem(self.canvas_points_fg)
self.canvas.scene().removeItem(self.canvas_points_bg)
self.canvas.scene().removeItem(self.canvas_rect)
self.canvas.scene().removeItem(self.canvas_extent)
self.project.layerWillBeRemoved.disconnect(self.clear_layers)

self.wdg_sel.pushButton_fg.clearFocus()
self.wdg_sel.pushButton_bg.clearFocus()
self.wdg_sel.pushButton_rect.clearFocus()
self.iface.actionPan().trigger()
# self.iface.removeDockWidget(self.wdg_sel)
# self.parent.toolbar.setVisible(False)
self.is_plugin_on = False

def load_demo_img(self):
layer_list = QgsProject.instance().mapLayersByName(self.demo_img_name)
Expand Down Expand Up @@ -1212,6 +1216,70 @@ def reset_all_styles(self):
self.reset_prompt_polygon_color()
self.reset_preview_polygon_color()

def run_tree_detection(self):
"""
Handler for the 'Detect Trees (SAM2)' button.
Opens a file dialog to select a raster image, runs tree detection,
and loads the results into QGIS.
"""
# Open file dialog to select raster image
image_path, _ = QFileDialog.getOpenFileName(
self, "Select Raster Image for Tree Detection", "", "Images (*.tif *.tiff *.jp2 *.png *.jpg)"
)

if not image_path:
self.iface.messageBar().pushMessage(
"Info", "Tree detection cancelled.", level=MessageTool.MessageLevel.Info, duration=3
)
return

try:
self.iface.messageBar().pushMessage(
"Info", "Initializing SAM2 Tree Detector...", level=MessageTool.MessageLevel.Info, duration=5
)
# Initialize the detector
detector = TreeDetector(model_type='vit_h', device='cuda')

self.iface.messageBar().pushMessage(
"Info", f"Running tree detection on {os.path.basename(image_path)}. This may take a while...",
level=MessageTool.MessageLevel.Info,
duration=10, # Keep message longer
)
QApplication.processEvents() # Update the UI

# Detect trees
trees_gdf = detector.detect_trees(
image_path=image_path,
confidence_threshold=0.90, # Slightly lower for more recall
min_area=50 # Smaller min area for trees
)

if trees_gdf.empty:
self.iface.messageBar().pushMessage(
"Warning", "No trees found in the selected image.", level=MessageTool.MessageLevel.Warning, duration=5
)
return

# Save results to a temporary GeoPackage file
output_dir = self.cwd / "output"
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"detected_trees_{Path(image_path).stem}.gpkg"
trees_gdf.to_file(str(output_path), driver='GPKG')

# Load the GeoPackage file into QGIS
self.iface.addVectorLayer(str(output_path), f"Detected_Trees_{Path(image_path).stem}", "ogr")

self.iface.messageBar().pushMessage(
"Success", f"Successfully detected {len(trees_gdf)} trees. Layer added to project.",
level=MessageTool.MessageLevel.Success,
duration=5,
)

except Exception as e:
self.iface.messageBar().pushMessage(
"Error", f"An error occurred during tree detection: {e}", level=MessageTool.MessageLevel.Critical, duration=10
)


class EncoderCopilot(QDockWidget):
# TODO: support encoding process in this widget
Expand Down Expand Up @@ -1678,9 +1746,22 @@ def reset_canvas(self):
self.canvas_extent.clear()

def destruct(self):
"""Destruct actions when closed widget"""
self.reset_canvas()
"""Destructor"""
self.unload()
# self.wdg_sel.closed.disconnect(self.destruct) # This will be destroyed, no need to disconnect

def unload(self):
"""Unload actions when plugin is closed"""
self.reset_canvas()
"""Unload the plugin"""
self.canvas.scene().removeItem(self.canvas_points_fg)
self.canvas.scene().removeItem(self.canvas_points_bg)
self.canvas.scene().removeItem(self.canvas_rect)
self.canvas.scene().removeItem(self.canvas_extent)
self.project.layerWillBeRemoved.disconnect(self.clear_layers)

self.wdg_sel.pushButton_fg.clearFocus()
self.wdg_sel.pushButton_bg.clearFocus()
self.wdg_sel.pushButton_rect.clearFocus()
self.iface.actionPan().trigger()
# self.iface.removeDockWidget(self.wdg_sel)
# self.parent.toolbar.setVisible(False)
self.is_plugin_on = False