Skip to content

Retrieval metrics crash or use excessive memory with high-valued query indexes #3290

@ramon-adalia-lmd

Description

@ramon-adalia-lmd

🐛 Bug

The retrieval metrics (RetrievalMAP, RetrievalRecall, etc.) crash or allocate excessive memory when the indexes tensor contains sparse or high-valued integers, even if the number of unique queries is small.

This is because torchmetrics.utilities.data._bincount() relies on index.max() to determine the size of internal tensors. When deterministic mode is enabled (or on XLA/MPS), the fallback implementation can allocate massive [len(indexes), index.max()] tensors, leading to out-of-memory (OOM) errors.

✅ Expected behavior

The metrics should group predictions by query regardless of the numerical values in indexes. The actual values shouldn't impact performance or memory.


To Reproduce

Steps:

  1. Simulate a retrieval task with a few queries using high-value IDs
  2. Use RetrievalMAP() or similar
  3. Call .update() and .compute()
Code sample
import torch
from torchmetrics.retrieval import RetrievalMAP

# Simulate predictions and labels for 3 queries with sparse/high index values
preds = torch.tensor([0.2, 0.8, 0.4, 0.9, 0.1, 0.3])
target = torch.tensor([0, 1, 0, 1, 1, 0])
indexes = torch.tensor([1000, 1000, 50000, 50000, 90000000, 90000000])  # only 3 unique queries

# Enable deterministic mode (triggers fallback path)
torch.use_deterministic_algorithms(True)

metric = RetrievalMAP()
metric.update(preds, target, indexes)

# This line will likely cause a crash or massive memory use due to high index values
result = metric.compute()
print(result)

🔥 What happens

With torch.use_deterministic_algorithms(True) enabled:

  • The _bincount() fallback tries to allocate a tensor of shape [len(indexes), index.max()] = [6, 90000001]
  • This results in >67 GB memory allocation and often crashes

Without deterministic mode, torch.bincount() is used directly, which also scales poorly if index.max() is large.


Environment
  • TorchMetrics version: 1.8.1 (reproduced on latest)
  • Python version: 3.12.10
  • PyTorch version: 2.8.0
  • OS: Ubuntu 22.04 / Windows 11
  • Device: CPU and CUDA (but also applies to MPS/XLA)

Additional context

This is especially common in real-world retrieval problems where indexes come from:

  • Row numbers or IDs in large datasets
  • Sparse query IDs (e.g., from database keys)

Since the metric only uses indexes to group elements, their actual values are irrelevant — only equality matters.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions