diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f9d3ce0a6b2a..4d5b5e6947e3 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -18478,13 +18478,14 @@ def Torch_GmlQuantizeAffineOp : Torch_Op<"gml.quantize_affine", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `gml::quantize_affine : (Tensor, int[], Tensor, Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `gml::quantize_affine : (Tensor, Tensor, Tensor, int, int?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$input, - AnyTorchListOfTorchIntType:$block_size, AnyTorchTensorType:$scale, AnyTorchTensorType:$zero_point, - Torch_IntType:$output_dtype + Torch_IntType:$output_dtype, + AnyTorchOptionalIntType:$axis, + AnyTorchOptionalIntType:$block_replication ); let results = (outs AnyTorchOptionalTensorType:$result @@ -18492,10 +18493,10 @@ def Torch_GmlQuantizeAffineOp : Torch_Op<"gml.quantize_affine", [ let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ ParseResult GmlQuantizeAffineOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 5, 1); + return parseDefaultTorchOp(parser, result, 6, 1); } void GmlQuantizeAffineOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 5, 1); + printDefaultTorchOp(printer, *this, 6, 1); } }]; } @@ -18505,14 +18506,15 @@ def Torch_GmlDequantizeAffineOp : Torch_Op<"gml.dequantize_affine", [ HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `gml::dequantize_affine : (Tensor, int[], Tensor, Tensor, int, int) -> (Tensor)`"; + let summary = "Generated op for `gml::dequantize_affine : (Tensor, Tensor, Tensor, int, int, int?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$input, - AnyTorchListOfTorchIntType:$block_size, AnyTorchTensorType:$scale, AnyTorchTensorType:$zero_point, Torch_IntType:$input_dtype, - Torch_IntType:$output_dtype + Torch_IntType:$output_dtype, + AnyTorchOptionalIntType:$axis, + AnyTorchOptionalIntType:$block_replication ); let results = (outs AnyTorchOptionalTensorType:$result @@ -18520,10 +18522,10 @@ def Torch_GmlDequantizeAffineOp : Torch_Op<"gml.dequantize_affine", [ let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ ParseResult GmlDequantizeAffineOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 1); + return parseDefaultTorchOp(parser, result, 7, 1); } void GmlDequantizeAffineOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 1); + printDefaultTorchOp(printer, *this, 7, 1); } }]; } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bd9bf678aade..13cd0598c7b1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -16170,17 +16170,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.gml.quantize_affine\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.int) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.gml.quantize_affine\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.gml.quantize_affine\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.int) -> !torch.int {\n" -" return %arg4 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.gml.quantize_affine\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" +" return %arg3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.gml.dequantize_affine\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.int, %arg5: !torch.int) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.gml.dequantize_affine\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.gml.dequantize_affine\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.int, %arg5: !torch.int) -> !torch.int {\n" -" return %arg5 : !torch.int\n" +" func.func @\"__torch_mlir_dtype_fn.gml.dequantize_affine\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" return %arg4 : !torch.int\n" " }\n" "}\n" ""; diff --git a/projects/pt1/python/torch_mlir/gml_ops/gml_ops.py b/projects/pt1/python/torch_mlir/gml_ops/gml_ops.py index 7d88a14615fa..1a98a6365e70 100644 --- a/projects/pt1/python/torch_mlir/gml_ops/gml_ops.py +++ b/projects/pt1/python/torch_mlir/gml_ops/gml_ops.py @@ -45,50 +45,48 @@ def _fused_moe_cpu( # Signature: # input: Tensor (float32, float16, or bfloat16) -# block_size: List[int] - granularity of quantization # scale: Tensor - quantization scale parameter(s) # zero_point: Tensor - quantization zero point parameter(s) # output_dtype: int - requested dtype (e.g. torch.uint8) +# axis: int? - axis along which to quantize (optional) +# block_replication: int? - number of elements per block (optional) _lib.define( - "quantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, int output_dtype) -> Tensor" + "quantize_affine(Tensor input, Tensor scale, Tensor zero_point, int output_dtype, int? axis, int? block_replication) -> Tensor" ) @impl("gml::quantize_affine", "Meta") -def _quantize_affine_meta(input, block_size, scale, zero_point, output_dtype): +def _quantize_affine_meta(input, scale, zero_point, output_dtype, axis, block_replication): # Output has same shape as input but with output_dtype return torch.empty_like(input, dtype=output_dtype, device="meta") @impl("gml::quantize_affine", "CPU") -def _quantize_affine_cpu(input, block_size, scale, zero_point, output_dtype): +def _quantize_affine_cpu(input, scale, zero_point, output_dtype, axis, block_replication): # CPU stub for safety in case CPU is used; maintain shape with output_dtype. return torch.empty_like(input, dtype=output_dtype) # Signature: # input: Tensor (quantized tensor) -# block_size: List[int] - granularity of quantization # scale: Tensor - quantization scale parameter(s) # zero_point: Tensor - quantization zero point parameter(s) # input_dtype: int - dtype of input tensor # output_dtype: int - desired output dtype (default fp32) +# axis: int? - axis along which to dequantize (optional) +# block_replication: int? - number of elements per block (optional) _lib.define( - "dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, int input_dtype, int output_dtype) -> Tensor" + "dequantize_affine(Tensor input, Tensor scale, Tensor zero_point, int input_dtype, int output_dtype, int? axis, int? block_replication) -> Tensor" ) @impl("gml::dequantize_affine", "Meta") -def _dequantize_affine_meta( - input, block_size, scale, zero_point, input_dtype, output_dtype -): +def _dequantize_affine_meta(input, scale, zero_point, input_dtype, output_dtype, axis, block_replication): # Output has same shape as input but with output_dtype (fp32/fp16/bf16) return torch.empty_like(input, dtype=output_dtype, device="meta") @impl("gml::dequantize_affine", "CPU") -def _dequantize_affine_cpu( - input, block_size, scale, zero_point, input_dtype, output_dtype -): +def _dequantize_affine_cpu(input, scale, zero_point, input_dtype, output_dtype, axis, block_replication): # CPU stub for safety in case CPU is used; maintain shape with output_dtype. return torch.empty_like(input, dtype=output_dtype) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 791a21995200..c357eed9f508 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -5938,36 +5938,48 @@ def gml〇fused_moe〡dtype(input_rank_dtype: Tuple[int, int], gate_proj_rank_dt # Per-tensor quantization Invocation( TensorOfShape(128, 512, dtype=torch.float32), # input - [128, 512], # block_size (same as input for per-tensor) TensorOfShape(1, dtype=torch.float32), # scale - TensorOfShape(1, dtype=torch.int32), # zero_point + TensorOfShape(1, dtype=torch.float32), # zero_point torch.int8, # output_dtype + 0, # axis + 512, # block_replication ), # Per-channel quantization Invocation( TensorOfShape(256, 1024, dtype=torch.float16), # input - [1, 1024], # block_size (per-channel) TensorOfShape(256, dtype=torch.float32), # scale - TensorOfShape(256, dtype=torch.int32), # zero_point + TensorOfShape(256, dtype=torch.float16), # zero_point torch.int8, # output_dtype + 1, # axis + 1024, # block_replication ), # Block-wise quantization Invocation( TensorOfShape(1024, 512, dtype=torch.bfloat16), # input - [128, 128], # block_size - TensorOfShape(8, 4, dtype=torch.float32), # scale (1024/128=8, 512/128=4) - TensorOfShape(8, 4, dtype=torch.int32), # zero_point + TensorOfShape(8, 4, dtype=torch.float32), # scale + TensorOfShape(8, 4, dtype=torch.bfloat16), # zero_point torch.float8_e4m3fn, # output_dtype + 1, # axis + 2048, # block_replication + ), + # Quantization with optional params as None + Invocation( + TensorOfShape(64, 256, dtype=torch.float32), # input + TensorOfShape(1, dtype=torch.float32), # scale + TensorOfShape(1, dtype=torch.float32), # zero_point + torch.int8, # output_dtype + None, # axis (optional) + None, # block_replication (optional) ), ] @check_shape_function(GML_QUANTIZE_AFFINE_TESTS) -def gml〇quantize_affine〡shape(input: List[int], block_size: List[int], scale: List[int], zero_point: List[int], output_dtype: int) -> List[int]: +def gml〇quantize_affine〡shape(input: List[int], scale: List[int], zero_point: List[int], output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> List[int]: # Quantize output has same shape as input return input @check_dtype_function(GML_QUANTIZE_AFFINE_TESTS) -def gml〇quantize_affine〡dtype(input_rank_dtype: Tuple[int, int], block_size: List[int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], output_dtype: int) -> int: +def gml〇quantize_affine〡dtype(input_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> int: # Output dtype is determined by output_dtype parameter return output_dtype @@ -5976,39 +5988,52 @@ def gml〇quantize_affine〡dtype(input_rank_dtype: Tuple[int, int], block_size: # Per-tensor dequantization Invocation( TensorOfShape(128, 512, dtype=torch.int8), # input - [128, 512], # block_size (same as input for per-tensor) TensorOfShape(1, dtype=torch.float32), # scale - TensorOfShape(1, dtype=torch.int32), # zero_point + TensorOfShape(1, dtype=torch.float32), # zero_point torch.int8, # input_dtype torch.float32, # output_dtype + 0, # axis + 512, # block_replication ), # Per-channel dequantization Invocation( TensorOfShape(256, 1024, dtype=torch.int8), # input - [1, 1024], # block_size (per-channel) TensorOfShape(256, dtype=torch.float32), # scale - TensorOfShape(256, dtype=torch.int32), # zero_point + TensorOfShape(256, dtype=torch.float16), # zero_point torch.int8, # input_dtype torch.float16, # output_dtype + 1, # axis + 1024, # block_replication ), # Block-wise dequantization Invocation( TensorOfShape(1024, 512, dtype=torch.uint8), # input - [128, 128], # block_size TensorOfShape(8, 4, dtype=torch.float32), # scale - TensorOfShape(8, 4, dtype=torch.int32), # zero_point + TensorOfShape(8, 4, dtype=torch.bfloat16), # zero_point torch.uint8, # input_dtype torch.bfloat16, # output_dtype + 1, # axis + 2048, # block_replication + ), + # Dequantization with optional params as None + Invocation( + TensorOfShape(64, 256, dtype=torch.int8), # input + TensorOfShape(1, dtype=torch.float32), # scale + TensorOfShape(1, dtype=torch.float32), # zero_point + torch.int8, # input_dtype + torch.float32, # output_dtype + None, # axis (optional) + None, # block_replication (optional) ), ] @check_shape_function(GML_DEQUANTIZE_AFFINE_TESTS) -def gml〇dequantize_affine〡shape(input: List[int], block_size: List[int], scale: List[int], zero_point: List[int], input_dtype: int, output_dtype: int) -> List[int]: +def gml〇dequantize_affine〡shape(input: List[int], scale: List[int], zero_point: List[int], input_dtype: int, output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> List[int]: # Dequantize output has same shape as input return input @check_dtype_function(GML_DEQUANTIZE_AFFINE_TESTS) -def gml〇dequantize_affine〡dtype(input_rank_dtype: Tuple[int, int], block_size: List[int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], input_dtype: int, output_dtype: int) -> int: +def gml〇dequantize_affine〡dtype(input_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], input_dtype: int, output_dtype: int, axis: Optional[int], block_replication: Optional[int]) -> int: # Output dtype is determined by output_dtype parameter return output_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fea443e9443d..dd2dbb461001 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1306,9 +1306,11 @@ def emit_with_mutating_variants(key, **kwargs): emit( "gml::fused_moe : (Tensor, Tensor[], Tensor[], Tensor[], Tensor, Tensor, str) -> (Tensor)" ) - emit("gml::quantize_affine : (Tensor, int[], Tensor, Tensor, int) -> (Tensor)") emit( - "gml::dequantize_affine : (Tensor, int[], Tensor, Tensor, int, int) -> (Tensor)" + "gml::quantize_affine : (Tensor, Tensor, Tensor, int, int?, int?) -> (Tensor)" + ) + emit( + "gml::dequantize_affine : (Tensor, Tensor, Tensor, int, int, int?, int?) -> (Tensor)" )