6
6
7
7
from einops import rearrange , repeat
8
8
9
- from causal_conv1d import causal_conv1d_fn
10
- import causal_conv1d_cuda
9
+ try :
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError :
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
11
16
import selective_scan_cuda
12
17
13
18
@@ -168,6 +173,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
168
173
"""
169
174
xz: (batch, dim, seqlen)
170
175
"""
176
+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
171
177
assert checkpoint_lvl in [0 , 1 ]
172
178
L = xz .shape [- 1 ]
173
179
delta_rank = delta_proj_weight .shape [1 ]
@@ -196,7 +202,9 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
196
202
assert x .shape [2 ] == (d_conv - 1 ) * len (cu_seqlens [1 :- 1 ]) + z .shape [2 ]
197
203
198
204
conv1d_bias = conv1d_bias .contiguous () if conv1d_bias is not None else None
199
- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
205
+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (
206
+ x , conv1d_weight , conv1d_bias , None , None , None , True
207
+ )
200
208
201
209
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
202
210
if cu_seqlens is not None :
@@ -262,6 +270,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
262
270
@custom_bwd
263
271
def backward (ctx , dout ):
264
272
# dout: (batch, seqlen, dim)
273
+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
265
274
(xz , conv1d_weight , conv1d_bias , x_dbl , x_proj_weight , delta_proj_weight , out_proj_weight ,
266
275
conv1d_out , delta , A , B , C , D , delta_bias , scan_intermediates , out , cu_seqlens ) = ctx .saved_tensors
267
276
L = xz .shape [- 1 ]
@@ -285,8 +294,10 @@ def backward(ctx, dout):
285
294
x = padded_x
286
295
assert x .shape [2 ] == (d_conv - 1 ) * len (cu_seqlens [1 :- 1 ]) + z .shape [2 ]
287
296
288
- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
289
-
297
+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (
298
+ x , conv1d_weight , conv1d_bias , None , None , None , True
299
+ )
300
+
290
301
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
291
302
if cu_seqlens is not None :
292
303
mask = []
@@ -345,8 +356,8 @@ def backward(ctx, dout):
345
356
dconv1d_out = rearrange (dconv1d_out , "d (b l) -> b d l" , b = x .shape [0 ], l = x .shape [- 1 ])
346
357
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
347
358
# backward of conv1d with the backward of chunk).
348
- dx , dconv1d_weight , dconv1d_bias = causal_conv1d_cuda .causal_conv1d_bwd (
349
- x , conv1d_weight , conv1d_bias , dconv1d_out , None , dx , True
359
+ dx , dconv1d_weight , dconv1d_bias , * _ = causal_conv1d_cuda .causal_conv1d_bwd (
360
+ x , conv1d_weight , conv1d_bias , dconv1d_out , None , None , None , dx , False , True
350
361
)
351
362
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
352
363
dconv1d_weight = rearrange (dconv1d_weight , "d w -> d 1 w" )
@@ -374,11 +385,12 @@ def mamba_inner_ref(
374
385
A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
375
386
C_proj_bias = None , delta_softplus = True
376
387
):
388
+ assert causal_conv1d_fn is not None , "causal_conv1d_fn is not available. Please install causal-conv1d."
377
389
L = xz .shape [- 1 ]
378
390
delta_rank = delta_proj_weight .shape [1 ]
379
391
d_state = A .shape [- 1 ] * (1 if not A .is_complex () else 2 )
380
392
x , z = xz .chunk (2 , dim = 1 )
381
- x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , "silu" )
393
+ x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , activation = "silu" )
382
394
# We're being very careful here about the layout, to avoid extra transposes.
383
395
# We want delta to have d as the slowest moving dimension
384
396
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
0 commit comments