Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions benchmarks/benchmark_generation_mamba_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

import argparse
import time
import json

import torch
import torch.nn.functional as F

from einops import rearrange

from transformers import AutoTokenizer, AutoModelForCausalLM

Expand Down
1 change: 0 additions & 1 deletion mamba_ssm/distributed/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Optional

import torch
from torch import Tensor
Expand Down
2 changes: 0 additions & 2 deletions mamba_ssm/modules/mamba_simple.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright (c) 2023, Tri Dao, Albert Gu.

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from einops import rearrange, repeat

Expand Down
1 change: 0 additions & 1 deletion mamba_ssm/ops/triton/selective_state_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this
"""

import math
import torch
import torch.nn.functional as F

Expand Down
2 changes: 0 additions & 2 deletions mamba_ssm/ops/triton/ssd_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@

import math
import torch
import torch.nn.functional as F

import triton
import triton.language as tl

from einops import rearrange, repeat


def init_to_zero(names):
Expand Down
4 changes: 0 additions & 4 deletions mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
"""We want triton==2.1.0 or 2.2.0 for this
"""

from typing import Optional

import math
from packaging import version

import torch
import torch.nn.functional as F
from torch import Tensor
from mamba_ssm.utils.torch import custom_bwd, custom_fwd

import triton
Expand All @@ -30,7 +28,6 @@
from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd
Expand All @@ -39,7 +36,6 @@
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev
from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd
from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd

Expand Down
3 changes: 1 addition & 2 deletions mamba_ssm/ops/triton/ssd_state_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
"""We want triton==2.1.0 or 2.2.0 for this
"""

import math
import torch
import torch.nn.functional as F

import triton
import triton.language as tl

from einops import rearrange, repeat
from einops import rearrange


@triton.autotune(
Expand Down
8 changes: 1 addition & 7 deletions mamba_ssm/utils/generation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
# Copyright (c) 2023, Albert Gu, Tri Dao.
import gc
import time
from collections import namedtuple
from dataclasses import dataclass, field
from functools import partial
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Optional

import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import Tensor
from torch.profiler import ProfilerActivity, profile, record_function
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer


Expand Down
1 change: 0 additions & 1 deletion mamba_ssm/utils/torch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from functools import partial
from typing import Callable

def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
Expand Down
3 changes: 0 additions & 3 deletions tests/ops/test_selective_scan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
# Copyright (C) 2023, Tri Dao.

import math

import torch
import torch.nn.functional as F
import pytest

from einops import rearrange

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref
Expand Down
3 changes: 1 addition & 2 deletions tests/ops/triton/test_layernorm_gated.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import math

import torch
import torch.nn.functional as F

import pytest

from einops import rearrange, repeat
from einops import rearrange

from mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref

Expand Down
4 changes: 1 addition & 3 deletions tests/ops/triton/test_selective_state_update.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright (C) 2023, Tri Dao.

import math

import torch
import torch.nn.functional as F
import pytest

from einops import rearrange, repeat
from einops import repeat

from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref

Expand Down
9 changes: 2 additions & 7 deletions tests/ops/triton/test_ssd.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import math

import torch
import torch.nn.functional as F

import pytest

from einops import rearrange, repeat
from einops import rearrange

from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state
from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen
from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref
from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd
from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref


def detach_clone(*args):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.utils.generation import InferenceParams

import pytest

from einops import rearrange, repeat
from einops import rearrange


def test_generation():
Expand Down