@@ -53,7 +53,8 @@ def __init__(
5353 dim = None ,
5454 codebook_size = None ,
5555 entropy_loss_weight = 0.1 ,
56- diversity_gamma = 2.5
56+ diversity_gamma = 2.5 ,
57+ straight_through_activation = nn .Tanh ()
5758 ):
5859 super ().__init__ ()
5960
@@ -73,6 +74,10 @@ def __init__(
7374 self .dim = dim
7475 self .codebook_dim = codebook_dim
7576
77+ # straight through activation
78+
79+ self .activation = straight_through_activation
80+
7681 # entropy aux loss related weights
7782
7883 self .diversity_gamma = diversity_gamma
@@ -122,7 +127,7 @@ def forward(
122127
123128 is_img_or_video = x .ndim >= 4
124129
125- # rearrange if image or video into (batch, seq, dimension)
130+ # standardize image or video into (batch, seq, dimension)
126131
127132 if is_img_or_video :
128133 x = rearrange (x , 'b d ... -> b ... d' )
@@ -137,10 +142,10 @@ def forward(
137142 ones = torch .ones_like (x )
138143 quantized = torch .where (x > 0 , ones , - ones )
139144
140- # use straight-through gradients with tanh if training
145+ # use straight-through gradients with tanh (or custom activation fn) if training
141146
142147 if self .training :
143- x = torch . tanh (x * inv_temperature )
148+ x = self . activation (x * inv_temperature )
144149 x = x - x .detach () + quantized
145150 else :
146151 x = quantized
@@ -181,6 +186,4 @@ def forward(
181186
182187 indices = unpack_one (indices , ps , 'b *' )
183188
184- # bits to decimal for the codebook indices
185-
186189 return Return (x , indices , entropy_aux_loss )
0 commit comments