6
6
from keras_rs .src import types
7
7
from keras_rs .src .metrics .utils import standardize_call_inputs_ranks
8
8
from keras_rs .src .api_export import keras_rs_export
9
+ from keras_rs .src .metrics .ranking_metrics_utils import sort_by_scores
10
+
9
11
10
12
@keras_rs_export ("keras_rs.losses.ListMLELoss" )
11
13
class ListMLELoss (keras .losses .Loss ):
@@ -20,7 +22,7 @@ class ListMLELoss(keras.losses.Loss):
20
22
21
23
The loss is computed as the negative log-likelihood of the ground truth
22
24
ranking given the predicted scores:
23
-
25
+
24
26
```
25
27
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
26
28
```
@@ -65,7 +67,6 @@ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
65
67
self .temperature = temperature
66
68
self ._epsilon = 1e-10
67
69
68
-
69
70
def compute_unreduced_loss (
70
71
self ,
71
72
labels : types .Tensor ,
@@ -84,7 +85,7 @@ def compute_unreduced_loss(
84
85
Tuple of (losses, weights) where losses has shape [batch_size, 1]
85
86
and weights has the same shape.
86
87
"""
87
-
88
+
88
89
valid_mask = ops .greater_equal (labels , ops .cast (0.0 , labels .dtype ))
89
90
90
91
if mask is not None :
@@ -95,13 +96,17 @@ def compute_unreduced_loss(
95
96
96
97
batch_has_valid_items = ops .greater (num_valid_items , 0.0 )
97
98
99
+
98
100
labels_for_sorting = ops .where (valid_mask , labels , ops .full_like (labels , - 1e9 ))
99
101
logits_masked = ops .where (valid_mask , logits , ops .full_like (logits , - 1e9 ))
100
-
101
- sorted_indices = ops .argsort (- labels_for_sorting , axis = - 1 )
102
-
103
- sorted_logits = ops .take_along_axis (logits_masked , sorted_indices , axis = - 1 )
104
- sorted_valid_mask = ops .take_along_axis (valid_mask , sorted_indices , axis = - 1 )
102
+
103
+ sorted_logits , sorted_valid_mask = sort_by_scores (
104
+ tensors_to_sort = [logits_masked , valid_mask ],
105
+ scores = labels_for_sorting ,
106
+ mask = None ,
107
+ shuffle_ties = False ,
108
+ seed = None
109
+ )
105
110
106
111
sorted_logits = ops .divide (
107
112
sorted_logits ,
@@ -111,26 +116,21 @@ def compute_unreduced_loss(
111
116
valid_logits_for_max = ops .where (sorted_valid_mask , sorted_logits ,
112
117
ops .full_like (sorted_logits , - 1e9 ))
113
118
raw_max = ops .max (valid_logits_for_max , axis = 1 , keepdims = True )
114
-
115
119
raw_max = ops .where (batch_has_valid_items , raw_max , ops .zeros_like (raw_max ))
116
120
sorted_logits = sorted_logits - raw_max
117
-
118
-
121
+
119
122
exp_logits = ops .exp (sorted_logits )
120
-
121
123
exp_logits = ops .where (sorted_valid_mask , exp_logits , ops .zeros_like (exp_logits ))
122
124
123
125
reversed_exp = ops .flip (exp_logits , axis = 1 )
124
126
reversed_cumsum = ops .cumsum (reversed_exp , axis = 1 )
125
127
cumsum_from_right = ops .flip (reversed_cumsum , axis = 1 )
126
128
127
-
128
129
log_normalizers = ops .log (cumsum_from_right + self ._epsilon )
129
130
log_probs = sorted_logits - log_normalizers
130
131
131
132
log_probs = ops .where (sorted_valid_mask , log_probs , ops .zeros_like (log_probs ))
132
133
133
-
134
134
negative_log_likelihood = - ops .sum (log_probs , axis = 1 , keepdims = True )
135
135
136
136
negative_log_likelihood = ops .where (batch_has_valid_items , negative_log_likelihood ,
@@ -187,9 +187,8 @@ def call(
187
187
losses = ops .multiply (losses , weights )
188
188
losses = ops .squeeze (losses , axis = - 1 )
189
189
return losses
190
-
191
190
192
191
def get_config (self ) -> dict [str , Any ]:
193
192
config : dict [str , Any ] = super ().get_config ()
194
193
config .update ({"temperature" : self .temperature })
195
- return config
194
+ return config
0 commit comments