6
6
7
7
from einops import rearrange , repeat
8
8
9
- from causal_conv1d import causal_conv1d_fn
10
- import causal_conv1d_cuda
9
+ try :
10
+ from causal_conv1d import causal_conv1d_fn
11
+ import causal_conv1d_cuda
12
+ except ImportError :
13
+ causal_conv1d_fn = None
14
+ causal_conv1d_cuda = None
15
+
11
16
import selective_scan_cuda
12
17
13
18
@@ -168,6 +173,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
168
173
"""
169
174
xz: (batch, dim, seqlen)
170
175
"""
176
+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
171
177
assert checkpoint_lvl in [0 , 1 ]
172
178
L = xz .shape [- 1 ]
173
179
delta_rank = delta_proj_weight .shape [1 ]
@@ -196,7 +202,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
196
202
assert x .shape [2 ] == (d_conv - 1 ) * len (cu_seqlens [1 :- 1 ]) + z .shape [2 ]
197
203
198
204
conv1d_bias = conv1d_bias .contiguous () if conv1d_bias is not None else None
199
- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
205
+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , None , None , True )
200
206
201
207
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
202
208
if cu_seqlens is not None :
@@ -262,6 +268,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
262
268
@custom_bwd
263
269
def backward (ctx , dout ):
264
270
# dout: (batch, seqlen, dim)
271
+ assert causal_conv1d_cuda is not None , "causal_conv1d_cuda is not available. Please install causal-conv1d."
265
272
(xz , conv1d_weight , conv1d_bias , x_dbl , x_proj_weight , delta_proj_weight , out_proj_weight ,
266
273
conv1d_out , delta , A , B , C , D , delta_bias , scan_intermediates , out , cu_seqlens ) = ctx .saved_tensors
267
274
L = xz .shape [- 1 ]
@@ -285,7 +292,7 @@ def backward(ctx, dout):
285
292
x = padded_x
286
293
assert x .shape [2 ] == (d_conv - 1 ) * len (cu_seqlens [1 :- 1 ]) + z .shape [2 ]
287
294
288
- conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , True )
295
+ conv1d_out = causal_conv1d_cuda .causal_conv1d_fwd (x , conv1d_weight , conv1d_bias , None , None , None , True )
289
296
290
297
# (Optional Step2 for cu_seqlens): Mask conv1d ops in cumulative sequences
291
298
if cu_seqlens is not None :
@@ -345,8 +352,8 @@ def backward(ctx, dout):
345
352
dconv1d_out = rearrange (dconv1d_out , "d (b l) -> b d l" , b = x .shape [0 ], l = x .shape [- 1 ])
346
353
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
347
354
# backward of conv1d with the backward of chunk).
348
- dx , dconv1d_weight , dconv1d_bias = causal_conv1d_cuda .causal_conv1d_bwd (
349
- x , conv1d_weight , conv1d_bias , dconv1d_out , None , dx , True
355
+ dx , dconv1d_weight , dconv1d_bias , * _ = causal_conv1d_cuda .causal_conv1d_bwd (
356
+ x , conv1d_weight , conv1d_bias , dconv1d_out , None , None , None , dx , False , True
350
357
)
351
358
dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
352
359
dconv1d_weight = rearrange (dconv1d_weight , "d w -> d 1 w" )
@@ -374,11 +381,12 @@ def mamba_inner_ref(
374
381
A , B = None , C = None , D = None , delta_bias = None , B_proj_bias = None ,
375
382
C_proj_bias = None , delta_softplus = True
376
383
):
384
+ assert causal_conv1d_fn is not None , "causal_conv1d_fn is not available. Please install causal-conv1d."
377
385
L = xz .shape [- 1 ]
378
386
delta_rank = delta_proj_weight .shape [1 ]
379
387
d_state = A .shape [- 1 ] * (1 if not A .is_complex () else 2 )
380
388
x , z = xz .chunk (2 , dim = 1 )
381
- x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , "silu" )
389
+ x = causal_conv1d_fn (x , rearrange (conv1d_weight , "d 1 w -> d w" ), conv1d_bias , activation = "silu" )
382
390
# We're being very careful here about the layout, to avoid extra transposes.
383
391
# We want delta to have d as the slowest moving dimension
384
392
# and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
0 commit comments