@@ -54,7 +54,9 @@ def __init__(
5454 codebook_size = None ,
5555 entropy_loss_weight = 0.1 ,
5656 diversity_gamma = 2.5 ,
57- straight_through_activation = nn .Tanh ()
57+ straight_through_activation = nn .Tanh (),
58+ num_codebooks = 1 ,
59+ keep_num_codebooks_dim = None
5860 ):
5961 super ().__init__ ()
6062
@@ -66,13 +68,19 @@ def __init__(
6668 codebook_size = default (codebook_size , lambda : 2 ** dim )
6769 codebook_dim = int (log2 (codebook_size ))
6870
69- dim = default (dim , codebook_dim )
71+ codebook_dims = codebook_dim * num_codebooks
72+ dim = default (dim , codebook_dims )
7073
71- self .project_in = nn .Linear (dim , codebook_dim ) if dim != codebook_dim else nn .Identity ()
72- self .project_out = nn .Linear (codebook_dim , dim ) if dim != codebook_dim else nn .Identity ()
74+ self .project_in = nn .Linear (dim , codebook_dims ) if dim != codebook_dims else nn .Identity ()
75+ self .project_out = nn .Linear (codebook_dims , dim ) if dim != codebook_dims else nn .Identity ()
7376
7477 self .dim = dim
7578 self .codebook_dim = codebook_dim
79+ self .num_codebooks = num_codebooks
80+
81+ keep_num_codebooks_dim = default (keep_num_codebooks_dim , num_codebooks > 1 )
82+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim )
83+ self .keep_num_codebooks_dim = keep_num_codebooks_dim
7684
7785 # straight through activation
7886
@@ -95,11 +103,16 @@ def indices_to_codes(
95103 ):
96104 is_img_or_video = indices .ndim >= 3
97105
106+ if not self .keep_num_codebooks_dim :
107+ indices = rearrange (indices , '... -> ... 1' )
108+
98109 # indices to codes, which are bits of either -1 or 1
99110
100111 bits = ((indices [..., None ].int () & self .mask ) != 0 ).float ()
101112 codes = bits * 2 - 1
102113
114+ codes = rearrange (codes , '... c d -> ... (c d)' )
115+
103116 # whether to project codes out to original dimensions
104117 # if the input feature dimensions were not log2(codebook size)
105118
@@ -123,6 +136,7 @@ def forward(
123136 b - batch
124137 n - sequence (or flattened spatial dimensions)
125138 d - feature dimension, which is also log2(codebook size)
139+ c - number of codebook dim
126140 """
127141
128142 is_img_or_video = x .ndim >= 4
@@ -133,10 +147,14 @@ def forward(
133147 x = rearrange (x , 'b d ... -> b ... d' )
134148 x , ps = pack_one (x , 'b * d' )
135149
136- assert x .shape [- 1 ] == self .dim
150+ assert x .shape [- 1 ] == self .dim , f'expected dimension of { self . dim } but received { x . shape [ - 1 ] } '
137151
138152 x = self .project_in (x )
139153
154+ # split out number of codebooks
155+
156+ x = rearrange (x , 'b n (c d) -> b n c d' , c = self .num_codebooks )
157+
140158 # quantize by eq 3.
141159
142160 ones = torch .ones_like (x )
@@ -152,7 +170,7 @@ def forward(
152170
153171 # calculate indices
154172
155- indices = reduce ((x > 0 ).int () * self .mask .int (), 'b n d -> b n' , 'sum' )
173+ indices = reduce ((x > 0 ).int () * self .mask .int (), 'b n c d -> b n c ' , 'sum' )
156174
157175 # entropy aux loss
158176
@@ -161,7 +179,7 @@ def forward(
161179
162180 bit_entropy = binary_entropy (prob ).mean ()
163181
164- avg_prob = reduce (prob , 'b n d -> b d' , 'mean' )
182+ avg_prob = reduce (prob , 'b n c d -> b c d' , 'mean' )
165183 codebook_entropy = binary_entropy (avg_prob ).mean ()
166184
167185 # 1. entropy will be nudged to be low for each bit, so each scalar commits to one latent binary bit or the other
@@ -174,6 +192,10 @@ def forward(
174192
175193 entropy_aux_loss = entropy_aux_loss * self .entropy_loss_weight
176194
195+ # merge back codebook dim
196+
197+ x = rearrange (x , 'b n c d -> b n (c d)' )
198+
177199 # project out to feature dimension if needed
178200
179201 x = self .project_out (x )
@@ -184,6 +206,11 @@ def forward(
184206 x = unpack_one (x , ps , 'b * d' )
185207 x = rearrange (x , 'b ... d -> b d ...' )
186208
187- indices = unpack_one (indices , ps , 'b *' )
209+ indices = unpack_one (indices , ps , 'b * c' )
210+
211+ # whether to remove single codebook dim
212+
213+ if not self .keep_num_codebooks_dim :
214+ indices = rearrange (indices , '... 1 -> ...' )
188215
189216 return Return (x , indices , entropy_aux_loss )
0 commit comments