1111
1212import torch
1313from torch import nn , Tensor
14+ import torch .nn .functional as F
1415from torch .nn import Module
1516
1617from einops import rearrange , reduce , pack , unpack
@@ -53,8 +54,9 @@ def __init__(
5354 dim = None ,
5455 codebook_size = None ,
5556 entropy_loss_weight = 0.1 ,
57+ commitment_loss_weight = 1. ,
5658 diversity_gamma = 2.5 ,
57- straight_through_activation = nn .Tanh (),
59+ straight_through_activation = nn .Identity (),
5860 num_codebooks = 1 ,
5961 keep_num_codebooks_dim = None
6062 ):
@@ -91,6 +93,10 @@ def __init__(
9193 self .diversity_gamma = diversity_gamma
9294 self .entropy_loss_weight = entropy_loss_weight
9395
96+ # commitment loss
97+
98+ self .commitment_loss_weight = commitment_loss_weight
99+
94100 # for no auxiliary loss, during inference
95101
96102 self .register_buffer ('mask' , 2 ** torch .arange (codebook_dim - 1 , - 1 , - 1 ))
@@ -157,6 +163,8 @@ def forward(
157163
158164 # quantize by eq 3.
159165
166+ original_input = x
167+
160168 ones = torch .ones_like (x )
161169 quantized = torch .where (x > 0 , ones , - ones )
162170
@@ -190,7 +198,12 @@ def forward(
190198 # if not training, just return dummy 0
191199 entropy_aux_loss = self .zero
192200
193- entropy_aux_loss = entropy_aux_loss * self .entropy_loss_weight
201+ # commit loss
202+
203+ if self .training :
204+ commit_loss = F .mse_loss (original_input , quantized .detach ())
205+ else :
206+ commit_loss = self .zero
194207
195208 # merge back codebook dim
196209
@@ -213,4 +226,8 @@ def forward(
213226 if not self .keep_num_codebooks_dim :
214227 indices = rearrange (indices , '... 1 -> ...' )
215228
229+ # complete aux loss
230+
231+ aux_loss = entropy_aux_loss * self .entropy_loss_weight + commit_loss * self .commitment_loss_weight
232+
216233 return Return (x , indices , entropy_aux_loss )
0 commit comments