Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 214 additions & 59 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Load pose tracking data from various frameworks into ``movement``."""

from pathlib import Path
from typing import Literal
from typing import Any, Literal, cast

import h5py
import numpy as np
Expand Down Expand Up @@ -42,7 +42,7 @@ def from_numpy(
Array of shape (n_frames, n_keypoints, n_individuals) containing
the point-wise confidence scores. It will be converted to a
:class:`xarray.DataArray` object named "confidence".
If None (default), the scores will be set to an array of NaNs.
If None (default), no confidence data variable is included.
individual_names : list of str, optional
List of unique names for the individuals in the video. If None
(default), the individuals will be named "individual_0",
Expand All @@ -61,8 +61,8 @@ def from_numpy(
Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.
``movement`` dataset containing the pose tracks, confidence scores
(if provided), and associated metadata.

Examples
--------
Expand Down Expand Up @@ -101,6 +101,7 @@ def from_file(
"SLEAP",
"LightningPose",
"Anipose",
"animovement",
"NWB",
],
fps: float | None = None,
Expand All @@ -113,8 +114,9 @@ def from_file(
file_path : pathlib.Path or str
Path to the file containing predicted poses. The file format must
be among those supported by the ``from_dlc_file()``,
``from_slp_file()`` or ``from_lp_file()`` functions. One of these
these functions will be called internally, based on
``from_slp_file()``, ``from_lp_file()``, ``from_anipose_file()``,
or ``from_animovement_file()`` functions. One of these
functions will be called internally, based on
the value of ``source_software``.
source_software : {"DeepLabCut", "SLEAP", "LightningPose", "Anipose", \
"NWB"}
Expand Down Expand Up @@ -167,6 +169,8 @@ def from_file(
"metadata in the file."
)
return from_nwb_file(file_path, **kwargs)
elif source_software == "animovement":
return from_animovement_file(file_path, fps)
else:
raise logger.error(
ValueError(f"Unsupported source software: {source_software}")
Expand Down Expand Up @@ -309,6 +313,7 @@ def from_sleap_file(
# Add metadata as attrs
ds.attrs["source_file"] = file.path.as_posix()
logger.info(f"Loaded pose tracks from {file.path}:\n{ds}")
logger.info(ds)
return ds


Expand Down Expand Up @@ -526,18 +531,30 @@ def _ds_from_sleap_labels_file(
file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"])
labels = read_labels(file.path.as_posix())
tracks_with_scores = _sleap_labels_to_numpy(labels)
individual_names = [track.name for track in labels.tracks] or None
if individual_names is None:

individual_names: list[str] = (
[track.name for track in labels.tracks]
if labels.tracks
else ["individual_0"]
)
if not labels.tracks:
logger.warning(
f"Could not find SLEAP Track in {file.path}. "
"Assuming single-individual dataset and assigning "
"default individual name."
)

keypoint_names: list[str] = [kp.name for kp in labels.skeletons[0].nodes]

# Explicit type assertions for mypy
individual_names = cast(list[str], individual_names)
keypoint_names = cast(list[str], keypoint_names)

return from_numpy(
position_array=tracks_with_scores[:, :-1, :, :],
confidence_array=tracks_with_scores[:, -1, :, :],
individual_names=individual_names,
keypoint_names=[kp.name for kp in labels.skeletons[0].nodes],
keypoint_names=keypoint_names,
fps=fps,
source_software="SLEAP",
)
Expand Down Expand Up @@ -579,15 +596,19 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
lfs = [lf for lf in labels.labeled_frames if lf.video == labels.videos[0]]
# Figure out frame index range
frame_idxs = [lf.frame_idx for lf in lfs]
first_frame = min(0, min(frame_idxs))
last_frame = max(0, max(frame_idxs))
first_frame = min(0, min(frame_idxs)) if frame_idxs else 0
last_frame = max(0, max(frame_idxs)) if frame_idxs else 0

n_tracks = len(labels.tracks) or 1 # If no tracks, assume 1 individual
individuals = labels.tracks or [None]
skeleton = labels.skeletons[-1] # Assume project only uses last skeleton
n_nodes = len(skeleton.nodes)
n_frames = int(last_frame - first_frame + 1)
tracks = np.full((n_frames, 3, n_nodes, n_tracks), np.nan, dtype="float32")

# Initialize tracks array with explicit type
tracks: np.ndarray = np.full(
(n_frames, 3, n_nodes, n_tracks), np.nan, dtype=np.float32
)

for lf in lfs:
i = int(lf.frame_idx - first_frame)
Expand All @@ -603,12 +624,18 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray:
# Use user-labelled instance if available
if user_track_instances:
inst = user_track_instances[-1]
tracks[i, ..., j] = np.hstack(
(inst.numpy(), np.full((n_nodes, 1), np.nan))
).T
points = inst.numpy()
for k in range(n_nodes):
tracks[i, 0, k, j] = points[k, 0] # x-coordinate
tracks[i, 1, k, j] = points[k, 1] # y-coordinate
tracks[i, 2, k, j] = np.nan # No scores for user instances
elif predicted_track_instances:
inst = predicted_track_instances[-1]
tracks[i, ..., j] = inst.numpy(scores=True).T
points = inst.numpy(scores=True)
for k in range(n_nodes):
tracks[i, 0, k, j] = points[k, 0] # x-coordinate
tracks[i, 1, k, j] = points[k, 1] # y-coordinate
tracks[i, 2, k, j] = points[k, 2] # confidence score
return tracks


Expand Down Expand Up @@ -690,8 +717,8 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.
``movement`` dataset containing the pose tracks, confidence scores
(if provided), and associated metadata.

"""
n_frames = data.position_array.shape[0]
Expand All @@ -713,14 +740,18 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
dataset_attrs["time_unit"] = time_unit

DIM_NAMES = ValidPosesDataset.DIM_NAMES
# Initialize data_vars dictionary with position
data_vars = {
"position": xr.DataArray(data.position_array, dims=DIM_NAMES),
}
# Add confidence only if confidence_array is provided
if data.confidence_array is not None:
data_vars["confidence"] = xr.DataArray(
data.confidence_array, dims=DIM_NAMES[:1] + DIM_NAMES[2:]
)
# Convert data to an xarray.Dataset
return xr.Dataset(
data_vars={
"position": xr.DataArray(data.position_array, dims=DIM_NAMES),
"confidence": xr.DataArray(
data.confidence_array, dims=DIM_NAMES[:1] + DIM_NAMES[2:]
),
},
data_vars=data_vars,
coords={
DIM_NAMES[0]: time_coords,
DIM_NAMES[1]: ["x", "y", "z"][:n_space],
Expand All @@ -736,42 +767,13 @@ def from_anipose_style_df(
fps: float | None = None,
individual_name: str = "individual_0",
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from an Anipose 3D dataframe.

Parameters
----------
df : pd.DataFrame
Anipose triangulation dataframe
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame units.
individual_name : str, optional
Name of the individual, by default "individual_0"

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.


Notes
-----
Reshape dataframe with columns keypoint1_x, keypoint1_y, keypoint1_z,
keypoint1_score,keypoint2_x, keypoint2_y, keypoint2_z,
keypoint2_score...to array of positions with dimensions
time, space, keypoints, individuals, and array of confidence (from scores)
with dimensions time, keypoints, individuals.

"""
keypoint_names = sorted(
"""Create a ``movement`` poses dataset from an Anipose 3D dataframe."""
keypoint_names: list[str] = sorted(
list(
set(
[
col.rsplit("_", 1)[0]
for col in df.columns
if any(col.endswith(f"_{s}") for s in ["x", "y", "z"])
]
col.rsplit("_", 1)[0]
for col in df.columns
if any(col.endswith(f"_{s}") for s in ["x", "y", "z"])
)
)
)
Expand All @@ -789,7 +791,7 @@ def from_anipose_style_df(
position_array[:, j, i, 0] = df[f"{kp}_{coord}"]
confidence_array[:, i, 0] = df[f"{kp}_score"]

individual_names = [individual_name]
individual_names: list[str] = [individual_name]

return from_numpy(
position_array=position_array,
Expand Down Expand Up @@ -978,3 +980,156 @@ def _ds_from_nwb_object(
)
)
return xr.merge(single_keypoint_datasets)


def from_tidy_df(
df: pd.DataFrame,
fps: float | None = None,
source_software: str = "animovement",
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from a tidy DataFrame.

Parameters
----------
df : pandas.DataFrame
Tidy DataFrame containing pose tracks and confidence scores.
Expected columns: 'frame', 'track_id', 'keypoint', 'x', 'y',
and optionally 'confidence'.
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.
source_software : str, optional
Name of the pose estimation software or package from which the
data originate. Defaults to "animovement".

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.

Notes
-----
The DataFrame must have at least the following columns:
- 'frame': integer, the frame number (time index)
- 'track_id': string or integer, the individual ID
- 'keypoint': string, the keypoint name
- 'x': float, x-coordinate
- 'y': float, y-coordinate
- 'confidence': float, optional, point-wise confidence scores

Examples
--------
>>> import pandas as pd
>>> from movement.io import load_poses
>>> df = pd.DataFrame(
... {
... "frame": [0, 0, 1, 1],
... "track_id": ["ind1", "ind1", "ind1", "ind1"],
... "keypoint": ["nose", "tail", "nose", "tail"],
... "x": [100.0, 150.0, 101.0, 151.0],
... "y": [200.0, 250.0, 201.0, 251.0],
... "confidence": [0.9, 0.8, 0.85, 0.75],
... }
... )
>>> ds = load_poses.from_tidy_df(df, fps=30)

"""
# Validate DataFrame columns
required_columns = {"frame", "track_id", "keypoint", "x", "y"}
if not required_columns.issubset(df.columns):
missing = required_columns - set(df.columns)
raise ValueError(f"DataFrame missing required columns: {missing}")

# Ensure correct data types
df = df.astype(
{
"frame": int,
"track_id": str,
"keypoint": str,
"x": float,
"y": float,
}
)

# Get unique values for coordinates
time: np.ndarray[Any, np.dtype[np.int_]] = np.sort(df["frame"].unique())
individuals: np.ndarray[Any, np.dtype[np.str_]] = df["track_id"].unique()
keypoints: np.ndarray[Any, np.dtype[np.str_]] = df["keypoint"].unique()
n_frames = len(time)
n_individuals = len(individuals)
n_keypoints = len(keypoints)

# Initialize position and confidence arrays
position_array = np.full(
(n_frames, 2, n_keypoints, n_individuals), np.nan, dtype=float
)
confidence_array = (
np.full((n_frames, n_keypoints, n_individuals), np.nan, dtype=float)
if "confidence" in df.columns
else None
)

# Pivot data to fill arrays
for _idx, row in df.iterrows():
t_idx = np.nonzero(time == row["frame"])[0][0]
i_idx = np.nonzero(individuals == row["track_id"])[0][0]
k_idx = np.nonzero(keypoints == row["keypoint"])[0][0]
position_array[t_idx, 0, k_idx, i_idx] = row["x"]
position_array[t_idx, 1, k_idx, i_idx] = row["y"]
if confidence_array is not None and "confidence" in row:
confidence_array[t_idx, k_idx, i_idx] = row["confidence"]

# Explicitly convert to lists to ensure mypy recognizes list[str]
individual_names: list[str] = list(individuals)
keypoint_names: list[str] = list(keypoints)

return from_numpy(
position_array=position_array,
confidence_array=confidence_array,
individual_names=individual_names,
keypoint_names=keypoint_names,
fps=fps,
source_software=source_software,
)


def from_animovement_file(
file_path: Path | str,
fps: float | None = None,
) -> xr.Dataset:
"""Create a ``movement`` poses dataset from an animovement Parquet file.

Parameters
----------
file_path : pathlib.Path or str
Path to the Parquet file containing pose tracks in tidy format.
fps : float, optional
The number of frames per second in the video. If None (default),
the ``time`` coordinates will be in frame numbers.

Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata.

Examples
--------
>>> from movement.io import load_poses
>>> ds = load_poses.from_animovement_file("path/to/file.parquet", fps=30)

"""
file = ValidFile(
file_path,
expected_permission="r",
expected_suffix=[".parquet"],
)
# Load Parquet file into DataFrame
df = pd.read_parquet(file.path)
# Convert to xarray Dataset
ds = from_tidy_df(df, fps=fps, source_software="animovement")
# Add metadata
ds.attrs["source_file"] = file.path.as_posix()
logger.info(f"Loaded pose tracks from {file.path}:\n{ds}")
return ds
Loading
Loading