diff --git a/benchmarks/benchmark_generation_mamba_simple.py b/benchmarks/benchmark_generation_mamba_simple.py index f3513b24..fbd4ac4f 100644 --- a/benchmarks/benchmark_generation_mamba_simple.py +++ b/benchmarks/benchmark_generation_mamba_simple.py @@ -2,12 +2,9 @@ import argparse import time -import json import torch -import torch.nn.functional as F -from einops import rearrange from transformers import AutoTokenizer, AutoModelForCausalLM diff --git a/mamba_ssm/distributed/distributed_utils.py b/mamba_ssm/distributed/distributed_utils.py index 74c55279..1ad79158 100644 --- a/mamba_ssm/distributed/distributed_utils.py +++ b/mamba_ssm/distributed/distributed_utils.py @@ -1,4 +1,3 @@ -from typing import Optional import torch from torch import Tensor diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 4c8a3882..1a5246b5 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -1,12 +1,10 @@ # Copyright (c) 2023, Tri Dao, Albert Gu. import math -from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F -from torch import Tensor from einops import rearrange, repeat diff --git a/mamba_ssm/ops/triton/selective_state_update.py b/mamba_ssm/ops/triton/selective_state_update.py index d425bc72..34fa54b8 100644 --- a/mamba_ssm/ops/triton/selective_state_update.py +++ b/mamba_ssm/ops/triton/selective_state_update.py @@ -3,7 +3,6 @@ """We want triton==2.1.0 or triton==2.2.0 or triton==2.3.0 for this """ -import math import torch import torch.nn.functional as F diff --git a/mamba_ssm/ops/triton/ssd_bmm.py b/mamba_ssm/ops/triton/ssd_bmm.py index 48fd4f06..d549f83c 100644 --- a/mamba_ssm/ops/triton/ssd_bmm.py +++ b/mamba_ssm/ops/triton/ssd_bmm.py @@ -5,12 +5,10 @@ import math import torch -import torch.nn.functional as F import triton import triton.language as tl -from einops import rearrange, repeat def init_to_zero(names): diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index bbf4ecf8..bd1fc950 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -3,14 +3,12 @@ """We want triton==2.1.0 or 2.2.0 for this """ -from typing import Optional import math from packaging import version import torch import torch.nn.functional as F -from torch import Tensor from mamba_ssm.utils.torch import custom_bwd, custom_fwd import triton @@ -30,7 +28,6 @@ from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_cumsum_bwd from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_fwd, _chunk_state_bwd_db -from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_state_bwd_ddAcs_stable from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd, _state_passing_bwd @@ -39,7 +36,6 @@ from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_dC, _chunk_scan_bwd_dcb from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_stable from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref -from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_bwd_ddAcs_prev from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn, _layer_norm_fwd, _layer_norm_bwd from mamba_ssm.ops.triton.k_activations import _swiglu_fwd, _swiglu_bwd diff --git a/mamba_ssm/ops/triton/ssd_state_passing.py b/mamba_ssm/ops/triton/ssd_state_passing.py index 63863b82..98e6e98d 100644 --- a/mamba_ssm/ops/triton/ssd_state_passing.py +++ b/mamba_ssm/ops/triton/ssd_state_passing.py @@ -3,14 +3,13 @@ """We want triton==2.1.0 or 2.2.0 for this """ -import math import torch import torch.nn.functional as F import triton import triton.language as tl -from einops import rearrange, repeat +from einops import rearrange @triton.autotune( diff --git a/mamba_ssm/utils/generation.py b/mamba_ssm/utils/generation.py index 330672af..ba36b62a 100644 --- a/mamba_ssm/utils/generation.py +++ b/mamba_ssm/utils/generation.py @@ -1,16 +1,10 @@ # Copyright (c) 2023, Albert Gu, Tri Dao. import gc -import time -from collections import namedtuple from dataclasses import dataclass, field -from functools import partial -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Optional import torch -import torch.nn.functional as F -from einops import rearrange, repeat from torch import Tensor -from torch.profiler import ProfilerActivity, profile, record_function from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer diff --git a/mamba_ssm/utils/torch.py b/mamba_ssm/utils/torch.py index 37df47c8..0a20eca7 100644 --- a/mamba_ssm/utils/torch.py +++ b/mamba_ssm/utils/torch.py @@ -1,5 +1,4 @@ import torch -from functools import partial from typing import Callable def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): diff --git a/tests/ops/test_selective_scan.py b/tests/ops/test_selective_scan.py index 8a834b3c..42aea4f9 100644 --- a/tests/ops/test_selective_scan.py +++ b/tests/ops/test_selective_scan.py @@ -1,12 +1,9 @@ # Copyright (C) 2023, Tri Dao. -import math import torch -import torch.nn.functional as F import pytest -from einops import rearrange from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, mamba_inner_ref diff --git a/tests/ops/triton/test_layernorm_gated.py b/tests/ops/triton/test_layernorm_gated.py index de669e85..3cd18441 100644 --- a/tests/ops/triton/test_layernorm_gated.py +++ b/tests/ops/triton/test_layernorm_gated.py @@ -1,11 +1,10 @@ -import math import torch import torch.nn.functional as F import pytest -from einops import rearrange, repeat +from einops import rearrange from mamba_ssm.ops.triton.layernorm_gated import layernorm_fn, rms_norm_ref diff --git a/tests/ops/triton/test_selective_state_update.py b/tests/ops/triton/test_selective_state_update.py index 55408c89..4c233e1e 100644 --- a/tests/ops/triton/test_selective_state_update.py +++ b/tests/ops/triton/test_selective_state_update.py @@ -1,12 +1,10 @@ # Copyright (C) 2023, Tri Dao. -import math import torch -import torch.nn.functional as F import pytest -from einops import rearrange, repeat +from einops import repeat from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref diff --git a/tests/ops/triton/test_ssd.py b/tests/ops/triton/test_ssd.py index d45152d6..e1bac114 100644 --- a/tests/ops/triton/test_ssd.py +++ b/tests/ops/triton/test_ssd.py @@ -1,20 +1,15 @@ -import math import torch import torch.nn.functional as F import pytest -from einops import rearrange, repeat +from einops import rearrange -from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state, chunk_state_ref +from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state from mamba_ssm.ops.triton.ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd from mamba_ssm.ops.triton.ssd_chunk_state import chunk_state_varlen -from mamba_ssm.ops.triton.ssd_state_passing import state_passing, state_passing_ref from mamba_ssm.ops.triton.ssd_state_passing import _state_passing_fwd -from mamba_ssm.ops.triton.ssd_chunk_scan import chunk_scan, chunk_scan_ref -from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_chunk_scan, ssd_chunk_scan_combined_ref, ssd_selective_scan -from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined, mamba_split_conv1d_scan_ref def detach_clone(*args): diff --git a/tests/test_generation.py b/tests/test_generation.py index 77e1aedf..7daa5533 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -5,9 +5,8 @@ from mamba_ssm.models.config_mamba import MambaConfig from mamba_ssm.utils.generation import InferenceParams -import pytest -from einops import rearrange, repeat +from einops import rearrange def test_generation():