diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -905,6 +905,88 @@ - !ScalarExpression scalar_arg: K --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: conv_2d_nchw + cpp_class_name: Conv2DNchwOp + doc: |- + Performs 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s4, s1, s5, s6)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] + -> (s0, s4, s7, s8, s1)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s9, s10)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, + s12] -> (s11, s12)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d1, d4, d5, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10, s11, s12] -> (d0, d1, d2, d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: K +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_sum cpp_class_name: PoolingNhwcSumOp @@ -1047,6 +1129,77 @@ - !ScalarExpression scalar_arg: I --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: pooling_nchw_max + cpp_class_name: PoolingNchwMaxOp + doc: |- + Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + usage: InputOperand + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: K + usage: InputOperand + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s4, s5)> + - !LinalgOperandDefConfig + name: O + usage: OutputOperand + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> + (s0, s1, s6, s7)> + - !LinalgOperandDefConfig + name: strides + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] + -> (s8, s9)> + - !LinalgOperandDefConfig + name: dilations + usage: IndexAttribute + type_var: I64 + attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] + -> (s10, s11)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d0, d1, d2 * s8 + d4 * s10, d3 * s9 + d5 * s11)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d4, d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, + s10, s11] -> (d0, d1, d2, d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_apply: + fn_name: max + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + symbolic_cast: + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min cpp_class_name: PoolingNhwcMinOp diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -125,12 +125,6 @@ O(n, h, w, f), MulFOp(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); } -ods_def: -def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) { - O(n, f, h, w) = AddFOp( - O(n, f, h, w), MulFOp(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); -} - ods_def: def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) { O(d, h, w) = AddFOp( diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1177,7 +1177,7 @@ populateVectorizationPatterns( tiling, promotion, vectorization, tileSizes); - populateVectorizationPatterns(tiling, promotion, vectorization, + populateVectorizationPatterns(tiling, promotion, vectorization, tileSizes); populateVectorizationPatterns( tiling, promotion, vectorization, tileSizes); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -205,6 +205,23 @@ U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) * cast(U, K[D.kh, D.kw, D.c]) +@linalg_structured_op +def conv_2d_nchw( + I=TensorDef(T1, S.N, S.C, S.IH, S.IW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + @linalg_structured_op def pooling_nhwc_sum( @@ -240,6 +257,22 @@ cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nchw_max( + I=TensorDef(T1, S.N, S.C, S.H, S.W), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( + cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + ])) @linalg_structured_op def pooling_nhwc_min( diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -30,6 +30,24 @@ return %0 : tensor<2x3x4x2x3xf32> } +// CHECK-LABEL: func @conv_2d_nchw_tensor +func @conv_2d_nchw_tensor(%input: tensor<2x2x4x5xf32>, %filter: tensor<4x2x3x3xf32>) -> tensor<2x4x2x3xf32> { + %cst = constant 0.000000e+00 : f32 + %init = linalg.init_tensor [2, 4, 2, 3] : tensor<2x4x2x3xf32> + %fill = linalg.fill(%cst, %init) : f32, tensor<2x4x2x3xf32> -> tensor<2x4x2x3xf32> +// CHECK: %{{.+}} = linalg.conv_2d_nchw +// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32> +// CHECK: return %{{.+}} : tensor<2x4x2x3xf32> +// CHECK: } + %0 = linalg.conv_2d_nchw + {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} + ins(%input, %filter: tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>) + outs(%fill : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32> + return %0 : tensor<2x4x2x3xf32> +} + // CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) { // CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf @@ -381,6 +399,25 @@ return %res : tensor<1x2x2x1xf32> } +// ----- +// CHECK-LABEL: func @pooling_nchw_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nchw_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x1x4x4xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> + +func @pooling_nchw_max_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> { + %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> + %init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32> + %cst = constant 0.000000e+00 : f32 + %fill = linalg.fill(%cst, %init) : f32, tensor<1x1x2x2xf32> -> tensor<1x1x2x2xf32> + %res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>) + outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32> + return %res : tensor<1x1x2x2xf32> +} + // ----- // CHECK-LABEL: func @pooling_nhwc_max diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir @@ -30,8 +30,10 @@ } func @conv_2d_nchw(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.conv_2d_nchw ins (%arg0, %arg1: memref, memref) - outs (%arg2: memref) + linalg.conv_2d_nchw + {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} + ins (%arg0, %arg1: memref, memref) + outs (%arg2: memref) return }