diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1311,6 +1311,104 @@ - !ScalarExpression scalar_arg: K --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: conv_2d_nhwc_fhwc + cpp_class_name: Conv2DNhwcFhwcOp + doc: |- + Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, + s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, s3, + s7, s9)> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, + s1, s5, s10)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s2, s6)> + default_indices: + - 1 + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s4, s8)> + default_indices: + - 1 + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d3, d4, d5, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d0, d1, d2, d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: K +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf_q cpp_class_name: Conv2DNhwcHwcfQOp diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -178,6 +178,38 @@ // ----- +// CHECK-LABEL: func @conv_2d_nhwc_fhwc +func.func @conv_2d_nhwc_fhwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + // CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc + // CHECK-SAME: dilations = dense<1> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) + // CHECK-SAME: outs(%{{.+}} : tensor) -> tensor + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @conv_2d_nhwc_fhwc_static +func.func @conv_2d_nhwc_fhwc_static(%input: tensor, %filter: tensor<64x3x3x32xf32>, %init: tensor) -> tensor { + // CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc + // CHECK-SAME: dilations = dense<1> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor<64x3x3x32xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor) -> tensor + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<64x3x3x32xf32>) + outs (%init: tensor) -> tensor + return %0 : tensor +} + +// ----- + // CHECK-LABEL: func @conv_2d_nhwc_hwcf func.func @conv_2d_nhwc_hwcf(%input: memref, %filter: memref, %output: memref) { // CHECK: linalg.conv_2d_nhwc_hwcf