Skip to content

Commit 38a254e

Browse files
Used sort_by_score function instead of manual sorting
1 parent c0be0bb commit 38a254e

File tree

3 files changed

+45
-28
lines changed

3 files changed

+45
-28
lines changed

api_gen.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import os
1010
import shutil
11-
import pre_commit
1211
import namex
1312

1413
PACKAGE = "keras_rs"

keras_rs/api/layers/__init__.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,34 @@
44
since your modifications would be overwritten.
55
"""
66

7-
8-
from keras_rs.src.layers.embedding.distributed_embedding import DistributedEmbedding as DistributedEmbedding
9-
from keras_rs.src.layers.embedding.distributed_embedding_config import FeatureConfig as FeatureConfig
10-
from keras_rs.src.layers.embedding.distributed_embedding_config import TableConfig as TableConfig
11-
from keras_rs.src.layers.embedding.embed_reduce import EmbedReduce as EmbedReduce
12-
from keras_rs.src.layers.feature_interaction.dot_interaction import DotInteraction as DotInteraction
13-
from keras_rs.src.layers.feature_interaction.feature_cross import FeatureCross as FeatureCross
14-
from keras_rs.src.layers.retrieval.brute_force_retrieval import BruteForceRetrieval as BruteForceRetrieval
15-
from keras_rs.src.layers.retrieval.hard_negative_mining import HardNegativeMining as HardNegativeMining
16-
from keras_rs.src.layers.retrieval.remove_accidental_hits import RemoveAccidentalHits as RemoveAccidentalHits
7+
from keras_rs.src.layers.embedding.distributed_embedding import (
8+
DistributedEmbedding as DistributedEmbedding,
9+
)
10+
from keras_rs.src.layers.embedding.distributed_embedding_config import (
11+
FeatureConfig as FeatureConfig,
12+
)
13+
from keras_rs.src.layers.embedding.distributed_embedding_config import (
14+
TableConfig as TableConfig,
15+
)
16+
from keras_rs.src.layers.embedding.embed_reduce import (
17+
EmbedReduce as EmbedReduce,
18+
)
19+
from keras_rs.src.layers.feature_interaction.dot_interaction import (
20+
DotInteraction as DotInteraction,
21+
)
22+
from keras_rs.src.layers.feature_interaction.feature_cross import (
23+
FeatureCross as FeatureCross,
24+
)
25+
from keras_rs.src.layers.retrieval.brute_force_retrieval import (
26+
BruteForceRetrieval as BruteForceRetrieval,
27+
)
28+
from keras_rs.src.layers.retrieval.hard_negative_mining import (
29+
HardNegativeMining as HardNegativeMining,
30+
)
31+
from keras_rs.src.layers.retrieval.remove_accidental_hits import (
32+
RemoveAccidentalHits as RemoveAccidentalHits,
33+
)
1734
from keras_rs.src.layers.retrieval.retrieval import Retrieval as Retrieval
18-
from keras_rs.src.layers.retrieval.sampling_probability_correction import SamplingProbabilityCorrection as SamplingProbabilityCorrection
35+
from keras_rs.src.layers.retrieval.sampling_probability_correction import (
36+
SamplingProbabilityCorrection as SamplingProbabilityCorrection,
37+
)

keras_rs/src/losses/list_mle_loss.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from keras_rs.src import types
77
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
88
from keras_rs.src.api_export import keras_rs_export
9+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
10+
911

1012
@keras_rs_export("keras_rs.losses.ListMLELoss")
1113
class ListMLELoss(keras.losses.Loss):
@@ -20,7 +22,7 @@ class ListMLELoss(keras.losses.Loss):
2022
2123
The loss is computed as the negative log-likelihood of the ground truth
2224
ranking given the predicted scores:
23-
25+
2426
```
2527
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
2628
```
@@ -65,7 +67,6 @@ def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
6567
self.temperature = temperature
6668
self._epsilon = 1e-10
6769

68-
6970
def compute_unreduced_loss(
7071
self,
7172
labels: types.Tensor,
@@ -84,7 +85,7 @@ def compute_unreduced_loss(
8485
Tuple of (losses, weights) where losses has shape [batch_size, 1]
8586
and weights has the same shape.
8687
"""
87-
88+
8889
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
8990

9091
if mask is not None:
@@ -95,13 +96,17 @@ def compute_unreduced_loss(
9596

9697
batch_has_valid_items = ops.greater(num_valid_items, 0.0)
9798

99+
98100
labels_for_sorting = ops.where(valid_mask, labels, ops.full_like(labels, -1e9))
99101
logits_masked = ops.where(valid_mask, logits, ops.full_like(logits, -1e9))
100-
101-
sorted_indices = ops.argsort(-labels_for_sorting, axis=-1)
102-
103-
sorted_logits = ops.take_along_axis(logits_masked, sorted_indices, axis=-1)
104-
sorted_valid_mask = ops.take_along_axis(valid_mask, sorted_indices, axis=-1)
102+
103+
sorted_logits, sorted_valid_mask = sort_by_scores(
104+
tensors_to_sort=[logits_masked, valid_mask],
105+
scores=labels_for_sorting,
106+
mask=None,
107+
shuffle_ties=False,
108+
seed=None
109+
)
105110

106111
sorted_logits = ops.divide(
107112
sorted_logits,
@@ -111,26 +116,21 @@ def compute_unreduced_loss(
111116
valid_logits_for_max = ops.where(sorted_valid_mask, sorted_logits,
112117
ops.full_like(sorted_logits, -1e9))
113118
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
114-
115119
raw_max = ops.where(batch_has_valid_items, raw_max, ops.zeros_like(raw_max))
116120
sorted_logits = sorted_logits - raw_max
117-
118-
121+
119122
exp_logits = ops.exp(sorted_logits)
120-
121123
exp_logits = ops.where(sorted_valid_mask, exp_logits, ops.zeros_like(exp_logits))
122124

123125
reversed_exp = ops.flip(exp_logits, axis=1)
124126
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
125127
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
126128

127-
128129
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
129130
log_probs = sorted_logits - log_normalizers
130131

131132
log_probs = ops.where(sorted_valid_mask, log_probs, ops.zeros_like(log_probs))
132133

133-
134134
negative_log_likelihood = -ops.sum(log_probs, axis=1, keepdims=True)
135135

136136
negative_log_likelihood = ops.where(batch_has_valid_items, negative_log_likelihood,
@@ -187,9 +187,8 @@ def call(
187187
losses = ops.multiply(losses, weights)
188188
losses = ops.squeeze(losses, axis=-1)
189189
return losses
190-
191190

192191
def get_config(self) -> dict[str, Any]:
193192
config: dict[str, Any] = super().get_config()
194193
config.update({"temperature": self.temperature})
195-
return config
194+
return config

0 commit comments

Comments
 (0)