Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
f30ed49
marlin
Jan 17, 2024
afd6547
Merge branch 'marlin' of github.com:neuralmagic/vllm into marlin
Jan 17, 2024
837d344
added marlin
Jan 17, 2024
7a43b29
trying to load packed weights turning out to be tricky
Jan 18, 2024
e034640
trying to load packed weights turning out to be tricky due to qkv
Jan 18, 2024
15e8f9c
integrated marlin for single gpu
Jan 18, 2024
d8286fb
Update llama.py
robertgshaw2-redhat Jan 19, 2024
8bc625f
Fixes to Marlin quantization to allow execution via CUDA graphs captu…
alexm-redhat Jan 19, 2024
2691e89
Integrate @efrantar's changes for CUDA graphs
alexm-redhat Jan 19, 2024
92f7290
review comments based on zhyncs
alexm-redhat Jan 19, 2024
bc10e4b
(1) Integrate the latest changes from Elias that improve large batch …
alexm-redhat Jan 30, 2024
47987da
add bug fix
alexm-redhat Jan 30, 2024
43aa818
refactored some of alex's work to be consistent with the gptq config
Feb 1, 2024
5906a60
updated to load model based on hf_config from AutoGPTQ
Feb 1, 2024
8dfeaa2
Reduce Marlin's kernel limitation of thread_n from 256 to 64 (to avoi…
alexm-redhat Feb 2, 2024
c7fb928
Update checks related to MarlinConfig
alexm-redhat Feb 2, 2024
1ea85f3
formatting
alexm-redhat Feb 2, 2024
c876b79
Merge branch 'main' into marlin
robertgshaw2-redhat Feb 7, 2024
a435c97
Update pybind.cpp
robertgshaw2-redhat Feb 7, 2024
90e8b8f
Update ops.h
robertgshaw2-redhat Feb 7, 2024
b03af7d
Update ops.h
robertgshaw2-redhat Feb 7, 2024
9192287
readded marlin
Feb 7, 2024
ce50dd4
Bug fix for determination of the scales size in marlin layer
alexm-redhat Feb 8, 2024
5a305d3
Ensure marlin only compiles for GPU compute capability >= 8.0
alexm-redhat Feb 8, 2024
b1773aa
fix marlin compilation again
alexm-redhat Feb 8, 2024
036e0ca
Merge branch 'vllm-project:main' into marlin
robertgshaw2-redhat Feb 9, 2024
d63627e
added marlin test
Feb 18, 2024
18981b1
added marlin test
Feb 18, 2024
828c621
updated skipping logic
Feb 18, 2024
4f1759b
updated skipping logic
Feb 18, 2024
f1714e9
added memory profiling
Feb 18, 2024
e3a4706
added memory profiling
Feb 18, 2024
efd886c
test wout memory utilization
Feb 18, 2024
70f5850
updating memory profiling
Feb 18, 2024
567fe38
adding more profiling
Feb 18, 2024
01f5e40
updating memory profiling
Feb 18, 2024
fc5310c
removed memory profiling
Feb 18, 2024
99ab19d
cleaned up
Feb 18, 2024
eabeea6
added newline
Feb 18, 2024
d064595
ran ./format.sh
Feb 18, 2024
721351e
Merge branch 'upstream-main' into marlin
Feb 18, 2024
9b1bc5f
merged into upstream main
Feb 18, 2024
013f10f
Update test_marlin.py
robertgshaw2-redhat Feb 18, 2024
7f2165e
Update test_marlin.py
robertgshaw2-redhat Feb 18, 2024
79081ff
Merge branch 'main' into marlin
robertgshaw2-redhat Feb 19, 2024
7a9b828
updated retry testing to use pytest-flaky rather than implementing th…
Feb 19, 2024
c23902f
missed newline
Feb 19, 2024
e7aba66
formatting
Feb 19, 2024
2403f7d
removed silly print
Feb 19, 2024
aabaed2
added license
Feb 29, 2024
a67dc8d
format
Feb 29, 2024
8ff42c0
minor change for ruff
Feb 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ torch::Tensor awq_gemm(
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);

void marlin_gemm(
const torch::Tensor& input,
const torch::Tensor& weights,
torch::Tensor& output,
const torch::Tensor& scales,
torch::Tensor& workspace);
#endif

void squeezellm_gemm(
Expand Down
3 changes: 2 additions & 1 deletion csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifndef USE_ROCM
// Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");

// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
Expand Down
837 changes: 837 additions & 0 deletions csrc/quantization/marlin/marlin_cuda_kernel.cu

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def get_torch_arch_list() -> Set[str]:

if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
print("\n\n HERE \n\n")
vllm_extension_sources.append("csrc/quantization/marlin/marlin_cuda_kernel.cu")

vllm_extension = CUDAExtension(
name="vllm._C",
Expand Down
11 changes: 6 additions & 5 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode

def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None:
self.quantization = self.quantization.lower()

Expand All @@ -172,9 +172,10 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")
if self.quantization != "marlin":
logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.")

def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
Expand Down
35 changes: 35 additions & 0 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,14 @@ def weight_loader(self,
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to account
# for the tiling.
marlin_tile_size = getattr(param, "tile_size", None)
if marlin_tile_size is not None:
shard_size = shard_size * marlin_tile_size
shard_offset = shard_offset * marlin_tile_size

loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
Expand All @@ -297,6 +305,14 @@ def weight_loader(self,
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to account
# for the tiling.
marlin_tile_size = getattr(param, "tile_size", None)
if marlin_tile_size is not None:
shard_size = shard_size * marlin_tile_size
shard_offset = shard_offset * marlin_tile_size

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
Expand Down Expand Up @@ -376,7 +392,10 @@ def weight_loader(self,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)

if loaded_shard_id is None:
print("--------- HERE 2")

# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
Expand All @@ -397,6 +416,14 @@ def weight_loader(self,
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to account
# for the tiling.
marlin_tile_size = getattr(param, "tile_size", None)
if marlin_tile_size is not None:
shard_size = shard_size * marlin_tile_size
shard_offset = shard_offset * marlin_tile_size

loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
Expand All @@ -421,6 +448,14 @@ def weight_loader(self,
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor

# If marlin, we need to adjust the offset and size to account
# for the tiling
marlin_tile_size = getattr(param, "tile_size", None)
if marlin_tile_size is not None:
shard_size = shard_size * marlin_tile_size
shard_offset = shard_offset * marlin_tile_size

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig

_QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
}


Expand Down
188 changes: 188 additions & 0 deletions vllm/model_executor/layers/quantization/marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import numpy as np
from typing import Any, Dict, List, Optional

import torch
from torch.nn.parameter import Parameter

from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)

# Essentially all reasonable GPUs have less than 256 SMs so this should be safe for now
MAX_SMS = 256
# Tile size used by Marlin Kernels
TILE_SIZE = 16
# 4 Bits Packed Into 32 Bit Dtype
PACK_FACTOR = 32 // 4

class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""

def __init__(
self,
group_size: int,
) -> None:
self.group_size = group_size
# 4Bits packed into Int32.
self.pack_factor = 32 // 4
# Tile size of 16 used by Marlin.
self.tile_size = 16

# todo(rib-2): add channelwise support (-1).
if self.group_size != 128:
raise ValueError(
"Currently, only group size 128 is supported for Marlin "
f"but got {self.group_size} bits.")

def __repr__(self) -> str:
return (f"MarlinConfig(group_size={self.group_size}")

@classmethod
def get_name(cls) -> str:
return "marlin"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]

@classmethod
# Need to figure it out
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Need to figure it out

def get_min_capability(cls) -> int:
return 60

@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)

def get_linear_method(self) -> "MarlinLinearMethod":
return MarlinLinearMethod(self)

def get_scaled_act_names(self) -> List[str]:
return []

class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""

def __init__(self, quant_config: MarlinConfig):
self.quant_config = quant_config
self._perm_len = 1024

def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
del output_size # Unused.
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if input_size_per_partition % 128 != 0:
raise ValueError(
"The input_size_per_partition must be divisible by 128, "
f"but got {input_size_per_partition}")

if output_size_per_partition % 256 != 0:
raise ValueError(
"The output_size_per_partition must be divisible by 256, "
f"but got {output_size_per_partition}")

# check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self._perm_len // (self.quant_config.tile_size ** 2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu"
)

# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"tile_size": TILE_SIZE,
})

# Scales in Float16.
scales = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": None if input_size == input_size_per_partition else 0,
"output_dim": 1,
})

# Workspace for the marlin kernels.
self.workspace = torch.empty(MAX_SMS, dtype=torch.int)

return {
"B": qweight,
"s": scales,
}

def apply_weights(self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["B"]
scales = weights["s"]

output = torch.empty(
x.shape[:-1] + (scales.shape[1],),
dtype=x.dtype,
device=x.device
)
ops.marlin_gemm(
x.view(-1, x.shape[-1]),
qweight,
output.view(-1, output.shape[-1]),
scales,
self.workspace
)

if bias is not None:
output = output + bias
return output