Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -18478,24 +18478,25 @@ def Torch_GmlQuantizeAffineOp : Torch_Op<"gml.quantize_affine", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `gml::quantize_affine : (Tensor, int[], Tensor, Tensor, int) -> (Tensor)`";
let summary = "Generated op for `gml::quantize_affine : (Tensor, Tensor, Tensor, int, int?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchListOfTorchIntType:$block_size,
AnyTorchTensorType:$scale,
AnyTorchTensorType:$zero_point,
Torch_IntType:$output_dtype
Torch_IntType:$output_dtype,
AnyTorchOptionalIntType:$axis,
AnyTorchOptionalIntType:$block_replication
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult GmlQuantizeAffineOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
return parseDefaultTorchOp(parser, result, 6, 1);
}
void GmlQuantizeAffineOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
printDefaultTorchOp(printer, *this, 6, 1);
}
}];
}
Expand All @@ -18505,25 +18506,26 @@ def Torch_GmlDequantizeAffineOp : Torch_Op<"gml.dequantize_affine", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `gml::dequantize_affine : (Tensor, int[], Tensor, Tensor, int, int) -> (Tensor)`";
let summary = "Generated op for `gml::dequantize_affine : (Tensor, Tensor, Tensor, int, int, int?, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchListOfTorchIntType:$block_size,
AnyTorchTensorType:$scale,
AnyTorchTensorType:$zero_point,
Torch_IntType:$input_dtype,
Torch_IntType:$output_dtype
Torch_IntType:$output_dtype,
AnyTorchOptionalIntType:$axis,
AnyTorchOptionalIntType:$block_replication
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult GmlDequantizeAffineOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 6, 1);
return parseDefaultTorchOp(parser, result, 7, 1);
}
void GmlDequantizeAffineOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 6, 1);
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
Expand Down
12 changes: 6 additions & 6 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16170,17 +16170,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.gml.quantize_affine\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.int) -> !torch.list<int> {\n"
" func.func @\"__torch_mlir_shape_fn.gml.quantize_affine\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.optional<int>, %arg5: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.gml.quantize_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.int) -> !torch.int {\n"
" return %arg4 : !torch.int\n"
" func.func @\"__torch_mlir_dtype_fn.gml.quantize_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.optional<int>, %arg5: !torch.optional<int>) -> !torch.int {\n"
" return %arg3 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.gml.dequantize_affine\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.int, %arg5: !torch.int) -> !torch.list<int> {\n"
" func.func @\"__torch_mlir_shape_fn.gml.dequantize_affine\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.optional<int>, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.gml.dequantize_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.int, %arg5: !torch.int) -> !torch.int {\n"
" return %arg5 : !torch.int\n"
" func.func @\"__torch_mlir_dtype_fn.gml.dequantize_affine\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.optional<int>, %arg6: !torch.optional<int>) -> !torch.int {\n"
" return %arg4 : !torch.int\n"
" }\n"
"}\n"
"";
Expand Down
22 changes: 10 additions & 12 deletions projects/pt1/python/torch_mlir/gml_ops/gml_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,50 +45,48 @@ def _fused_moe_cpu(

# Signature:
# input: Tensor (float32, float16, or bfloat16)
# block_size: List[int] - granularity of quantization
# scale: Tensor - quantization scale parameter(s)
# zero_point: Tensor - quantization zero point parameter(s)
# output_dtype: int - requested dtype (e.g. torch.uint8)
# axis: int? - axis along which to quantize (optional)
# block_replication: int? - number of elements per block (optional)
_lib.define(
"quantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, int output_dtype) -> Tensor"
"quantize_affine(Tensor input, Tensor scale, Tensor zero_point, int output_dtype, int? axis, int? block_replication) -> Tensor"
)


@impl("gml::quantize_affine", "Meta")
def _quantize_affine_meta(input, block_size, scale, zero_point, output_dtype):
def _quantize_affine_meta(input, scale, zero_point, output_dtype, axis, block_replication):
# Output has same shape as input but with output_dtype
return torch.empty_like(input, dtype=output_dtype, device="meta")


@impl("gml::quantize_affine", "CPU")
def _quantize_affine_cpu(input, block_size, scale, zero_point, output_dtype):
def _quantize_affine_cpu(input, scale, zero_point, output_dtype, axis, block_replication):
# CPU stub for safety in case CPU is used; maintain shape with output_dtype.
return torch.empty_like(input, dtype=output_dtype)


# Signature:
# input: Tensor (quantized tensor)
# block_size: List[int] - granularity of quantization
# scale: Tensor - quantization scale parameter(s)
# zero_point: Tensor - quantization zero point parameter(s)
# input_dtype: int - dtype of input tensor
# output_dtype: int - desired output dtype (default fp32)
# axis: int? - axis along which to dequantize (optional)
# block_replication: int? - number of elements per block (optional)
_lib.define(
"dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, int input_dtype, int output_dtype) -> Tensor"
"dequantize_affine(Tensor input, Tensor scale, Tensor zero_point, int input_dtype, int output_dtype, int? axis, int? block_replication) -> Tensor"
)


@impl("gml::dequantize_affine", "Meta")
def _dequantize_affine_meta(
input, block_size, scale, zero_point, input_dtype, output_dtype
):
def _dequantize_affine_meta(input, scale, zero_point, input_dtype, output_dtype, axis, block_replication):
# Output has same shape as input but with output_dtype (fp32/fp16/bf16)
return torch.empty_like(input, dtype=output_dtype, device="meta")


@impl("gml::dequantize_affine", "CPU")
def _dequantize_affine_cpu(
input, block_size, scale, zero_point, input_dtype, output_dtype
):
def _dequantize_affine_cpu(input, scale, zero_point, input_dtype, output_dtype, axis, block_replication):
# CPU stub for safety in case CPU is used; maintain shape with output_dtype.
return torch.empty_like(input, dtype=output_dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -5938,36 +5938,48 @@ def gml〇fused_moe〡dtype(input_rank_dtype: Tuple[int, int], gate_proj_rank_dt
# Per-tensor quantization
Invocation(
TensorOfShape(128, 512, dtype=torch.float32), # input
[128, 512], # block_size (same as input for per-tensor)
TensorOfShape(1, dtype=torch.float32), # scale
TensorOfShape(1, dtype=torch.int32), # zero_point
TensorOfShape(1, dtype=torch.float32), # zero_point
torch.int8, # output_dtype
0, # axis
512, # block_replication
),
# Per-channel quantization
Invocation(
TensorOfShape(256, 1024, dtype=torch.float16), # input
[1, 1024], # block_size (per-channel)
TensorOfShape(256, dtype=torch.float32), # scale
TensorOfShape(256, dtype=torch.int32), # zero_point
TensorOfShape(256, dtype=torch.float16), # zero_point
torch.int8, # output_dtype
1, # axis
1024, # block_replication
),
# Block-wise quantization
Invocation(
TensorOfShape(1024, 512, dtype=torch.bfloat16), # input
[128, 128], # block_size
TensorOfShape(8, 4, dtype=torch.float32), # scale (1024/128=8, 512/128=4)
TensorOfShape(8, 4, dtype=torch.int32), # zero_point
TensorOfShape(8, 4, dtype=torch.float32), # scale
TensorOfShape(8, 4, dtype=torch.bfloat16), # zero_point
torch.float8_e4m3fn, # output_dtype
1, # axis
2048, # block_replication
),
# Quantization with optional params as None
Invocation(
TensorOfShape(64, 256, dtype=torch.float32), # input
TensorOfShape(1, dtype=torch.float32), # scale
TensorOfShape(1, dtype=torch.float32), # zero_point
torch.int8, # output_dtype
None, # axis (optional)
None, # block_replication (optional)
),
]

@check_shape_function(GML_QUANTIZE_AFFINE_TESTS)
def gml〇quantize_affine〡shape(input: List[int], block_size: List[int], scale: List[int], zero_point: List[int], output_dtype: int) -> List[int]:
def gml〇quantize_affine〡shape(input: List[int], scale: List[int], zero_point: List[int], output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> List[int]:
# Quantize output has same shape as input
return input

@check_dtype_function(GML_QUANTIZE_AFFINE_TESTS)
def gml〇quantize_affine〡dtype(input_rank_dtype: Tuple[int, int], block_size: List[int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], output_dtype: int) -> int:
def gml〇quantize_affine〡dtype(input_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> int:
# Output dtype is determined by output_dtype parameter
return output_dtype

Expand All @@ -5976,39 +5988,52 @@ def gml〇quantize_affine〡dtype(input_rank_dtype: Tuple[int, int], block_size:
# Per-tensor dequantization
Invocation(
TensorOfShape(128, 512, dtype=torch.int8), # input
[128, 512], # block_size (same as input for per-tensor)
TensorOfShape(1, dtype=torch.float32), # scale
TensorOfShape(1, dtype=torch.int32), # zero_point
TensorOfShape(1, dtype=torch.float32), # zero_point
torch.int8, # input_dtype
torch.float32, # output_dtype
0, # axis
512, # block_replication
),
# Per-channel dequantization
Invocation(
TensorOfShape(256, 1024, dtype=torch.int8), # input
[1, 1024], # block_size (per-channel)
TensorOfShape(256, dtype=torch.float32), # scale
TensorOfShape(256, dtype=torch.int32), # zero_point
TensorOfShape(256, dtype=torch.float16), # zero_point
torch.int8, # input_dtype
torch.float16, # output_dtype
1, # axis
1024, # block_replication
),
# Block-wise dequantization
Invocation(
TensorOfShape(1024, 512, dtype=torch.uint8), # input
[128, 128], # block_size
TensorOfShape(8, 4, dtype=torch.float32), # scale
TensorOfShape(8, 4, dtype=torch.int32), # zero_point
TensorOfShape(8, 4, dtype=torch.bfloat16), # zero_point
torch.uint8, # input_dtype
torch.bfloat16, # output_dtype
1, # axis
2048, # block_replication
),
# Dequantization with optional params as None
Invocation(
TensorOfShape(64, 256, dtype=torch.int8), # input
TensorOfShape(1, dtype=torch.float32), # scale
TensorOfShape(1, dtype=torch.float32), # zero_point
torch.int8, # input_dtype
torch.float32, # output_dtype
None, # axis (optional)
None, # block_replication (optional)
),
]

@check_shape_function(GML_DEQUANTIZE_AFFINE_TESTS)
def gml〇dequantize_affine〡shape(input: List[int], block_size: List[int], scale: List[int], zero_point: List[int], input_dtype: int, output_dtype: int) -> List[int]:
def gml〇dequantize_affine〡shape(input: List[int], scale: List[int], zero_point: List[int], input_dtype: int, output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> List[int]:
# Dequantize output has same shape as input
return input

@check_dtype_function(GML_DEQUANTIZE_AFFINE_TESTS)
def gml〇dequantize_affine〡dtype(input_rank_dtype: Tuple[int, int], block_size: List[int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], input_dtype: int, output_dtype: int) -> int:
def gml〇dequantize_affine〡dtype(input_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], input_dtype: int, output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> int:
# Output dtype is determined by output_dtype parameter
return output_dtype

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1306,9 +1306,11 @@ def emit_with_mutating_variants(key, **kwargs):
emit(
"gml::fused_moe : (Tensor, Tensor[], Tensor[], Tensor[], Tensor, Tensor, str) -> (Tensor)"
)
emit("gml::quantize_affine : (Tensor, int[], Tensor, Tensor, int) -> (Tensor)")
emit(
"gml::dequantize_affine : (Tensor, int[], Tensor, Tensor, int, int) -> (Tensor)"
"gml::quantize_affine : (Tensor, Tensor, Tensor, int, int?, int?) -> (Tensor)"
)
emit(
"gml::dequantize_affine : (Tensor, Tensor, Tensor, int, int, int?, int?) -> (Tensor)"
)


Expand Down
Loading