diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -62,6 +62,98 @@ - !ScalarExpression scalar_arg: B --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: quantized_matmul + cpp_class_name: QuantizedMatmulOp + doc: |- + Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)> + - !LinalgOperandDefConfig + name: B + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2] -> (s1, s2)> + - !LinalgOperandDefConfig + name: AZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: BZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: C + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> ()> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> ()> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + iterator_types: + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: AZp + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: B + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: BZp +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: mmt4d cpp_class_name: Mmt4DOp @@ -198,6 +290,99 @@ - !ScalarExpression scalar_arg: B --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: quantized_batch_matmul + cpp_class_name: QuantizedBatchMatmulOp + doc: |- + Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: A + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)> + - !LinalgOperandDefConfig + name: B + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s2, s3)> + - !LinalgOperandDefConfig + name: AZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: BZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: C + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> ()> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> ()> + - affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: AZp + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: B + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: BZp +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec cpp_class_name: MatvecOp diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1019,9 +1019,24 @@ loc, outputTy.getShape(), outputTy.getElementType()); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); - rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, - ValueRange{zeroTensor}); + if (!op.quantization_info()) { + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b()}, + ValueRange{zeroTensor}); + return success(); + } + + auto quantizationInfo = op.quantization_info().getValue(); + auto aZp = rewriter.create( + loc, rewriter.getI32IntegerAttr( + quantizationInfo.a_zp().getValue().getSExtValue())); + auto bZp = rewriter.create( + loc, rewriter.getI32IntegerAttr( + quantizationInfo.b_zp().getValue().getSExtValue())); + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, + ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor); + return success(); } }; @@ -1040,13 +1055,8 @@ auto bias = op.bias(); auto weightTy = weight.getType().cast(); - auto biasTy = bias.getType().cast(); - auto weightShape = weightTy.getShape(); - if (op.quantization_info()) - return failure(); - // Creating maps for the output of MatMul and the bias SmallVector indexingMaps; @@ -1081,14 +1091,29 @@ SmallVector newWeightShape{weightShape[1], weightShape[0]}; Type newWeightTy = - RankedTensorType::get(newWeightShape, biasTy.getElementType()); + RankedTensorType::get(newWeightShape, weightTy.getElementType()); Value transposedWeight = rewriter.create( loc, newWeightTy, weight, permutationValue); - rewriter.replaceOpWithNewOp( - op, TypeRange{op.getType()}, ValueRange{input, transposedWeight}, - linalgOp); + if (!op.quantization_info()) { + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, ValueRange{input, transposedWeight}, + linalgOp); + return success(); + } + + auto quantizationInfo = op.quantization_info().getValue(); + auto inputZp = rewriter.create( + loc, rewriter.getI32IntegerAttr( + quantizationInfo.input_zp().getValue().getSExtValue())); + auto outputZp = rewriter.create( + loc, rewriter.getI32IntegerAttr( + quantizationInfo.weight_zp().getValue().getSExtValue())); + rewriter.replaceOpWithNewOp( + op, TypeRange{op.getType()}, + ValueRange{input, transposedWeight, inputZp, outputZp}, linalgOp); + return success(); } }; diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -20,6 +20,22 @@ implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def quantized_matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.m, D.n, D.k) + C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.k, D.n]) - cast(U, BZp)) @linalg_structured_op def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), @@ -40,7 +56,6 @@ implements(ContractionOpInterface) accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) - @linalg_structured_op def batch_matmul( A=TensorDef(T1, Batch, S.M, S.K), @@ -55,6 +70,23 @@ implements(ContractionOpInterface) C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) +@linalg_structured_op +def quantized_batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.b, D.m, D.n, D.k) + C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.b, D.k, D.n]) - cast(U, BZp)) + @linalg_structured_op def matvec( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -855,6 +855,21 @@ // ----- + +// CHECK-LABEL: @matmul_quantized +func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) { + // CHECK: [[C0:%.+]] = constant 0 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 5, 6] + // CHECK: [[FILLED:%.+]] = linalg.fill([[C0]], [[INIT]]) : i32, tensor<1x5x6xi32> -> tensor<1x5x6xi32> + // CHECK: [[ONE:%.+]] = constant 1 + // CHECK: [[TWO:%.+]] = constant 2 + // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> + %0 = "tosa.matmul"(%arg0, %arg1) {quantization_info = {a_zp = 1 : i32, b_zp = 2 : i32}} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) + return %0 : tensor<1x5x6xi32> +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)> @@ -876,6 +891,29 @@ // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d1, d0)> + +// CHECK-LABEL: @quantized_fully_connected +func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) { + // CHECK: [[INITB:%.+]] = linalg.init_tensor [5, 6] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs([[INITB]] : tensor<5x6xi32>) { + // CHECK: ^bb0([[IN:%.+]]: i32, [[UNUSED:%.+]]: i32): + // CHECK: linalg.yield [[IN]] : i32 + // CHECK: [[INITT:%.+]] = linalg.init_tensor [3, 6] + // CHECK: [[TRANSPOSE:%.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg1 : tensor<6x3xi8>) outs([[INITT]] + // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): + // CHECK: linalg.yield [[IN]] : i8 + // CHECK: [[ONE:%.+]] = constant 1 + // CHECK: [[TWO:%.+]] = constant 2 + // CHECK: linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[GENERIC]] : tensor<5x6xi32>) -> tensor<5x6xi32> + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = {input_zp = 1:i32, weight_zp = 2:i32}} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> (tensor<5x6xi32>) + return %0 : tensor<5x6xi32> +} + +// ----- + func @pad_float(%arg0 : tensor<1x2xf32>) -> (tensor<4x9xf32>) { %0 = constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32> // TODO: Output contains multiple "constant 1 : index".