diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index d425bc72..876ff601 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -199,11 +199,11 @@ def selective_state_update(state, x, dt, A, B, C, D=None, z=None, dt_bias=None, state.stride(0), state.stride(1), state.stride(2), state.stride(3), x.stride(0), x.stride(1), x.stride(2), dt.stride(0), dt.stride(1), dt.stride(2), - *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + *( (dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0) ), A.stride(0), A.stride(1), A.stride(2), B.stride(0), B.stride(1), B.stride(2), C.stride(0), C.stride(1), C.stride(2), - *(D.stride(0), D.stride(1)) if D is not None else 0, + *( (D.stride(0), D.stride(1)) if D is not None else (0, 0) ), z_strides[0], z_strides[1], z_strides[2], out.stride(0), out.stride(1), out.stride(2), dt_softplus,