diff --git a/examples/fuse_multiple_tracks.py b/examples/fuse_multiple_tracks.py new file mode 100644 index 000000000..de93dc36e --- /dev/null +++ b/examples/fuse_multiple_tracks.py @@ -0,0 +1,285 @@ +"""Fuse multiple tracking sources +============================ + +Demonstrate how to combine tracking data from multiple sources to produce a more +accurate trajectory. This is particularly useful in cases where different tracking +methods may fail in different situations, such as with ID swaps. +""" + +# %% +# Imports +# ------- + +from matplotlib import pyplot as plt + +from movement import sample_data +from movement.io import load_poses +from movement.plots import plot_centroid_trajectory +from movement.track_fusion import fuse_tracks + +# %% +# Load sample datasets +# ------------------- +# We'll load the DeepLabCut and SLEAP data for the same mouse in an EPM (Elevated Plus Maze) +# experiment. The DeepLabCut data is considered more reliable, while the SLEAP data was +# generated using a model trained on less data. + +# DeepLabCut data (considered more reliable) +dlc_path = sample_data.fetch_dataset_paths( + "DLC_single-mouse_EPM.predictions.h5" +)["poses"] +ds_dlc = load_poses.from_dlc_file(dlc_path, fps=30) + +# SLEAP data (considered less reliable) +sleap_path = sample_data.fetch_dataset_paths( + "SLEAP_single-mouse_EPM.analysis.h5" +)["poses"] +ds_sleap = load_poses.from_sleap_file(sleap_path, fps=30) + +# %% +# Inspect the datasets +# ------------------- +# Let's look at the available keypoints in each dataset. + +print("DeepLabCut keypoints:", ds_dlc.keypoints.values) +print("SLEAP keypoints:", ds_sleap.keypoints.values) + +# %% +# The two datasets might have different keypoints, so we'll focus on the centroid. +# If "centroid" doesn't exist in one of the datasets, we would need to compute it +# from other keypoints or choose a different keypoint common to both datasets. + +# %% +# Visualize the tracking from the individual sources +# ------------------------------------------------- +# First let's plot the centroid trajectory from both sources separately. + +# Create a figure with two subplots side by side +fig, axes = plt.subplots(1, 2, figsize=(12, 5)) + +# Plot DLC trajectory +plot_centroid_trajectory(ds_dlc.position, ax=axes[0]) +axes[0].set_title("DeepLabCut Tracking") +axes[0].invert_yaxis() # Invert y-axis to match image coordinates + +# Plot SLEAP trajectory +plot_centroid_trajectory(ds_sleap.position, ax=axes[1]) +axes[1].set_title("SLEAP Tracking") +axes[1].invert_yaxis() + +fig.tight_layout() + +# %% +# Fuse tracks using different methods +# ----------------------------------- +# Now we'll combine the tracks using different fusion methods and compare the results. + +# List of methods to try +methods = ["mean", "median", "weighted", "reliability", "kalman"] + +# Create figure with 3 subplots (3 rows, 2 columns) +fig, axes = plt.subplots(3, 2, figsize=(12, 15)) +axes = axes.flatten() + +# Plot the original tracks in the first two subplots +plot_centroid_trajectory(ds_dlc.position, ax=axes[0]) +axes[0].set_title("Original: DeepLabCut") +axes[0].invert_yaxis() + +plot_centroid_trajectory(ds_sleap.position, ax=axes[1]) +axes[1].set_title("Original: SLEAP") +axes[1].invert_yaxis() + +# Fuse and plot the tracks with different methods +for i, method in enumerate(methods, 2): + if i < len(axes): + # Set weights for weighted method (example: 0.7 for DLC, 0.3 for SLEAP) + weights = [0.7, 0.3] if method == "weighted" else None + + # Fuse the tracks + fused_track = fuse_tracks( + datasets=[ds_dlc, ds_sleap], + method=method, + keypoint="centroid", + weights=weights, + print_report=True, + ) + + # Plot the fused track + plot_centroid_trajectory(fused_track, ax=axes[i]) + axes[i].set_title(f"Fused: {method.capitalize()}") + axes[i].invert_yaxis() + +fig.tight_layout() + +# %% +# Detailed Comparison: Kalman Filter Fusion +# ---------------------------------------- +# Let's take a closer look at the Kalman filter method, which often provides +# good results for trajectory data. + +# Create a new figure +plt.figure(figsize=(10, 8)) + +# Fuse tracks with Kalman filter +kalman_fused = fuse_tracks( + datasets=[ds_dlc, ds_sleap], + method="kalman", + keypoint="centroid", + process_noise_scale=0.01, # Controls smoothness of trajectory + measurement_noise_scales=[ + 0.1, + 0.3, + ], # Lower values for more reliable sources + print_report=True, +) + +# Plot trajectories from both sources and the fused result +plt.plot( + ds_dlc.position.sel(keypoints="centroid", space="x"), + ds_dlc.position.sel(keypoints="centroid", space="y"), + "b.", + alpha=0.5, + label="DeepLabCut", +) +plt.plot( + ds_sleap.position.sel(keypoints="centroid", space="x"), + ds_sleap.position.sel(keypoints="centroid", space="y"), + "g.", + alpha=0.5, + label="SLEAP", +) +plt.plot( + kalman_fused.sel(space="x"), + kalman_fused.sel(space="y"), + "r-", + linewidth=2, + label="Kalman Fused", +) + +plt.gca().invert_yaxis() +plt.grid(True, alpha=0.3) +plt.legend() +plt.title("Comparison of Original Tracks and Kalman-Fused Track") +plt.xlabel("X Position") +plt.ylabel("Y Position") + +# %% +# Temporal Analysis: Plotting Coordinate Values Over Time +# ------------------------------------------------------ +# Let's look at how the x-coordinate values change over time for the different sources. + +# Create a new figure +plt.figure(figsize=(12, 6)) + +# Plot x-coordinate over time +time_values = kalman_fused.time.values + +plt.plot( + time_values, + ds_dlc.position.sel(keypoints="centroid", space="x"), + "b-", + alpha=0.5, + label="DeepLabCut", +) +plt.plot( + time_values, + ds_sleap.position.sel(keypoints="centroid", space="x"), + "g-", + alpha=0.5, + label="SLEAP", +) +plt.plot( + time_values, + kalman_fused.sel(space="x"), + "r-", + linewidth=2, + label="Kalman Fused", +) + +plt.grid(True, alpha=0.3) +plt.legend() +plt.title("X-Coordinate Values Over Time") +plt.xlabel("Time") +plt.ylabel("X Position") + +# %% +# Multiple-Animal Tracking Example with Potential ID Swaps +# ------------------------------------------------------- +# Now let's look at a more complex example with multiple animals, +# where ID swaps might be an issue. +# For this, we'll use the SLEAP datasets for three mice. + +# Load the two SLEAP datasets with three mice +ds_proofread = sample_data.fetch_dataset( + "SLEAP_three-mice_Aeon_proofread.analysis.h5" +) +ds_mixed = sample_data.fetch_dataset( + "SLEAP_three-mice_Aeon_mixed-labels.analysis.h5" +) + +print("Proofread dataset individuals:", ds_proofread.individuals.values) +print("Mixed-labels dataset individuals:", ds_mixed.individuals.values) + +# %% +# For each individual in the dataset, fuse the tracks from both sources + +# Create a figure for comparing original and fused tracks +fig, axes = plt.subplots(2, 3, figsize=(15, 10)) + +# Flatten axes for easier iteration +axes = axes.flatten() + +# Plot the original tracks for each mouse in the first row +for i, individual in enumerate(ds_proofread.individuals.values): + if i < 3: # First row + # Plot original trajectory from proofread dataset (more reliable) + pos = ds_proofread.position.sel(individuals=individual) + plot_centroid_trajectory(pos, ax=axes[i]) + axes[i].set_title(f"Original: {individual}") + axes[i].invert_yaxis() + +# Fuse and plot the tracks for each mouse in the second row +for i, individual in enumerate(ds_proofread.individuals.values): + if i < 3: # We have 3 mice + # Get the individual datasets + individual_ds_proofread = ds_proofread.sel(individuals=individual) + individual_ds_mixed = ds_mixed.sel(individuals=individual) + + # Fuse the tracks with the Kalman filter (can be replaced with other methods) + fused_track = fuse_tracks( + datasets=[individual_ds_proofread, individual_ds_mixed], + method="kalman", + keypoint="centroid", + # More weight to the proofread dataset (considered more reliable) + measurement_noise_scales=[0.1, 0.3], + print_report=False, + ) + + # Plot the fused track + plot_centroid_trajectory(fused_track, ax=axes[i + 3]) + axes[i + 3].set_title(f"Fused: {individual}") + axes[i + 3].invert_yaxis() + +fig.tight_layout() + +# %% +# Conclusions +# ---------- +# We've demonstrated several methods for combining tracking data from multiple sources: +# +# 1. **Mean**: Simple averaging of all valid measurements. +# 2. **Median**: More robust to outliers than the mean. +# 3. **Weighted**: Weighted average based on source reliability. +# 4. **Reliability-based**: Selects the most reliable source at each time point. +# 5. **Kalman filter**: Probabilistic approach that models position and velocity. +# +# The Kalman filter often provides the best results as it can: +# - Handle noisy measurements from multiple sources +# - Model the dynamics of movement (position and velocity) +# - Provide smooth trajectories that follow physical constraints +# - Handle missing data and uncertainty in measurements +# +# For multi-animal tracking with potential ID swaps, track fusion can be particularly +# valuable. By combining information from different tracking methods that may fail in +# different situations, we can produce more accurate trajectories across time. diff --git a/movement/__init__.py b/movement/__init__.py index bf5d4a2d2..1ec99216f 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -15,3 +15,6 @@ # initialize logger upon import configure_logging() + +# Make track fusion functionality available at the top level +from movement.track_fusion import fuse_tracks diff --git a/movement/track_fusion.py b/movement/track_fusion.py new file mode 100644 index 000000000..1fe7f8786 --- /dev/null +++ b/movement/track_fusion.py @@ -0,0 +1,678 @@ +"""Combine tracking data from multiple sources. + +This module provides functions for combining tracking data from multiple +sources to produce more accurate trajectories. This is particularly useful +in cases where different tracking methods may fail in different situations, +such as in multi-animal tracking with ID swaps. +""" + +import logging +from enum import Enum, auto + +import numpy as np +import xarray as xr +from scipy.signal import medfilt + +from movement.filtering import interpolate_over_time +from movement.utils.logging import log_error, log_to_attrs +from movement.utils.reports import report_nan_values + +logger = logging.getLogger(__name__) + + +class FusionMethod(Enum): + """Enumeration of available track fusion methods.""" + + MEAN = auto() + MEDIAN = auto() + WEIGHTED = auto() + RELIABILITY_BASED = auto() + KALMAN = auto() + + +@log_to_attrs +def align_datasets( + datasets: list[xr.Dataset], + keypoint: str = "centroid", + interpolate: bool = True, + max_gap: int | None = 5, +) -> list[xr.DataArray]: + """Aligns multiple datasets to have the same time coordinates. + + Parameters + ---------- + datasets : list of xarray.Dataset + List of datasets containing position data to align. + keypoint : str, optional + The keypoint to extract from each dataset, by default "centroid". + interpolate : bool, optional + Whether to interpolate missing values after alignment, by default True. + max_gap : int, optional + Maximum size of gap to interpolate, by default 5. + + Returns + ------- + list of xarray.DataArray + List of aligned DataArrays containing only the specified keypoint position data. + + Notes + ----- + This function extracts the specified keypoint from each dataset, aligns them to + have the same time coordinates, and optionally interpolates missing values. + + """ + if not datasets: + raise log_error(ValueError, "No datasets provided") + + # Extract the keypoint position data from each dataset + position_arrays = [] + for ds in datasets: + # Check if keypoint exists in this dataset + if "keypoints" in ds.dims and keypoint not in ds.keypoints.values: + available_keypoints = list(ds.keypoints.values) + raise log_error( + ValueError, + f"Keypoint '{keypoint}' not found in dataset. " + f"Available keypoints: {available_keypoints}", + ) + + # Extract position for this keypoint + if "keypoints" in ds.dims: + pos = ds.position.sel(keypoints=keypoint) + else: + # Handle datasets without keypoints dimension + pos = ds.position + + position_arrays.append(pos) + + # Get union of all time coordinates + all_times = sorted( + set().union(*[set(arr.time.values) for arr in position_arrays]) + ) + + # Reindex all arrays to the common time coordinate + aligned_arrays = [] + for arr in position_arrays: + reindexed = arr.reindex(time=all_times) + + # Optionally interpolate missing values + if interpolate: + reindexed = interpolate_over_time(reindexed, max_gap=max_gap) + + aligned_arrays.append(reindexed) + + return aligned_arrays + + +@log_to_attrs +def fuse_tracks_mean( + aligned_tracks: list[xr.DataArray], + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks by taking the mean across all sources. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values averaged across sources. + + Notes + ----- + This function computes the mean of all valid position values at each time point. + If all sources have NaN at a particular time point, the result will also be NaN. + + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + # Stack all tracks along a new 'source' dimension + stacked = xr.concat(aligned_tracks, dim="source") + + # Take the mean along the source dimension, ignoring NaNs + fused = stacked.mean(dim="source", skipna=True) + + if print_report: + print(report_nan_values(fused, "Fused track (mean)")) + + return fused + + +@log_to_attrs +def fuse_tracks_median( + aligned_tracks: list[xr.DataArray], + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks by taking the median across all sources. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values being the median across sources. + + Notes + ----- + This function computes the median of all valid position values at each time point. + If all sources have NaN at a particular time point, the result will also be NaN. + This method is more robust to outliers than the mean method. + + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + # Stack all tracks along a new 'source' dimension + stacked = xr.concat(aligned_tracks, dim="source") + + # Take the median along the source dimension, ignoring NaNs + fused = stacked.median(dim="source", skipna=True) + + if print_report: + print(report_nan_values(fused, "Fused track (median)")) + + return fused + + +@log_to_attrs +def fuse_tracks_weighted( + aligned_tracks: list[xr.DataArray], + weights: list[float] = None, + confidence_arrays: list[xr.DataArray] = None, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks using a weighted average. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + weights : list of float, optional + Static weights for each track source. Must sum to 1 if provided. + If not provided and confidence_arrays is also None, equal weights are used. + confidence_arrays : list of xarray.DataArray, optional + Dynamic confidence values for each track. Must match the shape of aligned_tracks. + If provided, these are used instead of static weights. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values weighted by the specified weights or confidence values. + + Notes + ----- + This function computes a weighted average of position values. Weights can be either: + - Static (one weight per source) + - Dynamic (confidence value for each position at each time point) + If both weights and confidence_arrays are provided, confidence_arrays takes precedence. + + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + n_tracks = len(aligned_tracks) + + # Check and prepare weights + if weights is not None: + if len(weights) != n_tracks: + raise log_error( + ValueError, + f"Number of weights ({len(weights)}) does not match " + f"number of tracks ({n_tracks})", + ) + if abs(sum(weights) - 1.0) > 1e-10: + raise log_error( + ValueError, f"Weights must sum to 1, got sum={sum(weights)}" + ) + else: + # Equal weights if nothing is provided + weights = [1.0 / n_tracks] * n_tracks + + # Use dynamic confidence arrays if provided + if confidence_arrays is not None: + if len(confidence_arrays) != n_tracks: + raise log_error( + ValueError, + f"Number of confidence arrays ({len(confidence_arrays)}) does not match " + f"number of tracks ({n_tracks})", + ) + + # Normalize confidence values per time point + # Stack all confidence arrays along a 'source' dimension + stacked_conf = xr.concat(confidence_arrays, dim="source") + + # Calculate sum of confidences at each time point + sum_conf = stacked_conf.sum(dim="source") + + # Handle zeros by replacing with equal weights + has_zeros = sum_conf == 0 + norm_conf = stacked_conf / sum_conf + norm_conf = norm_conf.where(~has_zeros, 1.0 / n_tracks) + + # Apply confidence-weighted average + stacked_pos = xr.concat(aligned_tracks, dim="source") + weighted_pos = stacked_pos * norm_conf + fused = weighted_pos.sum(dim="source", skipna=True) + + else: + # Apply static weights + weighted_tracks = [ + track * weight + for track, weight in zip(aligned_tracks, weights, strict=False) + ] + + # Stack and sum along a new 'source' dimension + stacked = xr.concat(weighted_tracks, dim="source") + + # Calculate where all tracks are NaN + all_nan = xr.concat( + [track.isnull() for track in aligned_tracks], dim="source" + ).all(dim="source") + + # Sum along source dimension, set result to NaN where all sources are NaN + fused = stacked.sum(dim="source", skipna=True).where(~all_nan) + + if print_report: + print(report_nan_values(fused, "Fused track (weighted average)")) + + return fused + + +@log_to_attrs +def fuse_tracks_reliability( + aligned_tracks: list[xr.DataArray], + reliability_metrics: list[float] = None, + window_size: int = 11, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks by selecting the most reliable source at each time point. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + reliability_metrics : list of float, optional + Global reliability score for each source (higher is better). + If not provided, NaN count is used as an inverse reliability metric. + window_size : int, optional + Window size for filtering the selection of sources, by default 11. + Must be an odd number. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values taken from the most reliable source at each time. + + Notes + ----- + This function selects values from the most reliable source at each time point, + then applies a median filter to avoid rapid switching between sources, which + could create unrealistic jumps in the trajectory. + + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + if window_size % 2 == 0: + raise log_error(ValueError, "Window size must be an odd number") + + n_tracks = len(aligned_tracks) + + # Determine track reliability if not provided + if reliability_metrics is None: + # Count NaNs in each track (fewer NaNs = more reliable) + nan_counts = [ + float(track.isnull().sum().values) for track in aligned_tracks + ] + total_values = float(aligned_tracks[0].size) + # Convert to a reliability score (inverse of NaN proportion) + reliability_metrics = [ + 1.0 - (count / total_values) for count in nan_counts + ] + + # Stack all tracks along a new 'source' dimension + stacked = xr.concat(aligned_tracks, dim="source") + + # For each time point, create a selection array based on reliability and NaN status + time_points = stacked.time.values + selected_sources = np.zeros(len(time_points), dtype=int) + + # Loop through each time point + for i, t in enumerate(time_points): + values_at_t = [track.sel(time=t).values for track in aligned_tracks] + is_nan = [np.isnan(val).any() for val in values_at_t] + + # If all sources have NaN, pick the most reliable one anyway + if all(is_nan): + selected_sources[i] = np.argmax(reliability_metrics) + else: + # Filter out NaN sources + valid_indices = [ + idx for idx, nan_status in enumerate(is_nan) if not nan_status + ] + valid_reliability = [ + reliability_metrics[idx] for idx in valid_indices + ] + + # Select the most reliable valid source + best_valid_idx = valid_indices[np.argmax(valid_reliability)] + selected_sources[i] = best_valid_idx + + # Apply median filter to smooth source selection and avoid rapid switching + if window_size > 1 and len(time_points) > window_size: + selected_sources = medfilt(selected_sources, window_size) + + # Create the fused track by selecting values from the chosen source at each time + fused_data = np.zeros((len(time_points), stacked.sizes["space"])) + + for i, (t, source_idx) in enumerate( + zip(time_points, selected_sources, strict=False) + ): + fused_data[i] = stacked.sel(time=t, source=source_idx).values + + # Create a new DataArray with the fused data + fused = xr.DataArray( + data=fused_data, + dims=["time", "space"], + coords={"time": time_points, "space": stacked.space.values}, + ) + + if print_report: + print(report_nan_values(fused, "Fused track (reliability-based)")) + + return fused + + +@log_to_attrs +def fuse_tracks_kalman( + aligned_tracks: list[xr.DataArray], + process_noise_scale: float = 0.01, + measurement_noise_scales: list[float] = None, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks using a Kalman filter. + + Parameters + ---------- + aligned_tracks : list of xarray.DataArray + List of aligned position DataArrays. + process_noise_scale : float, optional + Scale factor for the process noise covariance, by default 0.01. + measurement_noise_scales : list of float, optional + Scale factors for measurement noise for each source. + Lower values indicate more reliable sources. Default is equal values. + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values estimated by the Kalman filter. + + Notes + ----- + This function implements a simple Kalman filter for track fusion. The filter: + 1. Models position and velocity in a state vector + 2. Predicts the next state based on constant velocity assumptions + 3. Updates the prediction using measurements from all available sources + 4. Handles missing measurements (NaNs) by skipping the update step + + The Kalman filter is particularly effective for trajectory smoothing and + handling noisy measurements from multiple sources. + + """ + if not aligned_tracks: + raise log_error(ValueError, "No tracks provided") + + n_tracks = len(aligned_tracks) + + # Set default measurement noise scales if not provided + if measurement_noise_scales is None: + measurement_noise_scales = [1.0] * n_tracks + + if len(measurement_noise_scales) != n_tracks: + raise log_error( + ValueError, + f"Number of measurement noise scales ({len(measurement_noise_scales)}) " + f"does not match number of tracks ({n_tracks})", + ) + + # Get the common time axis + time_points = aligned_tracks[0].time.values + n_timesteps = len(time_points) + + # Get the dimensionality of the space (2D or 3D) + n_dims = len(aligned_tracks[0].space.values) + + # Initialize state vector [x, y, vx, vy] or [x, y, z, vx, vy, vz] + state_dim = 2 * n_dims + state = np.zeros(state_dim) + + # Initialize state covariance matrix + state_cov = np.eye(state_dim) + + # Define transition matrix (constant velocity model) + dt = 1.0 # Assuming unit time steps + A = np.eye(state_dim) + for i in range(n_dims): + A[i, i + n_dims] = dt + + # Define process noise covariance + Q = np.eye(state_dim) * process_noise_scale + + # Define measurement matrix (extracts position from state) + H = np.zeros((n_dims, state_dim)) + for i in range(n_dims): + H[i, i] = 1.0 + + # Initialize storage for Kalman filter output + kalman_output = np.zeros((n_timesteps, n_dims)) + + # For the first time step, initialize with the average of available measurements + first_measurements = [] + for track in aligned_tracks: + pos = track.sel(time=time_points[0]).values + if not np.isnan(pos).any(): + first_measurements.append(pos) + + if first_measurements: + initial_pos = np.mean(first_measurements, axis=0) + state[:n_dims] = initial_pos + kalman_output[0] = initial_pos + + # Run Kalman filter + for t in range(1, n_timesteps): + # Prediction step + state = A @ state + state_cov = A @ state_cov @ A.T + Q + + # Update step - combine all available measurements + measurements = [] + R_list = [] # Measurement noise covariances + + for i, track in enumerate(aligned_tracks): + pos = track.sel(time=time_points[t]).values + if not np.isnan(pos).any(): + measurements.append(pos) + # Measurement noise covariance for this source + R = np.eye(n_dims) * measurement_noise_scales[i] + R_list.append(R) + + # Skip update if no measurements available + if not measurements: + kalman_output[t] = state[:n_dims] + continue + + # Apply update for each measurement + for z, R in zip(measurements, R_list, strict=False): + y = z - H @ state # Measurement residual + S = H @ state_cov @ H.T + R # Residual covariance + K = state_cov @ H.T @ np.linalg.inv(S) # Kalman gain + state = state + K @ y # Updated state + state_cov = ( + np.eye(state_dim) - K @ H + ) @ state_cov # Updated covariance + + # Store the updated position + kalman_output[t] = state[:n_dims] + + # Create a new DataArray with the Kalman filter output + fused = xr.DataArray( + data=kalman_output, + dims=["time", "space"], + coords={"time": time_points, "space": aligned_tracks[0].space.values}, + ) + + if print_report: + print(report_nan_values(fused, "Fused track (Kalman filter)")) + + return fused + + +@log_to_attrs +def fuse_tracks( + datasets: list[xr.Dataset], + method: str | FusionMethod = "kalman", + keypoint: str = "centroid", + interpolate_gaps: bool = True, + max_gap: int = 5, + weights: list[float] = None, + confidence_arrays: list[xr.DataArray] = None, + reliability_metrics: list[float] = None, + window_size: int = 11, + process_noise_scale: float = 0.01, + measurement_noise_scales: list[float] = None, + print_report: bool = False, +) -> xr.DataArray: + """Fuse tracks from multiple datasets using the specified method. + + Parameters + ---------- + datasets : list of xarray.Dataset + List of datasets containing position data to fuse. + method : str or FusionMethod, optional + Track fusion method to use, by default "kalman". Options are: + - "mean": Average position across all sources + - "median": Median position across all sources (robust to outliers) + - "weighted": Weighted average using static weights or confidence values + - "reliability": Select most reliable source at each time point + - "kalman": Apply Kalman filter to estimate the optimal trajectory + keypoint : str, optional + The keypoint to extract from each dataset, by default "centroid". + interpolate_gaps : bool, optional + Whether to interpolate missing values after alignment, by default True. + max_gap : int, optional + Maximum size of gap to interpolate, by default 5. + weights : list of float, optional + Static weights for each track source (used with "weighted" method). + confidence_arrays : list of xarray.DataArray, optional + Dynamic confidence values for each track (used with "weighted" method). + reliability_metrics : list of float, optional + Global reliability score for each source (used with "reliability" method). + window_size : int, optional + Window size for filtering source selection (used with "reliability" method). + process_noise_scale : float, optional + Scale factor for process noise (used with "kalman" method). + measurement_noise_scales : list of float, optional + Scale factors for measurement noise (used with "kalman" method). + print_report : bool, optional + Whether to print a report on the number of NaNs in the result, by default False. + + Returns + ------- + xarray.DataArray + Fused track with position values determined by the specified fusion method. + + Raises + ------ + ValueError + If an unsupported fusion method is specified or parameters are invalid. + + Notes + ----- + This function acts as a high-level interface to various track fusion methods, + automatically handling dataset alignment and applying the selected fusion algorithm. + + """ + # Convert string method to enum if needed + if isinstance(method, str): + method_map = { + "mean": FusionMethod.MEAN, + "median": FusionMethod.MEDIAN, + "weighted": FusionMethod.WEIGHTED, + "reliability": FusionMethod.RELIABILITY_BASED, + "kalman": FusionMethod.KALMAN, + } + + if method.lower() not in method_map: + valid_methods = list(method_map.keys()) + raise log_error( + ValueError, + f"Unsupported fusion method: {method}. " + f"Valid methods are: {valid_methods}", + ) + + method = method_map[method.lower()] + + # Align datasets + aligned_tracks = align_datasets( + datasets=datasets, + keypoint=keypoint, + interpolate=interpolate_gaps, + max_gap=max_gap, + ) + + # Apply fusion method + if method == FusionMethod.MEAN: + return fuse_tracks_mean( + aligned_tracks=aligned_tracks, + print_report=print_report, + ) + + elif method == FusionMethod.MEDIAN: + return fuse_tracks_median( + aligned_tracks=aligned_tracks, + print_report=print_report, + ) + + elif method == FusionMethod.WEIGHTED: + return fuse_tracks_weighted( + aligned_tracks=aligned_tracks, + weights=weights, + confidence_arrays=confidence_arrays, + print_report=print_report, + ) + + elif method == FusionMethod.RELIABILITY_BASED: + return fuse_tracks_reliability( + aligned_tracks=aligned_tracks, + reliability_metrics=reliability_metrics, + window_size=window_size, + print_report=print_report, + ) + + elif method == FusionMethod.KALMAN: + return fuse_tracks_kalman( + aligned_tracks=aligned_tracks, + process_noise_scale=process_noise_scale, + measurement_noise_scales=measurement_noise_scales, + print_report=print_report, + ) + + else: + raise log_error(ValueError, f"Unsupported fusion method: {method}") diff --git a/tests/test_unit/test_track_fusion.py b/tests/test_unit/test_track_fusion.py new file mode 100644 index 000000000..db383fb6b --- /dev/null +++ b/tests/test_unit/test_track_fusion.py @@ -0,0 +1,237 @@ +"""Tests for track fusion functions.""" + +import numpy as np +import pytest +import xarray as xr + +from movement.track_fusion import ( + align_datasets, + fuse_tracks, + fuse_tracks_kalman, + fuse_tracks_mean, + fuse_tracks_median, + fuse_tracks_reliability, + fuse_tracks_weighted, +) + + +@pytest.fixture +def mock_datasets(): + """Create mock datasets for testing track fusion.""" + # Create two simple datasets with different time points and some NaNs + # Dataset 1: More reliable (fewer NaNs) + time1 = np.arange(0, 10, 1) + pos1 = np.zeros((10, 1, 2)) + # Simple straight line with a slope + pos1[:, 0, 0] = np.arange(0, 10, 1) # x coordinate + pos1[:, 0, 1] = np.arange(0, 10, 1) # y coordinate + # Add some NaNs + pos1[3, 0, :] = np.nan + + # Dataset 2: Less reliable (more NaNs) + time2 = np.arange(0, 10, 1) + pos2 = np.zeros((10, 1, 2)) + # Similar trajectory but with some noise + pos2[:, 0, 0] = np.arange(0, 10, 1) + np.random.normal(0, 0.5, 10) + pos2[:, 0, 1] = np.arange(0, 10, 1) + np.random.normal(0, 0.5, 10) + # Add more NaNs + pos2[3, 0, :] = np.nan + pos2[7, 0, :] = np.nan + + # Create xarray datasets + ds1 = xr.Dataset( + data_vars={ + "position": (["time", "keypoints", "space"], pos1), + "confidence": (["time", "keypoints"], np.ones((10, 1))), + }, + coords={ + "time": time1, + "keypoints": ["centroid"], + "space": ["x", "y"], + "individuals": ["individual_0"], + }, + ) + + ds2 = xr.Dataset( + data_vars={ + "position": (["time", "keypoints", "space"], pos2), + "confidence": (["time", "keypoints"], np.ones((10, 1))), + }, + coords={ + "time": time2, + "keypoints": ["centroid"], + "space": ["x", "y"], + "individuals": ["individual_0"], + }, + ) + + return [ds1, ds2] + + +def test_align_datasets(mock_datasets): + """Test aligning datasets with different time points.""" + aligned = align_datasets(mock_datasets, interpolate=False) + + # Check that both arrays have the same time coordinates + assert aligned[0].time.equals(aligned[1].time) + + # Check that NaNs are preserved when interpolate=False + assert np.isnan(aligned[0].sel(time=3, space="x").values) + assert np.isnan(aligned[1].sel(time=3, space="x").values) + assert np.isnan(aligned[1].sel(time=7, space="x").values) + + # Test with interpolation + aligned_interp = align_datasets(mock_datasets, interpolate=True) + + # Check that NaNs are interpolated + assert not np.isnan(aligned_interp[0].sel(time=3, space="x").values) + assert not np.isnan(aligned_interp[1].sel(time=3, space="x").values) + assert not np.isnan(aligned_interp[1].sel(time=7, space="x").values) + + +def test_fuse_tracks_mean(mock_datasets): + """Test mean fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=True) + fused = fuse_tracks_mean(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # Check that the fused track has all time points + assert len(fused.time) == 10 + + # No NaNs when both sources are interpolated + assert not np.isnan(fused).any() + + +def test_fuse_tracks_median(mock_datasets): + """Test median fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=True) + fused = fuse_tracks_median(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # No NaNs when both sources are interpolated + assert not np.isnan(fused).any() + + +def test_fuse_tracks_weighted(mock_datasets): + """Test weighted fusion method.""" + aligned = align_datasets(mock_datasets, interpolate=True) + + # Test with static weights + weights = [0.7, 0.3] + fused = fuse_tracks_weighted(aligned, weights=weights) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # No NaNs when both sources are interpolated + assert not np.isnan(fused).any() + + # Test with invalid weights (sum != 1) + with pytest.raises(ValueError): + fuse_tracks_weighted(aligned, weights=[0.5, 0.2]) + + # Test with mismatched weights length + with pytest.raises(ValueError): + fuse_tracks_weighted(aligned, weights=[0.5, 0.3, 0.2]) + + +def test_fuse_tracks_reliability(mock_datasets): + """Test reliability-based fusion method.""" + aligned = align_datasets( + mock_datasets, interpolate=False + ) # Keep NaNs for testing + + # Test with automatic reliability metrics + fused = fuse_tracks_reliability(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # Test with custom reliability metrics + reliability_metrics = [0.9, 0.5] # First source more reliable + fused = fuse_tracks_reliability( + aligned, reliability_metrics=reliability_metrics + ) + + # Check that we still get a value for time point 7 where only source 1 has data + assert not np.isnan(fused.sel(time=7, space="x").values) + + # Test with invalid window size (even number) + with pytest.raises(ValueError): + fuse_tracks_reliability(aligned, window_size=10) + + +def test_fuse_tracks_kalman(mock_datasets): + """Test Kalman filter fusion method.""" + aligned = align_datasets( + mock_datasets, interpolate=False + ) # Keep NaNs for testing + + # Test with default parameters + fused = fuse_tracks_kalman(aligned) + + # Check output dimensions + assert "source" not in fused.dims + assert "time" in fused.dims + assert "space" in fused.dims + + # Kalman filter should interpolate over missing values + assert not np.isnan(fused).any() + + # Test with custom parameters + fused = fuse_tracks_kalman( + aligned, process_noise_scale=0.1, measurement_noise_scales=[0.1, 0.5] + ) + + # Check that we get a smoother trajectory (less variance) + x_vals = fused.sel(space="x").values + diff = np.diff(x_vals) + assert ( + np.std(diff) < 0.5 + ) # Standard deviation of the differences should be low + + # Test with mismatched noise scales length + with pytest.raises(ValueError): + fuse_tracks_kalman(aligned, measurement_noise_scales=[0.1, 0.2, 0.3]) + + +def test_fuse_tracks_high_level(mock_datasets): + """Test the high-level fuse_tracks interface.""" + # Test each method through the high-level interface + methods = ["mean", "median", "weighted", "reliability", "kalman"] + + for method in methods: + fused = fuse_tracks( + datasets=mock_datasets, + method=method, + keypoint="centroid", + interpolate_gaps=True, + ) + + # Check output dimensions + assert "time" in fused.dims + assert "space" in fused.dims + assert len(fused.space) == 2 + + # No NaNs when interpolation is used + assert not np.isnan(fused).any() + + # Test with invalid method + with pytest.raises(ValueError): + fuse_tracks(mock_datasets, method="invalid_method") + + # Test with non-existent keypoint + with pytest.raises(ValueError): + fuse_tracks(mock_datasets, keypoint="non_existent")