diff --git a/entmax/losses.py b/entmax/losses.py index 0113234..1d6bfc1 100644 --- a/entmax/losses.py +++ b/entmax/losses.py @@ -7,24 +7,31 @@ class _GenericLoss(nn.Module): - def __init__(self, ignore_index=-100, reduction="elementwise_mean"): - assert reduction in ["elementwise_mean", "sum", "none"] + def __init__(self, ignore_index=-100, reduction="mean"): + assert reduction in ["elementwise_mean", "sum", "none", "mean"] + if reduction == "elementwise_mean": + reduction = "mean" self.reduction = reduction self.ignore_index = ignore_index super(_GenericLoss, self).__init__() def forward(self, X, target): + if self.ignore_index is not None: + num_samples = target.size(0) + valid_positions = target != self.ignore_index + target = target[valid_positions] + X = X[valid_positions] + loss = self.loss(X, target) - if self.ignore_index >= 0: - ignored_positions = target == self.ignore_index - size = float((target.size(0) - ignored_positions.sum()).item()) - loss.masked_fill_(ignored_positions, 0.0) - else: - size = float(target.size(0)) + + if self.reduction == "none" and self.ignore_index is not None: + nonzero_loss = loss + loss = torch.zeros(num_samples, device=X.device) + loss[valid_positions] = nonzero_loss if self.reduction == "sum": loss = loss.sum() - elif self.reduction == "elementwise_mean": - loss = loss.sum() / size + elif self.reduction == "mean": + loss = loss.mean() return loss @@ -252,7 +259,7 @@ def loss(self, X, target): class SparsemaxLoss(_GenericLoss): - def __init__(self, k=None, ignore_index=-100, reduction="elementwise_mean"): + def __init__(self, k=None, ignore_index=-100, reduction="mean"): self.k = k super(SparsemaxLoss, self).__init__(ignore_index, reduction) @@ -266,7 +273,7 @@ def __init__( alpha=1.5, n_iter=50, ignore_index=-100, - reduction="elementwise_mean", + reduction="mean", ): self.alpha = alpha self.n_iter = n_iter @@ -277,7 +284,7 @@ def loss(self, X, target): class Entmax15Loss(_GenericLoss): - def __init__(self, k=100, ignore_index=-100, reduction="elementwise_mean"): + def __init__(self, k=100, ignore_index=-100, reduction="mean"): self.k = k super(Entmax15Loss, self).__init__(ignore_index, reduction)