diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -905,6 +905,202 @@ - !ScalarExpression scalar_arg: K --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: depthwise_conv_2D_nchw + cpp_class_name: DepthwiseConv2DNchwOp + doc: |- + Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s4, s5, s3, s6)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s7, s8, s3, s6)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s9, s10)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s11, s12)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d3, d4, d5, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + - reduction + - parallel + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: K +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: depthwise_conv2D_nchw_q + cpp_class_name: DepthwiseConv2DNchwQOp + doc: |- + Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s4, s5, s3, s6)> + - !LinalgOperandDefConfig + name: IZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: KZp + usage: InputOperand + type_var: I32 + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s7, s8, s3, s6)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s9, s10)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s11, s12)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d3, d4, d5, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)> + iterator_types: + - parallel + - parallel + - parallel + - reduction + - reduction + - parallel + - parallel + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: IZp + - !ScalarExpression + scalar_apply: + fn_name: sub + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: K + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: KZp +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw cpp_class_name: Conv2DNchwOp @@ -1700,4 +1896,3 @@ operands: - !ScalarExpression scalar_arg: I - diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -946,7 +946,7 @@ return success(); } - if (isa(op) && !isQuantized) { + if (isa(op)) { ShapedType linalgConvTy = RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2], weightShape[2], weightShape[3]}, @@ -954,11 +954,23 @@ Value biasReshape = rewriter.create(loc, linalgConvTy, biasBroadcast); - Value conv = rewriter - .create( - loc, linalgConvTy, ValueRange{input, weight}, - ValueRange{biasReshape}, dilationAttr, strideAttr) - .getResult(0); + Value conv; + if (!isQuantized) { + conv = rewriter + .create( + loc, linalgConvTy, ValueRange{input, weight}, + ValueRange{biasReshape}, dilationAttr, strideAttr) + .getResult(0); + } else { + auto iZpVal = rewriter.create(loc, iZp); + auto kZpVal = rewriter.create(loc, kZp); + conv = + rewriter + .create( + loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{biasReshape}, dilationAttr, strideAttr) + .getResult(0); + } Value reshape = rewriter.create(loc, resultTy, conv); rewriter.replaceOp(op, reshape); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -223,6 +223,43 @@ ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) +def depthwise_conv2D_nchw( #TODO: Fix name + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) + + +def depthwise_conv2D_nchw_q( #TODO: Fix name + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + O[D.n, D.oh, D.ow, D.ic, D.cm] += ( + (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.ic]) - cast(U, IZp)) * + (cast(U, K[D.kh, D.kw, D.ic, D.cm]) - cast(U, KZp))) + + @linalg_structured_op def pooling_nhwc_sum( I=TensorDef(T1, S.N, S.H, S.W, S.C), diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1219,7 +1219,7 @@ // CHECK: linalg.yield %arg3 : i32 // CHECK: %[[C128:.+]] = constant -128 // CHECK: %[[C42:.+]] = constant 42 - // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>) + // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, i32, i32) outs(%1 : tensor<1x10x10x1024xi32>) %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1]} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x10x10x1024xi32> return } @@ -1237,7 +1237,7 @@ // CHECK: linalg.yield %arg3 : f32 // CHECK: } -> tensor<1x5x5x33xf32> // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] - // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>) + // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2D_nchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[DBIAS]] : tensor<1x5x5x3x11xf32>) // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]] %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) { pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1] } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> (tensor<1x5x5x33xf32>) return @@ -1245,6 +1245,27 @@ // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @depthwise_conv_quant +func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3x4x128xi8>, %arg2 : tensor<512xi32>) -> () { + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1, 10, 10, 512] + // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<512xi32>) outs([[INIT]] : tensor<1x10x10x512xi32>) { + // CHECK: ^bb0(%arg3: i32, %arg4: i32): // no predecessors + // CHECK: linalg.yield %arg3 : i32 + // CHECK: } -> tensor<1x10x10x512xi32> + // CHECK: [[DBIAS:%.+]] = linalg.tensor_expand_shape [[BIAS]] {{\[}}[0], [1], [2], [3, 4]] + // CHECK: %[[C128:.+]] = constant -128 + // CHECK: %[[C42:.+]] = constant 42 + // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv2D_nchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %[[C128]], %[[C42]] : tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, i32, i32) outs([[DBIAS]] : tensor<1x10x10x4x128xi32>) + // CHECK: linalg.tensor_collapse_shape %3 {{\[}}[0], [1], [2], [3, 4]] + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [1, 1] } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> + return +} + +// ----- + // CHECK-LABEL: @transpose_conv func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () { // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0]