diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -709,6 +709,121 @@ - !ScalarExpression scalar_arg: K --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: conv_2d_input_nhwc_filter_ohwi_poly_q + cpp_class_name: Conv2DInputNhwcFilterOhwiPolyQOp + doc: |- + Performs a 2-D quantized convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Includes zero point + adjustment for quantization. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s4, s5, s6, s3)> + - !LinalgOperandDefConfig + name: IZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: KZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s7, s8, s4)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s9, s10)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s11, s12)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d5, d3, d4, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d1, d2, d5)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + - reduction + - parallel + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: IZp + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: K + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: KZp +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_input_nhwc_filter_hwc_poly cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -862,14 +862,24 @@ ShapedType resultTy = op->getResult(0).getType().cast(); Type inputETy = inputTy.getElementType(); - Type weightETy = weightTy.getElementType(); - Type biasETy = biasTy.getElementType(); Type resultETy = resultTy.getElementType(); auto padAttr = op->getAttr("pad").cast(); auto strideTosaAttr = op->getAttr("stride").cast(); auto dilationTosaAttr = op->getAttr("dilation").cast(); + bool isQuantized = op->hasAttr("quantization_info"); + IntegerAttr iZp; + IntegerAttr kZp; + if (isQuantized) { + auto quantizationInfo = + op->getAttr("quantization_info").cast(); + iZp = rewriter.getI32IntegerAttr( + quantizationInfo.input_zp().getValue().getSExtValue()); + kZp = rewriter.getI32IntegerAttr( + quantizationInfo.weight_zp().getValue().getSExtValue()); + } + if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) return rewriter.notifyMatchFailure(op, @@ -878,11 +888,6 @@ auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); - // TODO(suderman): Support other types. - if (!inputETy.isF32() || !weightETy.isF32() || !biasETy.isF32() || - !resultETy.isF32()) - return failure(); - // Apply padding as necessary. Attribute zeroAttr = rewriter.getZeroAttr(inputETy); llvm::SmallVector pad; @@ -924,14 +929,23 @@ auto dilationAttr = DenseIntElementsAttr::get( RankedTensorType::get({2}, rewriter.getI64Type()), dilation); - if (isa(op)) { + if (isa(op) && !isQuantized) { rewriter.replaceOpWithNewOp( op, resultTy, ValueRange{input, weight}, ValueRange{biasBroadcast}, strideAttr, dilationAttr); return success(); } - if (isa(op)) { + if (isa(op) && isQuantized) { + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); + rewriter.replaceOpWithNewOp( + op, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{biasBroadcast}, strideAttr, dilationAttr); + return success(); + } + + if (isa(op) && !isQuantized) { ShapedType linalgConvTy = RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], weightShape[2], weightShape[3]}, diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -164,6 +164,30 @@ D.ow * S.SW + D.kw * S.DW, D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic]) +@linalg_structured_op +def conv_2d_input_nhwc_filter_ohwi_poly_q( + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs a 2-D quantized convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Includes zero point + adjustment for quantization. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic) + O[D.n, D.oh, D.ow, D.oc] += ((cast( + U, I[D.n, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic]) - cast(U, IZp)) * + (cast(U, K[D.oc, D.kh, D.kw, D.ic]) - cast(U, KZp))) + + @linalg_structured_op def depthwise_conv_2d_input_nhwc_filter_hwc_poly( I=TensorDef(T1, S.N, S.IH, S.IW, S.C), diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1184,17 +1184,21 @@ // ----- -// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: @conv2d_f32 func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 45, 40, 28] - // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP2]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) + // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<1x45x40x28xf32>) %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<1x45x40x28xf32>) return } +// ----- + +// CHECK-LABEL: @conv2d_padded_f32 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () { // CHECK: linalg.pad_tensor %arg0 // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly @@ -1204,6 +1208,24 @@ // ----- +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @conv2d_quant +func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1xi8>, %arg2 : tensor<1024xi32>) -> () { + // CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 10, 10, 1024] + // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1024xi32>) outs(%[[INIT]] : tensor<1x10x10x1024xi32>) + // CHECK: ^bb0(%arg3: i32, %arg4: i32): + // CHECK: linalg.yield %arg3 : i32 + // CHECK: %[[C128:.+]] = constant -128 + // CHECK: %[[C42:.+]] = constant 42 + // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>) + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x10x10x1024xi32> + return +} + +// ----- + // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>