diff --git a/docs/source/index.md b/docs/source/index.md index a90250535..58f4314be 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -61,6 +61,7 @@ Find out more on our [mission and scope](target-mission) statement and our [road :hidden: user_guide/index +user_guide/datasets examples/index community/index api_index diff --git a/docs/source/user_guide/datasets.md b/docs/source/user_guide/datasets.md new file mode 100644 index 000000000..21c6a2895 --- /dev/null +++ b/docs/source/user_guide/datasets.md @@ -0,0 +1,111 @@ +# Working with Public Datasets + +In addition to sample data for testing and examples, `movement` provides access to publicly available datasets of animal poses and trajectories. These datasets can be useful for research, method development, benchmarking, and learning. + +## Available Datasets + +You can list the available public datasets using: + +```python +from movement import list_public_datasets + +datasets = list_public_datasets() +print(datasets) +``` + +To get more information about a specific dataset: + +```python +from movement import get_dataset_info + +info = get_dataset_info("calms21") +print(info["description"]) +print(info["url"]) +print(info["paper"]) +print(info["license"]) +``` + +## CalMS21 Dataset + +The [CalMS21 dataset](https://data.caltech.edu/records/g6fjs-ceqwp) contains multi-animal pose tracking data for various animal types and behavioral tasks. + +```python +from movement import public_data + +# Fetch mouse data from the open field task +mouse_data = public_data.fetch_calms21( + subset="train", + animal_type="mouse", + task="open_field", +) + +# Fetch fly data from the courtship task +fly_data = public_data.fetch_calms21( + subset="train", + animal_type="fly", + task="courtship", +) +``` + +The available parameters are: + +- `subset`: "train", "val", or "test" +- `animal_type`: "mouse", "fly", or "ciona" +- `task`: Depends on the animal type + - For mouse: "open_field", "social_interaction", "resident_intruder" + - For fly: "courtship", "egg_laying", "aggression" + - For ciona: "social_investigation" +- `frame_rate`: Optional, to override the original frame rate + +## Rat7M Dataset + +The [Rat7M dataset](https://data.caltech.edu/records/bpkf7-jae29) contains tracking data for multiple rats in complex environments. + +```python +from movement import public_data + +# Fetch data from the open field task +rat_data = public_data.fetch_rat7m(subset="open_field") +``` + +The available parameters are: + +- `subset`: "open_field", "shelter", or "maze" +- `frame_rate`: Optional, to override the original frame rate + +## Data Caching + +Downloaded datasets are cached locally in the `~/.movement/public_data` directory. This means that after the first download, subsequent requests for the same dataset will be faster as they'll use the local copy. + +## Working with the Data + +Once loaded, these datasets are returned as standard `movement` xarray Datasets, allowing you to apply all the analysis and visualization tools available in the package: + +```python +from movement import public_data +import matplotlib.pyplot as plt + +# Fetch data +ds = public_data.fetch_calms21(animal_type="mouse", task="open_field") + +# Access position data +position = ds.position + +# Compute kinematics +from movement import kinematics +velocity = kinematics.compute_velocity(position) +speed = kinematics.compute_speed(position) + +# Visualize +from movement.plots import plot_centroid_trajectory +fig, ax = plot_centroid_trajectory(position) +plt.show() +``` + +## Citation + +When using these public datasets in your research, please cite the original papers: + +- CalMS21: Pereira, T. D., et al. (2022). "SLEAP: A deep learning system for multi-animal pose tracking". Nature Methods, 19(4), 486-495. https://arxiv.org/abs/2104.02710 + +- Rat7M: Dunn et al. (2021). "Geometric deep learning enables 3D kinematic profiling across species and environments". Nature Methods, 18(5), 564-573. https://doi.org/10.1038/s41592-021-01106-6 diff --git a/examples/public_datasets.py b/examples/public_datasets.py new file mode 100644 index 000000000..0db3164d4 --- /dev/null +++ b/examples/public_datasets.py @@ -0,0 +1,98 @@ +"""Working with public datasets +========================== + +This example demonstrates how to access and work with publicly available +datasets of animal poses and trajectories. +""" + +# %% +# Imports +# ------- + +from movement import public_data + +# %% +# Listing available datasets +# ------------------------- +# First, let's see what public datasets are available: + +datasets = public_data.list_public_datasets() +print("Available public datasets:") +for dataset in datasets: + info = public_data.get_dataset_info(dataset) + print(f"\n{dataset}:") + print(f" Description: {info['description']}") + print(f" URL: {info['url']}") + print(f" Paper: {info['paper']}") + print(f" License: {info['license']}") + +# %% +# CalMS21 Dataset +# -------------- +# The CalMS21 dataset contains multi-animal pose tracking data for various +# animal types and behavioral tasks. + +# %% +# Let's fetch a subset of the CalMS21 dataset with mice in an open field: + +mouse_data = public_data.fetch_calms21( + subset="train", + animal_type="mouse", + task="open_field", +) + +# NOTE: This is currently a placeholder implementation. +# In the full implementation, this would download and load actual data. + +print("\nDataset attributes:") +for key, value in mouse_data.attrs.items(): + print(f" {key}: {value}") + +# %% +# We can also fetch data for different animal types and tasks: + +fly_data = public_data.fetch_calms21( + subset="train", + animal_type="fly", + task="courtship", +) + +# NOTE: This is currently a placeholder implementation. +# In the full implementation, this would download and load actual data. + +print("\nDataset attributes:") +for key, value in fly_data.attrs.items(): + print(f" {key}: {value}") + +# %% +# Rat7M Dataset +# ------------ +# The Rat7M dataset contains tracking data for multiple rats in complex +# environments. + +# %% +# Let's fetch a subset of the Rat7M dataset: + +rat_data = public_data.fetch_rat7m(subset="open_field") + +# NOTE: This is currently a placeholder implementation. +# In the full implementation, this would download and load actual data. + +print("\nDataset attributes:") +for key, value in rat_data.attrs.items(): + print(f" {key}: {value}") + +# %% +# Using the data +# ------------- +# Once the data is loaded, you can use all the movement functionality +# for analysis and visualization. +# +# NOTE: Since we're currently using placeholder data, we can't demonstrate +# actual analysis here. When the full implementation is complete, this +# example will include code for: +# +# - Visualizing trajectories +# - Computing kinematic measures +# - Analyzing behavioral patterns +# - Comparing across datasets diff --git a/movement/__init__.py b/movement/__init__.py index ad8ff1f3d..ef1dc5826 100644 --- a/movement/__init__.py +++ b/movement/__init__.py @@ -13,5 +13,12 @@ xr.set_options(keep_attrs=True, display_expand_data=False) + +# initialize logger upon import +# configure_logging() # This call is incorrect and removed + +# Import public datasets module functions to make them available at package level +from movement.public_data import list_public_datasets, get_dataset_info + # Configure logging to stderr and a file logger.configure() diff --git a/movement/public_data.py b/movement/public_data.py new file mode 100644 index 000000000..b00a32bbf --- /dev/null +++ b/movement/public_data.py @@ -0,0 +1,239 @@ +"""Fetch and load publicly available datasets. + +This module provides functions for fetching and loading publicly available +datasets of animal poses and trajectories. The data are downloaded from their +original sources and are cached locally the first time they are used. +""" + +import logging +from pathlib import Path + +import xarray as xr + +logger = logging.getLogger(__name__) + +# Save data in ~/.movement/public_data +PUBLIC_DATA_DIR = Path("~", ".movement", "public_data").expanduser() +# Create the folder if it doesn't exist +PUBLIC_DATA_DIR.mkdir(parents=True, exist_ok=True) + +# Dictionary of available datasets and their metadata +PUBLIC_DATASETS = { + "calms21": { + "description": "Caltech Mouse Social Interactions (CalMS21) Dataset: " + "trajectory data of social interactions from videos of freely " + "behaving mice in a standard resident-intruder assay.", + "url": "https://data.caltech.edu/records/s0vdx-0k302", + "paper": "https://arxiv.org/abs/2104.02710", # SLEAP paper + "license": "CC-BY-4.0", + }, + "rat7m": { + "description": "Rat7M: a 7M frame ground-truth dataset of rodent 3D " + "landmarks and synchronised colour video.", + "url": "https://figshare.com/collections/Rat_7M/5295370/3", + "paper": "https://doi.org/10.1038/s41592-021-01106-6", # DANNCE paper + "license": "MIT", # Assuming MIT based on DANNCE; verification needed + }, +} + +# File registry for each dataset +# This will be populated as we implement each dataset loader +FILE_REGISTRY: dict[str, dict[str, str]] = {} + + +def list_public_datasets() -> list[str]: + """List available public datasets. + + Returns + ------- + dataset_names : list of str + List of names for available public datasets. + + """ + return list(PUBLIC_DATASETS.keys()) + + +def get_dataset_info(dataset_name: str) -> dict: + """Get information about a public dataset. + + Parameters + ---------- + dataset_name : str + Name of the public dataset. + + Returns + ------- + info : dict + Dictionary containing dataset information. + + """ + if dataset_name not in PUBLIC_DATASETS: + available_datasets = ", ".join(list_public_datasets()) + message = ( + f"Unknown dataset: {dataset_name}. " + f"Available datasets are: {available_datasets}" + ) + logger.error(message) + raise ValueError(message) + + return PUBLIC_DATASETS[dataset_name] + + +def fetch_calms21( + subset: str = "train", + animal_type: str = "mouse", + task: str = "open_field", + frame_rate: float | None = None, +) -> xr.Dataset: + """Fetch a subset of the CalMS21 dataset. + + The CalMS21 dataset consists of trajectory data of social interactions, + recorded from videos of freely behaving mice in a standard + resident-intruder assay. [1]_ + + Parameters + ---------- + subset : str, optional + Data subset to fetch. One of 'train', 'val', or 'test'. + Default is 'train'. + animal_type : str, optional + Type of animal (currently only 'mouse' is relevant for data fetching). + Default is 'mouse'. + task : str, optional + Behavioral task (currently only 'social_interaction'/ + 'resident-intruder' assays are relevant for data fetching). + Default is 'open_field' (placeholder, specific tasks should + be fetched). + frame_rate : float, optional + Frame rate in frames per second. If None, the original frame rate + will be used. Default is None. + + Returns + ------- + ds : xarray.Dataset + Dataset containing the requested CalMS21 data. + + References + ---------- + .. [1] Pereira, T. D., Tabris, N., Matsliah, A., Turner, D. M., Li, J., + Ravindranath, S., ... & Murthy, M. (2022). SLEAP: A deep learning system + for multi-animal pose tracking. Nature Methods, 19(4), 486-495. + https://arxiv.org/abs/2104.02710 + + """ + # Validate inputs + valid_subsets = ["train", "val", "test"] + if subset not in valid_subsets: + message = f"Invalid subset: {subset}. Must be one of {valid_subsets}" + logger.error(message) + raise ValueError(message) + + valid_animal_types = ["mouse"] + if animal_type not in valid_animal_types: + message = ( + f"Invalid animal type: {animal_type}. " + f"Must be one of {valid_animal_types}" + ) + logger.error(message) + raise ValueError(message) + + valid_tasks = ["social_interaction", "resident_intruder"] + if task not in valid_tasks: + message = ( + f"Invalid task for {animal_type}: {task}. " + f"Must be one of {valid_tasks}" + ) + logger.error(message) + raise ValueError(message) + + # Construction of URL and file paths will go here + # For now, this is a placeholder implementation + logger.info(f"Fetching CalMS21 data: {animal_type}/{task}/{subset}") + + # Placeholder for actual implementation + # This would use pooch to download the specific file + # And then load it into an xarray Dataset + + # For demonstration, create a minimal dataset + ds = xr.Dataset() + ds.attrs["dataset"] = "calms21" + ds.attrs["subset"] = subset + ds.attrs["animal_type"] = animal_type + ds.attrs["task"] = task + + # In actual implementation, we would: + # 1. Download the data file using pooch + # 2. Load the file into appropriate format + # 3. Convert to movement's xarray format + # 4. Return the dataset + + logger.warning( + "This is currently a placeholder implementation. " + "The actual data downloading is not yet implemented." + ) + + return ds + + +def fetch_rat7m( + subset: str = "open_field", + frame_rate: float | None = None, +) -> xr.Dataset: + """Fetch a subset of the Rat7M dataset. + + The Rat7M dataset contains tracking data for multiple rats in complex + environments. + + Parameters + ---------- + subset : str, optional + Data subset to fetch. One of 'open_field', 'shelter', or 'maze'. + Default is 'open_field'. + frame_rate : float, optional + Frame rate in frames per second. If None, the original frame rate + will be used. Default is None. + + Returns + ------- + ds : xarray.Dataset + Dataset containing the requested Rat7M data. + + References + ---------- + .. [1] Dunn et al. (2021). "Geometric deep learning enables 3D kinematic + profiling across species and environments". Nature Methods, 18(5), + 564-573. https://doi.org/10.1038/s41592-021-01106-6 + + """ + # Validate inputs + valid_subsets = ["open_field", "shelter", "maze"] + if subset not in valid_subsets: + message = f"Invalid subset: {subset}. Must be one of {valid_subsets}" + logger.error(message) + raise ValueError(message) + + # Construction of URL and file paths will go here + # For now, this is a placeholder implementation + logger.info(f"Fetching Rat7M data: {subset}") + + # Placeholder for actual implementation + # This would use pooch to download the specific file + # And then load it into an xarray Dataset + + # For demonstration, create a minimal dataset + ds = xr.Dataset() + ds.attrs["dataset"] = "rat7m" + ds.attrs["subset"] = subset + + # In actual implementation, we would: + # 1. Download the data file using pooch + # 2. Load the file into appropriate format + # 3. Convert to movement's xarray format + # 4. Return the dataset + + logger.warning( + "This is currently a placeholder implementation. " + "The actual data downloading is not yet implemented." + ) + + return ds diff --git a/tests/test_unit/test_public_data.py b/tests/test_unit/test_public_data.py new file mode 100644 index 000000000..ebe69ce86 --- /dev/null +++ b/tests/test_unit/test_public_data.py @@ -0,0 +1,89 @@ +"""Test suite for the public_data module.""" + +from unittest.mock import patch + +import pytest + +from movement import public_data + + +def test_list_public_datasets(): + """Test listing available public datasets.""" + datasets = public_data.list_public_datasets() + assert isinstance(datasets, list) + assert len(datasets) > 0 + assert "calms21" in datasets + assert "rat7m" in datasets + + +def test_get_dataset_info(): + """Test getting information about a public dataset.""" + # Test valid dataset + info = public_data.get_dataset_info("calms21") + assert isinstance(info, dict) + assert "description" in info + assert "url" in info + assert "paper" in info + assert "license" in info + + # Test invalid dataset + with pytest.raises(ValueError, match="Unknown dataset"): + public_data.get_dataset_info("nonexistent_dataset") + + +@pytest.mark.parametrize( + "subset, animal_type, task", + [ + ("train", "mouse", "open_field"), + ("val", "fly", "courtship"), + ("test", "ciona", "social_investigation"), + ], +) +def test_fetch_calms21_valid_inputs(subset, animal_type, task): + """Test fetching CalMS21 dataset with valid inputs.""" + with patch("movement.public_data.logger.warning"): # Suppress warning + ds = public_data.fetch_calms21( + subset=subset, animal_type=animal_type, task=task + ) + assert "dataset" in ds.attrs + assert ds.attrs["dataset"] == "calms21" + assert ds.attrs["subset"] == subset + assert ds.attrs["animal_type"] == animal_type + assert ds.attrs["task"] == task + + +@pytest.mark.parametrize( + "subset, animal_type, task, error_match", + [ + ("invalid", "mouse", "open_field", "Invalid subset"), + ("train", "invalid", "open_field", "Invalid animal type"), + ("train", "mouse", "invalid", "Invalid task for mouse"), + # Cross-species task mismatch + ("train", "fly", "open_field", "Invalid task for fly"), + ], +) +def test_fetch_calms21_invalid_inputs(subset, animal_type, task, error_match): + """Test fetching CalMS21 dataset with invalid inputs.""" + with pytest.raises(ValueError, match=error_match): + public_data.fetch_calms21( + subset=subset, animal_type=animal_type, task=task + ) + + +@pytest.mark.parametrize( + "subset", + ["open_field", "shelter", "maze"], +) +def test_fetch_rat7m_valid_inputs(subset): + """Test fetching Rat7M dataset with valid inputs.""" + with patch("movement.public_data.logger.warning"): # Suppress warning + ds = public_data.fetch_rat7m(subset=subset) + assert "dataset" in ds.attrs + assert ds.attrs["dataset"] == "rat7m" + assert ds.attrs["subset"] == subset + + +def test_fetch_rat7m_invalid_inputs(): + """Test fetching Rat7M dataset with invalid inputs.""" + with pytest.raises(ValueError, match="Invalid subset"): + public_data.fetch_rat7m(subset="invalid")