From 404df4de6380653b28cc15c00bd069a110408146 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 10:16:04 +0200 Subject: [PATCH 1/8] Support HF integration in Mamba and Mamba2Simple + add metadata --- mamba_ssm/modules/mamba2.py | 12 ++++++++---- mamba_ssm/modules/mamba2_simple.py | 10 ++++++++-- mamba_ssm/modules/mamba_simple.py | 9 ++++++++- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 1859ab0d..8b3b0fa3 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -5,8 +5,8 @@ import torch import torch.nn as nn import torch.nn.functional as F - from einops import rearrange, repeat +from huggingface_hub import PyTorchModelHubMixin try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -31,10 +31,14 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined -from huggingface_hub import PyTorchModelHubMixin - -class Mamba2(nn.Module, PyTorchModelHubMixin): +class Mamba2( + nn.Module, + PyTorchModelHubMixin, + library_name="mamba_ssm", + repo_url="https://github.com/state-spaces/mamba", + tags=["mamba2", "arXiv:2312.00752", "arXiv:2405.21060"], + ): def __init__( self, d_model, diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index 026c674b..fb725c90 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn import torch.nn.functional as F - from einops import rearrange, repeat +from huggingface_hub import PyTorchModelHubMixin try: from causal_conv1d import causal_conv1d_fn @@ -21,7 +21,13 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined -class Mamba2Simple(nn.Module): +class Mamba2Simple( + nn.Module, + PyTorchModelHubMixin, + library_name="mamba_ssm", + repo_url="https://github.com/state-spaces/mamba", + tags=["mamba2simple", "arXiv:2312.00752", "arXiv:2405.21060"], + ): def __init__( self, d_model, diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 4c8a3882..f933b377 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -9,6 +9,7 @@ from torch import Tensor from einops import rearrange, repeat +from huggingface_hub import PyTorchModelHubMixin from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn @@ -28,7 +29,13 @@ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None -class Mamba(nn.Module): +class Mamba( + nn.Module, + PyTorchModelHubMixin, + library_name="mamba_ssm", + repo_url="https://github.com/state-spaces/mamba", + tags=["mamba", "arXiv:2312.00752"], + ): def __init__( self, d_model, From 44458c48ee1f606768c6e3936b4f35cd1ffc8f3f Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 10:17:30 +0200 Subject: [PATCH 2/8] Add huggingface_hub as core dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index dd8d8128..afa16060 100755 --- a/setup.py +++ b/setup.py @@ -374,6 +374,7 @@ def run(self): "einops", "triton", "transformers", + "huggingface_hub>=0.22", # "causal_conv1d>=1.4.0", ], ) From 961eccb89f92e34459ea3e6959081c833bba19c3 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 13:37:09 +0200 Subject: [PATCH 3/8] mamba-ssm as library name --- mamba_ssm/modules/mamba2.py | 2 +- mamba_ssm/modules/mamba2_simple.py | 2 +- mamba_ssm/modules/mamba_simple.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 8b3b0fa3..314dab44 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -35,7 +35,7 @@ class Mamba2( nn.Module, PyTorchModelHubMixin, - library_name="mamba_ssm", + library_name="mamba-ssm", repo_url="https://github.com/state-spaces/mamba", tags=["mamba2", "arXiv:2312.00752", "arXiv:2405.21060"], ): diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index fb725c90..a0d954e3 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -24,7 +24,7 @@ class Mamba2Simple( nn.Module, PyTorchModelHubMixin, - library_name="mamba_ssm", + library_name="mamba-ssm", repo_url="https://github.com/state-spaces/mamba", tags=["mamba2simple", "arXiv:2312.00752", "arXiv:2405.21060"], ): diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index f933b377..22f093f4 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -32,7 +32,7 @@ class Mamba( nn.Module, PyTorchModelHubMixin, - library_name="mamba_ssm", + library_name="mamba-ssm", repo_url="https://github.com/state-spaces/mamba", tags=["mamba", "arXiv:2312.00752"], ): From 4bd4af96bf3f16f0ade65b9da131c22fab5296cb Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 14:22:47 +0200 Subject: [PATCH 4/8] requires 0.23.5 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index afa16060..eeae8c10 100755 --- a/setup.py +++ b/setup.py @@ -374,7 +374,7 @@ def run(self): "einops", "triton", "transformers", - "huggingface_hub>=0.22", + "huggingface_hub>=0.23.5", # "causal_conv1d>=1.4.0", ], ) From a157ec5cf68ef7ae2673c7e646adf8411aeeef46 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 15:05:53 +0200 Subject: [PATCH 5/8] Mixin in MambaLMHeadModel only --- mamba_ssm/models/mixer_seq_simple.py | 47 ++++++++-------------------- mamba_ssm/modules/mamba2.py | 14 +++------ mamba_ssm/modules/mamba2_simple.py | 12 ++----- mamba_ssm/modules/mamba_simple.py | 11 ++----- mamba_ssm/utils/hf.py | 23 -------------- 5 files changed, 23 insertions(+), 84 deletions(-) delete mode 100644 mamba_ssm/utils/hf.py diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index fae2257a..f3bdf192 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -1,24 +1,20 @@ # Copyright (c) 2023, Albert Gu, Tri Dao. - -import math -from functools import partial -import json -import os import copy - +import math from collections import namedtuple +from functools import partial import torch import torch.nn as nn +from huggingface_hub import PyTorchModelHubMixin from mamba_ssm.models.config_mamba import MambaConfig -from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.block import Block from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mamba_simple import Mamba from mamba_ssm.modules.mha import MHA from mamba_ssm.modules.mlp import GatedMLP -from mamba_ssm.modules.block import Block from mamba_ssm.utils.generation import GenerationMixin -from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf try: from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn @@ -212,7 +208,14 @@ def forward(self, input_ids, inference_params=None, **mixer_kwargs): return hidden_states -class MambaLMHeadModel(nn.Module, GenerationMixin): +class MambaLMHeadModel( + nn.Module, + GenerationMixin, + PyTorchModelHubMixin, + library_name="mamba-ssm", + repo_url="https://github.com/state-spaces/mamba", + tags=["arXiv:2312.00752", "arXiv:2405.21060"], + ): def __init__( self, @@ -283,27 +286,3 @@ def forward(self, input_ids, position_ids=None, inference_params=None, num_last_ CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) return CausalLMOutput(logits=lm_logits) - @classmethod - def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): - config_data = load_config_hf(pretrained_model_name) - config = MambaConfig(**config_data) - model = cls(config, device=device, dtype=dtype, **kwargs) - model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) - return model - - def save_pretrained(self, save_directory): - """ - Minimal implementation of save_pretrained for MambaLMHeadModel. - Save the model and its configuration file to a directory. - """ - # Ensure save_directory exists - os.makedirs(save_directory, exist_ok=True) - - # Save the model's state_dict - model_path = os.path.join(save_directory, 'pytorch_model.bin') - torch.save(self.state_dict(), model_path) - - # Save the configuration of the model - config_path = os.path.join(save_directory, 'config.json') - with open(config_path, 'w') as f: - json.dump(self.config.__dict__, f, indent=4) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 314dab44..af15be00 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -5,8 +5,8 @@ import torch import torch.nn as nn import torch.nn.functional as F + from einops import rearrange, repeat -from huggingface_hub import PyTorchModelHubMixin try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -31,14 +31,10 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined +from huggingface_hub import PyTorchModelHubMixin + -class Mamba2( - nn.Module, - PyTorchModelHubMixin, - library_name="mamba-ssm", - repo_url="https://github.com/state-spaces/mamba", - tags=["mamba2", "arXiv:2312.00752", "arXiv:2405.21060"], - ): +class Mamba2(nn.Module, PyTorchModelHubMixin): def __init__( self, d_model, @@ -384,4 +380,4 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states if initialize_states: conv_state.zero_() ssm_state.zero_() - return conv_state, ssm_state + return conv_state, ssm_state \ No newline at end of file diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index a0d954e3..a4f6e1e4 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn import torch.nn.functional as F + from einops import rearrange, repeat -from huggingface_hub import PyTorchModelHubMixin try: from causal_conv1d import causal_conv1d_fn @@ -21,13 +21,7 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined -class Mamba2Simple( - nn.Module, - PyTorchModelHubMixin, - library_name="mamba-ssm", - repo_url="https://github.com/state-spaces/mamba", - tags=["mamba2simple", "arXiv:2312.00752", "arXiv:2405.21060"], - ): +class Mamba2Simple(nn.Module): def __init__( self, d_model, @@ -202,4 +196,4 @@ def forward(self, u, seq_idx=None): # Multiply "gate" branch and apply extra normalization layer y = self.norm(y, z) out = self.out_proj(y) - return out + return out \ No newline at end of file diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 22f093f4..7c365b6b 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -9,7 +9,6 @@ from torch import Tensor from einops import rearrange, repeat -from huggingface_hub import PyTorchModelHubMixin from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn @@ -29,13 +28,7 @@ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None -class Mamba( - nn.Module, - PyTorchModelHubMixin, - library_name="mamba-ssm", - repo_url="https://github.com/state-spaces/mamba", - tags=["mamba", "arXiv:2312.00752"], - ): +class Mamba(nn.Module): def __init__( self, d_model, @@ -298,4 +291,4 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states if initialize_states: conv_state.zero_() ssm_state.zero_() - return conv_state, ssm_state + return conv_state, ssm_state \ No newline at end of file diff --git a/mamba_ssm/utils/hf.py b/mamba_ssm/utils/hf.py deleted file mode 100644 index 0d7555ac..00000000 --- a/mamba_ssm/utils/hf.py +++ /dev/null @@ -1,23 +0,0 @@ -import json - -import torch - -from transformers.utils import WEIGHTS_NAME, CONFIG_NAME -from transformers.utils.hub import cached_file - - -def load_config_hf(model_name): - resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False) - return json.load(open(resolved_archive_file)) - - -def load_state_dict_hf(model_name, device=None, dtype=None): - # If not fp32, then we don't want to load directly to the GPU - mapped_device = "cpu" if dtype not in [torch.float32, None] else device - resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False) - return torch.load(resolved_archive_file, map_location=mapped_device) - # Convert dtype before moving to GPU to save memory - if dtype is not None: - state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} - state_dict = {k: v.to(device=device) for k, v in state_dict.items()} - return state_dict From bc402de69eff97c85d00c39c70e67b9916fadd56 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 15:08:00 +0200 Subject: [PATCH 6/8] last line --- mamba_ssm/modules/mamba2.py | 2 +- mamba_ssm/modules/mamba2_simple.py | 2 +- mamba_ssm/modules/mamba_simple.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index af15be00..1859ab0d 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -380,4 +380,4 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states if initialize_states: conv_state.zero_() ssm_state.zero_() - return conv_state, ssm_state \ No newline at end of file + return conv_state, ssm_state diff --git a/mamba_ssm/modules/mamba2_simple.py b/mamba_ssm/modules/mamba2_simple.py index a4f6e1e4..026c674b 100644 --- a/mamba_ssm/modules/mamba2_simple.py +++ b/mamba_ssm/modules/mamba2_simple.py @@ -196,4 +196,4 @@ def forward(self, u, seq_idx=None): # Multiply "gate" branch and apply extra normalization layer y = self.norm(y, z) out = self.out_proj(y) - return out \ No newline at end of file + return out diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 7c365b6b..4c8a3882 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -291,4 +291,4 @@ def _get_states_from_cache(self, inference_params, batch_size, initialize_states if initialize_states: conv_state.zero_() ssm_state.zero_() - return conv_state, ssm_state \ No newline at end of file + return conv_state, ssm_state From 0c4686fb7a6c4ee90a98e553ad9f0ff0d14643ce Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 15:08:24 +0200 Subject: [PATCH 7/8] remove mixin from mamba2 --- mamba_ssm/modules/mamba2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mamba_ssm/modules/mamba2.py b/mamba_ssm/modules/mamba2.py index 1859ab0d..85fd6dec 100644 --- a/mamba_ssm/modules/mamba2.py +++ b/mamba_ssm/modules/mamba2.py @@ -31,10 +31,8 @@ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined -from huggingface_hub import PyTorchModelHubMixin - -class Mamba2(nn.Module, PyTorchModelHubMixin): +class Mamba2(nn.Module): def __init__( self, d_model, From e536a97a4778144ca4580dee4268b7947fd58c34 Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Jul 2024 16:04:37 +0200 Subject: [PATCH 8/8] add pipeline_tag --- mamba_ssm/models/mixer_seq_simple.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index f3bdf192..bdd7efdb 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -215,6 +215,7 @@ class MambaLMHeadModel( library_name="mamba-ssm", repo_url="https://github.com/state-spaces/mamba", tags=["arXiv:2312.00752", "arXiv:2405.21060"], + pipeline_tag="text-generation", ): def __init__(