Skip to content

Commit 59f0f45

Browse files
committed
final tweak to lfq
1 parent a0d14af commit 59f0f45

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vector_quantize_pytorch',
55
packages = find_packages(),
6-
version = '1.9.5',
6+
version = '1.9.6',
77
license='MIT',
88
description = 'Vector Quantization - Pytorch',
99
long_description_content_type = 'text/markdown',

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def __init__(
5353
dim = None,
5454
codebook_size = None,
5555
entropy_loss_weight = 0.1,
56-
diversity_gamma = 2.5
56+
diversity_gamma = 2.5,
57+
straight_through_activation = nn.Tanh()
5758
):
5859
super().__init__()
5960

@@ -73,6 +74,10 @@ def __init__(
7374
self.dim = dim
7475
self.codebook_dim = codebook_dim
7576

77+
# straight through activation
78+
79+
self.activation = straight_through_activation
80+
7681
# entropy aux loss related weights
7782

7883
self.diversity_gamma = diversity_gamma
@@ -122,7 +127,7 @@ def forward(
122127

123128
is_img_or_video = x.ndim >= 4
124129

125-
# rearrange if image or video into (batch, seq, dimension)
130+
# standardize image or video into (batch, seq, dimension)
126131

127132
if is_img_or_video:
128133
x = rearrange(x, 'b d ... -> b ... d')
@@ -137,10 +142,10 @@ def forward(
137142
ones = torch.ones_like(x)
138143
quantized = torch.where(x > 0, ones, -ones)
139144

140-
# use straight-through gradients with tanh if training
145+
# use straight-through gradients with tanh (or custom activation fn) if training
141146

142147
if self.training:
143-
x = torch.tanh(x * inv_temperature)
148+
x = self.activation(x * inv_temperature)
144149
x = x - x.detach() + quantized
145150
else:
146151
x = quantized
@@ -181,6 +186,4 @@ def forward(
181186

182187
indices = unpack_one(indices, ps, 'b *')
183188

184-
# bits to decimal for the codebook indices
185-
186189
return Return(x, indices, entropy_aux_loss)

0 commit comments

Comments
 (0)