Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin

try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
Expand All @@ -31,10 +31,14 @@
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined

from huggingface_hub import PyTorchModelHubMixin


class Mamba2(nn.Module, PyTorchModelHubMixin):
class Mamba2(
nn.Module,
PyTorchModelHubMixin,
library_name="mamba_ssm",
repo_url="https://github.com/state-spaces/mamba",
tags=["mamba2", "arXiv:2312.00752", "arXiv:2405.21060"],
):
def __init__(
self,
d_model,
Expand Down
10 changes: 8 additions & 2 deletions mamba_ssm/modules/mamba2_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin

try:
from causal_conv1d import causal_conv1d_fn
Expand All @@ -21,7 +21,13 @@
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined


class Mamba2Simple(nn.Module):
class Mamba2Simple(
nn.Module,
PyTorchModelHubMixin,
library_name="mamba_ssm",
repo_url="https://github.com/state-spaces/mamba",
tags=["mamba2simple", "arXiv:2312.00752", "arXiv:2405.21060"],
):
def __init__(
self,
d_model,
Expand Down
9 changes: 8 additions & 1 deletion mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import Tensor

from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn

Expand All @@ -28,7 +29,13 @@
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


class Mamba(nn.Module):
class Mamba(
nn.Module,
PyTorchModelHubMixin,
library_name="mamba_ssm",
repo_url="https://github.com/state-spaces/mamba",
tags=["mamba", "arXiv:2312.00752"],
):
def __init__(
self,
d_model,
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def run(self):
"einops",
"triton",
"transformers",
"huggingface_hub>=0.22",
# "causal_conv1d>=1.4.0",
],
)