From 9505cecde54e78e6a8e9522837d05cbc1fce39a6 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Wed, 3 Sep 2025 22:15:43 -0400 Subject: [PATCH 01/15] Adding testcases and fixes to missing class case --- .../segmentation/generalized_dice.py | 1 + .../segmentation/generalized_dice.py | 10 +++- .../test_generalized_dice_score.py | 49 +++++++++++++++++++ 3 files changed, 59 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 1a110980a32..938bf6023f6 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -90,6 +90,7 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: if not per_class: numerator = torch.sum(numerator, 1) denominator = torch.sum(denominator, 1) + return _safe_divide(numerator, denominator, "nan") return _safe_divide(numerator, denominator) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index a047ecf2b18..6dc9dc88e97 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -109,6 +109,7 @@ class GeneralizedDiceScore(Metric): score: Tensor samples: Tensor + class_present: Tensor full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True @@ -135,6 +136,7 @@ def __init__( num_classes = num_classes - 1 if not include_background else num_classes self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") + self.add_state("class_present", default=torch.zeros(num_classes, dtype=torch.int), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with new data.""" @@ -144,9 +146,15 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) self.samples += preds.shape[0] + if self.per_class: + class_mask = target.sum(dim=(0, *range(2, target.ndim))) > 0 + self.class_present += class_mask[1:] if not self.include_background else class_mask + def compute(self) -> Tensor: """Compute the final generalized dice score.""" - return self.score / self.samples + if not self.per_class: + return self.score / self.samples + return torch.where(self.class_present > 0, self.score, torch.tensor(float("nan"))) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index fbecdbdb041..5f2d295abef 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -122,3 +122,52 @@ def test_generalized_dice_functional(self, preds, target, input_format, include_ "input_format": input_format, }, ) + + +@pytest.mark.parametrize("per_class", [True, False]) +@pytest.mark.parametrize("include_background", [True, False]) +def test_samples_with_missing_classes(per_class, include_background): + """Test GeneralizedDiceScore with missing classes in some samples.""" + target = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8) + preds = torch.zeros((4, NUM_CLASSES, 128, 128), dtype=torch.int8) + + target[0, 0, 0, 0] = 1 + preds[0, 0, 0, 0] = 1 + + target[2, 1, 0, 0] = 1 + preds[2, 1, 0, 0] = 1 + + metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background) + score = metric(preds, target) + + target_slice = target if include_background else target[:, 1:] + output_classes = NUM_CLASSES if include_background else NUM_CLASSES - 1 + + if per_class: + assert len(score) == output_classes + for c in range(output_classes): + assert score[c] == 1.0 if target_slice[:, c].sum() > 0 else torch.isnan(score[c]) + else: + assert score.isnan() + + +@pytest.mark.parametrize("per_class", [True, False]) +@pytest.mark.parametrize("include_background", [True, False]) +def test_generalized_dice_zero_denominator(per_class, include_background): + """Check that GeneralizedDiceScore returns NaN when the denominator is all zero (no class present).""" + target = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8) + preds = torch.full((4, NUM_CLASSES, 128, 128), 0, dtype=torch.int8) + + metric = GeneralizedDiceScore(num_classes=NUM_CLASSES, per_class=per_class, include_background=include_background) + + score = metric(preds, target) + + if per_class and include_background: + assert len(score) == NUM_CLASSES + assert all(t.isnan() for t in score) + elif per_class and not include_background: + assert len(score) == NUM_CLASSES - 1 + assert all(t.isnan() for t in score) + else: + # Expect scalar NaN + assert score.isnan() From 495413c6c9e1403b4f152f06948df91d5050af79 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Wed, 3 Sep 2025 22:29:42 -0400 Subject: [PATCH 02/15] Adding Changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5d70d22dda..7e257c7dbb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `precision_at_fixed_recall` and `recall_at_fixed_precision` to correctly return `NaN` thresholds when recall/precision conditions are not met ([#3226](https://github.com/Lightning-AI/torchmetrics/pull/3226)) + +- Fixed `GeneralizedDiceScore` to yield `NaN` if there are missing classes ([#2846](https://github.com/Lightning-AI/torchmetrics/issues/2846)) + --- ## [1.8.1] - 2025-08-07 From 19630920ee9d7e367faac0d0dd496a6e339b9737 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Wed, 3 Sep 2025 23:38:23 -0400 Subject: [PATCH 03/15] Fixing docstring error --- src/torchmetrics/segmentation/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 6dc9dc88e97..720b8d2f481 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -100,7 +100,7 @@ class GeneralizedDiceScore(Metric): tensor(0.4992) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True) >>> gds(preds, target) - tensor([0.5001, 0.4993, 0.4982]) + tensor([5.0008, 4.9930, 4.9825]) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) >>> gds(preds, target) tensor([0.4993, 0.4982]) From 981ca9cb3907e4bb77baee91f20060ffdf7ba6f3 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Thu, 4 Sep 2025 00:19:51 -0400 Subject: [PATCH 04/15] Fixing docstring error --- src/torchmetrics/segmentation/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 720b8d2f481..27c83ea7b3b 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -103,7 +103,7 @@ class GeneralizedDiceScore(Metric): tensor([5.0008, 4.9930, 4.9825]) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) >>> gds(preds, target) - tensor([0.4993, 0.4982]) + tensor([4.9930, 4.9825]) """ From 44175de90844f9e1c9e7fb73452279ff8ca85532 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 8 Sep 2025 14:48:25 -0400 Subject: [PATCH 05/15] Fixing score range --- .../segmentation/generalized_dice.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 27c83ea7b3b..e42ddaafd95 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -107,9 +107,9 @@ class GeneralizedDiceScore(Metric): """ - score: Tensor - samples: Tensor class_present: Tensor + numerator: Tensor + denominator: Tensor full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True @@ -134,27 +134,29 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("score", default=torch.zeros(num_classes if per_class else 1), dist_reduce_fx="sum") - self.add_state("samples", default=torch.zeros(1), dist_reduce_fx="sum") self.add_state("class_present", default=torch.zeros(num_classes, dtype=torch.int), dist_reduce_fx="sum") + self.add_state("numerator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat") + self.add_state("denominator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with new data.""" numerator, denominator = _generalized_dice_update( preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format ) - self.score += _generalized_dice_compute(numerator, denominator, self.per_class).sum(dim=0) - self.samples += preds.shape[0] - + self.numerator = torch.cat([self.numerator, numerator], dim=0) + self.denominator = torch.cat([self.denominator, denominator], dim=0) if self.per_class: class_mask = target.sum(dim=(0, *range(2, target.ndim))) > 0 self.class_present += class_mask[1:] if not self.include_background else class_mask + self.numerator = torch.sum(self.numerator, dim=0, keepdim=True) + self.denominator = torch.sum(self.denominator, dim=0, keepdim=True) def compute(self) -> Tensor: """Compute the final generalized dice score.""" + score = _generalized_dice_compute(self.numerator, self.denominator, self.per_class) if not self.per_class: - return self.score / self.samples - return torch.where(self.class_present > 0, self.score, torch.tensor(float("nan"))) + return score.mean() + return torch.where(self.class_present > 0, score, torch.tensor(float("nan"))).squeeze() def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. From 6b4ac90ac07f2c2913aa43ba95278d8a4e59013b Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 8 Sep 2025 15:12:29 -0400 Subject: [PATCH 06/15] Fixing docstring error --- src/torchmetrics/segmentation/generalized_dice.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index e42ddaafd95..14be9dc2e59 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -100,10 +100,10 @@ class GeneralizedDiceScore(Metric): tensor(0.4992) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True) >>> gds(preds, target) - tensor([5.0008, 4.9930, 4.9825]) + tensor([0.5000, 0.4993, 0.4983]) >>> gds = GeneralizedDiceScore(num_classes=3, per_class=True, include_background=False) >>> gds(preds, target) - tensor([4.9930, 4.9825]) + tensor([0.4993, 0.4983]) """ From 5dacd293945b50a534429df5822e78f3c1b61933 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 8 Sep 2025 15:41:04 -0400 Subject: [PATCH 07/15] Modifying atol --- tests/unittests/segmentation/test_generalized_dice_score.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 5f2d295abef..3d54dc65a17 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -82,6 +82,8 @@ def _reference_generalized_dice( class TestGeneralizedDiceScore(MetricTester): """Test class for `GeneralizedDiceScore` metric.""" + atol = 1e-4 + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): """Test class implementation of metric.""" From 1410b154504fc3292225d925a20aac0384e412f7 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 8 Sep 2025 16:03:23 -0400 Subject: [PATCH 08/15] Debugging unittest failure --- tests/unittests/segmentation/test_generalized_dice_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 3d54dc65a17..51d1be4bb42 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -82,7 +82,7 @@ def _reference_generalized_dice( class TestGeneralizedDiceScore(MetricTester): """Test class for `GeneralizedDiceScore` metric.""" - atol = 1e-4 + atol = 1e-2 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): From b142344556f9471ccc165bca8a66ddf566735031 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 8 Sep 2025 16:12:52 -0400 Subject: [PATCH 09/15] Debugging unittest failure --- tests/unittests/segmentation/test_generalized_dice_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 51d1be4bb42..05dcc69760e 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -82,7 +82,7 @@ def _reference_generalized_dice( class TestGeneralizedDiceScore(MetricTester): """Test class for `GeneralizedDiceScore` metric.""" - atol = 1e-2 + atol = 2e-2 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): From 4ae909a2b45880107d73c50ff86ec1285e456ca2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 12 Sep 2025 19:01:50 +0200 Subject: [PATCH 10/15] atol = 2e-3 --- tests/unittests/segmentation/test_generalized_dice_score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 05dcc69760e..cc55bd58a65 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -82,7 +82,7 @@ def _reference_generalized_dice( class TestGeneralizedDiceScore(MetricTester): """Test class for `GeneralizedDiceScore` metric.""" - atol = 2e-2 + atol = 2e-3 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_generalized_dice_class(self, preds, target, input_format, include_background, ddp): From 8e26b68f70b9426ed9c6b085d06cd30f2d8cfa57 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 29 Sep 2025 13:50:22 -0400 Subject: [PATCH 11/15] Modifying numerator accumulation logic --- .../segmentation/generalized_dice.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index 14be9dc2e59..d7bc1f74a59 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import torch from torch import Tensor @@ -24,6 +24,7 @@ _generalized_dice_validate_args, ) from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -108,8 +109,8 @@ class GeneralizedDiceScore(Metric): """ class_present: Tensor - numerator: Tensor - denominator: Tensor + numerator: List[Tensor] + denominator: List[Tensor] full_state_update: bool = False is_differentiable: bool = False higher_is_better: bool = True @@ -135,28 +136,31 @@ def __init__( num_classes = num_classes - 1 if not include_background else num_classes self.add_state("class_present", default=torch.zeros(num_classes, dtype=torch.int), dist_reduce_fx="sum") - self.add_state("numerator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat") - self.add_state("denominator", default=torch.zeros((0, num_classes)), dist_reduce_fx="cat") + self.add_state("numerator", default=[], dist_reduce_fx="cat") + self.add_state("denominator", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update the state with new data.""" numerator, denominator = _generalized_dice_update( preds, target, self.num_classes, self.include_background, self.weight_type, self.input_format ) - self.numerator = torch.cat([self.numerator, numerator], dim=0) - self.denominator = torch.cat([self.denominator, denominator], dim=0) + self.numerator.append(numerator) + self.denominator.append(denominator) if self.per_class: class_mask = target.sum(dim=(0, *range(2, target.ndim))) > 0 self.class_present += class_mask[1:] if not self.include_background else class_mask - self.numerator = torch.sum(self.numerator, dim=0, keepdim=True) - self.denominator = torch.sum(self.denominator, dim=0, keepdim=True) def compute(self) -> Tensor: """Compute the final generalized dice score.""" - score = _generalized_dice_compute(self.numerator, self.denominator, self.per_class) + numerator = dim_zero_cat(self.numerator) + denominator = dim_zero_cat(self.denominator) if not self.per_class: + score = _generalized_dice_compute(numerator, denominator, self.per_class) return score.mean() - return torch.where(self.class_present > 0, score, torch.tensor(float("nan"))).squeeze() + score = _generalized_dice_compute( + torch.sum(numerator, dim=0, keepdim=True), torch.sum(denominator, dim=0, keepdim=True), self.per_class + ) + return torch.where(self.class_present > 0, score.mean(dim=0), torch.tensor(float("nan"))).squeeze() def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. From ce55fea9365d5a7660c02787d3ca2e037763130d Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 29 Sep 2025 15:14:12 -0400 Subject: [PATCH 12/15] Removing class_present logic to reduce redundant computation --- .../functional/segmentation/generalized_dice.py | 6 ++++-- .../segmentation/generalized_dice.py | 16 ++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 938bf6023f6..c80d37ce703 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -90,8 +90,10 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: if not per_class: numerator = torch.sum(numerator, 1) denominator = torch.sum(denominator, 1) - return _safe_divide(numerator, denominator, "nan") - return _safe_divide(numerator, denominator) + else: + numerator = torch.sum(numerator, 0, keepdim=True) + denominator = torch.sum(denominator, 0, keepdim=True) + return _safe_divide(numerator, denominator, "nan") def generalized_dice_score( diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index d7bc1f74a59..f0baf490e56 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -14,7 +14,6 @@ from collections.abc import Sequence from typing import Any, List, Optional, Union -import torch from torch import Tensor from typing_extensions import Literal @@ -135,7 +134,6 @@ def __init__( self.input_format = input_format num_classes = num_classes - 1 if not include_background else num_classes - self.add_state("class_present", default=torch.zeros(num_classes, dtype=torch.int), dist_reduce_fx="sum") self.add_state("numerator", default=[], dist_reduce_fx="cat") self.add_state("denominator", default=[], dist_reduce_fx="cat") @@ -146,21 +144,11 @@ def update(self, preds: Tensor, target: Tensor) -> None: ) self.numerator.append(numerator) self.denominator.append(denominator) - if self.per_class: - class_mask = target.sum(dim=(0, *range(2, target.ndim))) > 0 - self.class_present += class_mask[1:] if not self.include_background else class_mask def compute(self) -> Tensor: """Compute the final generalized dice score.""" - numerator = dim_zero_cat(self.numerator) - denominator = dim_zero_cat(self.denominator) - if not self.per_class: - score = _generalized_dice_compute(numerator, denominator, self.per_class) - return score.mean() - score = _generalized_dice_compute( - torch.sum(numerator, dim=0, keepdim=True), torch.sum(denominator, dim=0, keepdim=True), self.per_class - ) - return torch.where(self.class_present > 0, score.mean(dim=0), torch.tensor(float("nan"))).squeeze() + score = _generalized_dice_compute(dim_zero_cat(self.numerator), dim_zero_cat(self.denominator), self.per_class) + return score.mean(dim=0) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. From 59c6150b0f9fc072356ecfbd896359cb4e8f6ead Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 29 Sep 2025 16:22:45 -0400 Subject: [PATCH 13/15] Modifying docstring examples --- .../functional/segmentation/generalized_dice.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index c80d37ce703..1118fc652e9 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -129,10 +129,7 @@ def generalized_dice_score( >>> generalized_dice_score(preds, target, num_classes=5) tensor([0.4830, 0.4935, 0.5044, 0.4880]) >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) - tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], - [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], - [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], - [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) + tensor([[0.4845, 0.4997, 0.4993, 0.4864, 0.4912]]) Example (with index tensors): >>> from torch import randint @@ -142,10 +139,7 @@ def generalized_dice_score( >>> generalized_dice_score(preds, target, num_classes=5, input_format="index") tensor([0.1991, 0.1971, 0.2350, 0.2216]) >>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index") - tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069], - [0.1837, 0.2162, 0.0962, 0.2692, 0.1895], - [0.3866, 0.1348, 0.2526, 0.2301, 0.2083], - [0.1978, 0.2804, 0.1714, 0.1915, 0.2783]]) + tensor([[0.1823, 0.2304, 0.2184, 0.2299, 0.2537]]) """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) From ef5dadee0ded22fdaa2f3d4c69e3e511681b897c Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 29 Sep 2025 16:52:19 -0400 Subject: [PATCH 14/15] Modifying docstring examples --- src/torchmetrics/functional/segmentation/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index 1118fc652e9..d94d8618b24 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -139,7 +139,7 @@ def generalized_dice_score( >>> generalized_dice_score(preds, target, num_classes=5, input_format="index") tensor([0.1991, 0.1971, 0.2350, 0.2216]) >>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index") - tensor([[0.1823, 0.2304, 0.2184, 0.2299, 0.2537]]) + tensor([[0.2234, 0.2170, 0.1597, 0.2399, 0.2204]]) """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) From 35246096c806ea0d587c931d1dbd46b223b25049 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Tue, 30 Sep 2025 07:18:26 -0400 Subject: [PATCH 15/15] Modifying functional logic --- .../functional/segmentation/generalized_dice.py | 13 ++++++++----- src/torchmetrics/segmentation/generalized_dice.py | 8 +++++++- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/generalized_dice.py b/src/torchmetrics/functional/segmentation/generalized_dice.py index d94d8618b24..d8d6f17f8d4 100644 --- a/src/torchmetrics/functional/segmentation/generalized_dice.py +++ b/src/torchmetrics/functional/segmentation/generalized_dice.py @@ -90,9 +90,6 @@ def _generalized_dice_compute(numerator: Tensor, denominator: Tensor, per_class: if not per_class: numerator = torch.sum(numerator, 1) denominator = torch.sum(denominator, 1) - else: - numerator = torch.sum(numerator, 0, keepdim=True) - denominator = torch.sum(denominator, 0, keepdim=True) return _safe_divide(numerator, denominator, "nan") @@ -129,7 +126,10 @@ def generalized_dice_score( >>> generalized_dice_score(preds, target, num_classes=5) tensor([0.4830, 0.4935, 0.5044, 0.4880]) >>> generalized_dice_score(preds, target, num_classes=5, per_class=True) - tensor([[0.4845, 0.4997, 0.4993, 0.4864, 0.4912]]) + tensor([[0.4724, 0.5185, 0.4710, 0.5062, 0.4500], + [0.4571, 0.4980, 0.5191, 0.4380, 0.5649], + [0.5428, 0.4904, 0.5358, 0.4830, 0.4724], + [0.4715, 0.4925, 0.4797, 0.5267, 0.4788]]) Example (with index tensors): >>> from torch import randint @@ -139,7 +139,10 @@ def generalized_dice_score( >>> generalized_dice_score(preds, target, num_classes=5, input_format="index") tensor([0.1991, 0.1971, 0.2350, 0.2216]) >>> generalized_dice_score(preds, target, num_classes=5, per_class=True, input_format="index") - tensor([[0.2234, 0.2170, 0.1597, 0.2399, 0.2204]]) + tensor([[0.1714, 0.2500, 0.1304, 0.2524, 0.2069], + [0.1837, 0.2162, 0.0962, 0.2692, 0.1895], + [0.3866, 0.1348, 0.2526, 0.2301, 0.2083], + [0.1978, 0.2804, 0.1714, 0.1915, 0.2783]]) """ _generalized_dice_validate_args(num_classes, include_background, per_class, weight_type, input_format) diff --git a/src/torchmetrics/segmentation/generalized_dice.py b/src/torchmetrics/segmentation/generalized_dice.py index f0baf490e56..6bd3c182410 100644 --- a/src/torchmetrics/segmentation/generalized_dice.py +++ b/src/torchmetrics/segmentation/generalized_dice.py @@ -14,6 +14,7 @@ from collections.abc import Sequence from typing import Any, List, Optional, Union +import torch from torch import Tensor from typing_extensions import Literal @@ -147,7 +148,12 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Tensor: """Compute the final generalized dice score.""" - score = _generalized_dice_compute(dim_zero_cat(self.numerator), dim_zero_cat(self.denominator), self.per_class) + numerator = dim_zero_cat(self.numerator) + denominator = dim_zero_cat(self.denominator) + if self.per_class: + numerator = torch.sum(numerator, 0, keepdim=True) + denominator = torch.sum(denominator, 0, keepdim=True) + score = _generalized_dice_compute(dim_zero_cat(numerator), dim_zero_cat(denominator), self.per_class) return score.mean(dim=0) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: