@@ -38,36 +38,11 @@ def unpack_one(t, ps, pattern):
3838
3939# entropy
4040
41- def binary_entropy (prob ):
42- return - prob * log (prob ) - (1 - prob ) * log (1 - prob )
43-
44- # tensor helpers
45-
4641def log (t , eps = 1e-20 ):
4742 return t .clamp (min = eps ).log ()
4843
49- # convert to bit representations and back
50-
51- def decimal_to_bits (x , bits ):
52- device = x .device
53-
54- x = x .int ()
55-
56- mask = 2 ** torch .arange (bits - 1 , - 1 , - 1 , device = device )
57- x = rearrange (x , 'b n -> b n 1' )
58-
59- bits = ((x & mask ) != 0 ).float ()
60- bits = rearrange (bits , 'b n d -> b n d' )
61- return bits * 2 - 1
62-
63- def bits_to_decimal (x , bits ):
64- device = x .device
65-
66- x = (x > 0 ).int ()
67-
68- mask = 2 ** torch .arange (bits - 1 , - 1 , - 1 , device = device , dtype = torch .int32 )
69- dec = reduce (x * mask , 'b n d -> b n' , 'sum' )
70- return dec
44+ def binary_entropy (prob ):
45+ return - prob * log (prob ) - (1 - prob ) * log (1 - prob )
7146
7247# class
7348
@@ -105,6 +80,7 @@ def __init__(
10580
10681 # for no auxiliary loss, during inference
10782
83+ self .register_buffer ('mask' , 2 ** torch .arange (codebook_dim - 1 , - 1 , - 1 ))
10884 self .register_buffer ('zero' , torch .zeros (1 ,), persistent = False )
10985
11086 def indices_to_codes (
@@ -114,14 +90,10 @@ def indices_to_codes(
11490 ):
11591 is_img_or_video = indices .ndim >= 3
11692
117- # rearrange if image or video into (batch, seq, dimension)
118-
119- if is_img_or_video :
120- indices , ps = pack_one (indices , 'b *' )
121-
12293 # indices to codes, which are bits of either -1 or 1
12394
124- codes = decimal_to_bits (indices , self .codebook_dim )
95+ bits = ((indices [..., None ].int () & self .mask ) != 0 ).float ()
96+ codes = bits * 2 - 1
12597
12698 # whether to project codes out to original dimensions
12799 # if the input feature dimensions were not log2(codebook size)
@@ -132,7 +104,6 @@ def indices_to_codes(
132104 # rearrange codes back to original shape
133105
134106 if is_img_or_video :
135- codes = unpack_one (codes , ps , 'b * d' )
136107 codes = rearrange (codes , 'b ... d -> b d ...' )
137108
138109 return codes
@@ -163,10 +134,8 @@ def forward(
163134
164135 # quantize by eq 3.
165136
166- greater_than_zero = x > 0
167137 ones = torch .ones_like (x )
168-
169- quantized = torch .where (greater_than_zero , ones , - ones )
138+ quantized = torch .where (x > 0 , ones , - ones )
170139
171140 # use straight-through gradients with tanh if training
172141
@@ -178,7 +147,7 @@ def forward(
178147
179148 # calculate indices
180149
181- indices = bits_to_decimal ( x , self .codebook_dim )
150+ indices = reduce (( x > 0 ). int () * self .mask . int (), 'b n d -> b n' , 'sum' )
182151
183152 # entropy aux loss
184153
0 commit comments