Skip to content

Commit f90f0cc

Browse files
committed
return loss breakdown for LFQ for logging purposes
1 parent 8f4e427 commit f90f0cc

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.9.12',
6+
version = '1.9.14',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
2222

23+
LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
24+
2325
# helper functions
2426

2527
def exists(v):
@@ -151,7 +153,8 @@ def indices_to_codes(
151153
def forward(
152154
self,
153155
x,
154-
inv_temperature = 1.
156+
inv_temperature = 1.,
157+
return_loss_breakdown = False
155158
):
156159
"""
157160
einstein notation
@@ -214,7 +217,7 @@ def forward(
214217
entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
215218
else:
216219
# if not training, just return dummy 0
217-
entropy_aux_loss = self.zero
220+
entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
218221

219222
# commit loss
220223

@@ -248,4 +251,9 @@ def forward(
248251

249252
aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
250253

251-
return Return(x, indices, aux_loss)
254+
ret = Return(x, indices, aux_loss)
255+
256+
if not return_loss_breakdown:
257+
return ret
258+
259+
return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)

0 commit comments

Comments
 (0)