diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1945,14 +1945,14 @@ scalar_arg: K --- !LinalgOpConfig metadata: !LinalgOpMetadata - name: conv_2d_ngchw_fgchw - cpp_class_name: Conv2DNgchwFgchwOp + name: conv_2d_ngchw_gfchw + cpp_class_name: Conv2DNgchwGfchwOp doc: |- Performs 2-D grouped convolution. Layout: * Input: NGCHW. - * Kernel: FGCHW. + * Kernel: GFCHW. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -1971,13 +1971,13 @@ kind: input_tensor type_var: T2 shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s11, s1, s2, s5, s9)> + (s1, s11, s2, s5, s9)> - !LinalgOperandDefConfig name: O kind: output_tensor type_var: U shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] -> - (s0, s11, s1, s3, s7)> + (s0, s1, s11, s3, s7)> - !LinalgOperandDefConfig name: strides kind: index_attr diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -406,18 +406,18 @@ @linalg_structured_op -def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, +def conv_2d_ngchw_gfchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D grouped convolution. Layout: * Input: NGCHW. - * Kernel: FGCHW. + * Kernel: GFCHW. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -301,14 +301,14 @@ // ----- -// CHECK-LABEL: func @conv_2d_ngchw_fgchw -func.func @conv_2d_ngchw_fgchw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { - // CHECK: %{{.+}} = linalg.conv_2d_ngchw_fgchw +// CHECK-LABEL: func @conv_2d_ngchw_gfchw +func.func @conv_2d_ngchw_gfchw(%input: tensor, %filter: tensor, %init: tensor) -> tensor { + // CHECK: %{{.+}} = linalg.conv_2d_ngchw_gfchw // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) // CHECK-SAME: outs(%{{.+}} : tensor) -> tensor - %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>, + %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins (%input, %filter: tensor, tensor) outs (%init: tensor) -> tensor @@ -317,6 +317,22 @@ // ----- +// CHECK-LABEL: func @conv_2d_ngchw_gfchw_static +func.func @conv_2d_ngchw_gfchw_static(%input: tensor<1x2x32x128x256xf32>, %filter: tensor<2x64x32x3x5xf32>, %init: tensor<1x2x64x126x252xf32>) -> tensor<1x2x64x126x252xf32> { + // CHECK: %{{.+}} = linalg.conv_2d_ngchw_gfchw + // CHECK-SAME: dilations = dense<1> : tensor<2xi64> + // CHECK-SAME: strides = dense<1> : tensor<2xi64> + // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x2x32x128x256xf32>, tensor<2x64x32x3x5xf32>) + // CHECK-SAME: outs(%{{.+}} : tensor<1x2x64x126x252xf32>) -> tensor<1x2x64x126x252xf32> + %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor<1x2x32x128x256xf32>, tensor<2x64x32x3x5xf32>) + outs (%init: tensor<1x2x64x126x252xf32>) -> tensor<1x2x64x126x252xf32> + return %0 : tensor<1x2x64x126x252xf32> +} + +// ----- + // CHECK-LABEL: func @conv_2d_nhwc_fhwc func.func @conv_2d_nhwc_fhwc(%input: tensor, %filter: tensor, %init: tensor) -> tensor { // CHECK: %{{.+}} = linalg.conv_2d_nhwc_fhwc @@ -365,14 +381,14 @@ // ----- -// CHECK-LABEL: func @conv_2d_ngchw_fgchw -func.func @conv_2d_ngchw_fgchw(%input: memref, %filter: memref, %output: memref) { - // CHECK: linalg.conv_2d_ngchw_fgchw +// CHECK-LABEL: func @conv_2d_ngchw_gfchw +func.func @conv_2d_ngchw_gfchw(%input: memref, %filter: memref, %output: memref) { + // CHECK: linalg.conv_2d_ngchw_gfchw // CHECK-SAME: dilations = dense<1> : tensor<2xi64> // CHECK-SAME: strides = dense<1> : tensor<2xi64> // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref) // CHECK-SAME: outs(%{{.+}} : memref) - linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>, + linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins (%input, %filter: memref, memref) outs (%output: memref)