Skip to content
Discussion options

You must be logged in to vote

When using torchmetrics with PyTorch Lightning in DDP mode (multi-GPU), metrics need to be synchronized across all processes before calling compute(). This is because each GPU calculates partial metric states locally, and compute() waits for all to sync, which can cause hangs if not done properly.

A recommended approach:

Use the sync_on_compute=True argument when initializing your metric. This ensures metrics automatically sync across processes before computing:

from torchmetrics import Accuracy
metric = Accuracy(sync_on_compute=True)

Then you can safely call:

metric.update(preds, target)
result = metric.compute()
metric.reset()

Alternatively, you can manually call metric.sync() before co…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by Borda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants