-
Couldn't load subscription status.
- Fork 466
Description
🐛 Bug
The retrieval metrics (RetrievalMAP, RetrievalRecall, etc.) crash or allocate excessive memory when the indexes tensor contains sparse or high-valued integers, even if the number of unique queries is small.
This is because torchmetrics.utilities.data._bincount() relies on index.max() to determine the size of internal tensors. When deterministic mode is enabled (or on XLA/MPS), the fallback implementation can allocate massive [len(indexes), index.max()] tensors, leading to out-of-memory (OOM) errors.
✅ Expected behavior
The metrics should group predictions by query regardless of the numerical values in indexes. The actual values shouldn't impact performance or memory.
To Reproduce
Steps:
- Simulate a retrieval task with a few queries using high-value IDs
- Use
RetrievalMAP()or similar - Call
.update()and.compute()
Code sample
import torch
from torchmetrics.retrieval import RetrievalMAP
# Simulate predictions and labels for 3 queries with sparse/high index values
preds = torch.tensor([0.2, 0.8, 0.4, 0.9, 0.1, 0.3])
target = torch.tensor([0, 1, 0, 1, 1, 0])
indexes = torch.tensor([1000, 1000, 50000, 50000, 90000000, 90000000]) # only 3 unique queries
# Enable deterministic mode (triggers fallback path)
torch.use_deterministic_algorithms(True)
metric = RetrievalMAP()
metric.update(preds, target, indexes)
# This line will likely cause a crash or massive memory use due to high index values
result = metric.compute()
print(result)🔥 What happens
With torch.use_deterministic_algorithms(True) enabled:
- The
_bincount()fallback tries to allocate a tensor of shape[len(indexes), index.max()] = [6, 90000001] - This results in >67 GB memory allocation and often crashes
Without deterministic mode, torch.bincount() is used directly, which also scales poorly if index.max() is large.
Environment
- TorchMetrics version: 1.8.1 (reproduced on latest)
- Python version: 3.12.10
- PyTorch version: 2.8.0
- OS: Ubuntu 22.04 / Windows 11
- Device: CPU and CUDA (but also applies to MPS/XLA)
Additional context
This is especially common in real-world retrieval problems where indexes come from:
- Row numbers or IDs in large datasets
- Sparse query IDs (e.g., from database keys)
Since the metric only uses indexes to group elements, their actual values are irrelevant — only equality matters.