diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -352,6 +352,18 @@ ow * strides[1] + kw * dilations[1], c)); } +ods_def: +def pooling_ndhwc_sum + (I: f32(N, D, H, W, C), K: f32(KD, KH, KW)) + -> (O: f32(N, OD, OH, OW, C)) + attr(strides: 3xi64, dilations: 3xi64) +{ + O(n, od, oh, ow, c) = AddFOp( + O(n, od, oh, ow, c), I(n, od * strides[0] + kd * dilations[0], + oh * strides[1] + kh * dilations[1], + ow * strides[2] + kw * dilations[2], c)); +} + ods_def: def pooling_nhwc_i8_max (I: i8(N, H, W, C), K: i8(KH, KW)) @@ -412,6 +424,23 @@ O(n, oh, ow, c)); } +ods_def: +def pooling_ndhwc_max + (I: f32(N, D, H, W, C), K: f32(KD, KH, KW)) + -> (O: f32(N, OD, OH, OW, C)) + attr(strides: 3xi64, dilations: 3xi64) +{ + O(n, od, oh, ow, c) = + SelectOp(CmpFOpOGT(I(n, od * strides[0] + kd * dilations[0], + oh * strides[1] + kh * dilations[1], + ow * strides[2] + kw * dilations[2], c), + O(n, od, oh, ow, c)), + I(n, od * strides[0] + kd * dilations[0], + oh * strides[1] + kh * dilations[1], + ow * strides[2] + kw * dilations[2], c), + O(n, od, oh, ow, c)); +} + ods_def: def pooling_nhwc_min (I: f32(N, H, W, C), K: f32(KH, KW)) @@ -426,3 +455,20 @@ ow * strides[1] + kw * dilations[1], c), O(n, oh, ow, c)); } + +ods_def: +def pooling_ndhwc_min + (I: f32(N, D, H, W, C), K: f32(KD, KH, KW)) + -> (O: f32(N, OD, OH, OW, C)) + attr(strides: 3xi64, dilations: 3xi64) +{ + O(n, od, oh, ow, c) = + SelectOp(CmpFOpOLT(I(n, od * strides[0] + kd * dilations[0], + oh * strides[1] + kh * dilations[1], + ow * strides[2] + kw * dilations[2], c), + O(n, od, oh, ow, c)), + I(n, od * strides[0] + kd * dilations[0], + oh * strides[1] + kh * dilations[1], + ow * strides[2] + kw * dilations[2], c), + O(n, od, oh, ow, c)); +} diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -518,3 +518,105 @@ outs(%output: memref<1x2x2x1xf32>) return } + +// ----- + +// CHECK-LABEL: func @pooling_ndhwc_sum_tensor +// CHECK: %{{.+}} = linalg.pooling_ndhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> +func @pooling_ndhwc_sum_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> { + %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> + %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> + %cst = constant 0.000000e+00 : f32 + %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32> + %res = linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) + outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> + return %res : tensor<1x2x2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func @pooling_ndhwc_sum +// CHECK: linalg.pooling_ndhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x4x1xf32>, memref<3x3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x2x1xf32>) +func @pooling_ndhwc_sum(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3xf32>, %output: memref<1x2x2x2x1xf32>) { + linalg.pooling_ndhwc_sum {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins(%input, %fake: memref<1x4x4x4x1xf32>, memref<3x3x3xf32>) + outs(%output: memref<1x2x2x2x1xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @pooling_ndhwc_max_tensor +// CHECK: %{{.+}} = linalg.pooling_ndhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> +func @pooling_ndhwc_max_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> { + %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> + %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> + %cst = constant 0.000000e+00 : f32 + %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32> + %res = linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) + outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> + return %res : tensor<1x2x2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func @pooling_ndhwc_max +// CHECK: linalg.pooling_ndhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x4x1xf32>, memref<3x3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x2x1xf32>) +func @pooling_ndhwc_max(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3xf32>, %output: memref<1x2x2x2x1xf32>) { + linalg.pooling_ndhwc_max {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins(%input, %fake: memref<1x4x4x4x1xf32>, memref<3x3x3xf32>) + outs(%output: memref<1x2x2x2x1xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @pooling_ndhwc_min_tensor +// CHECK: %{{.+}} = linalg.pooling_ndhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> +func @pooling_ndhwc_min_tensor(%input: tensor<1x4x4x4x1xf32>) -> tensor<1x2x2x2x1xf32> { + %fake = linalg.init_tensor [3, 3, 3] : tensor<3x3x3xf32> + %init = linalg.init_tensor [1, 2, 2, 2, 1] : tensor<1x2x2x2x1xf32> + %cst = constant 0.000000e+00 : f32 + %fill = linalg.fill(%cst, %init) : f32, tensor<1x2x2x2x1xf32> -> tensor<1x2x2x2x1xf32> + %res = linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins(%input, %fake: tensor<1x4x4x4x1xf32>, tensor<3x3x3xf32>) + outs(%fill: tensor<1x2x2x2x1xf32>) -> tensor<1x2x2x2x1xf32> + return %res : tensor<1x2x2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func @pooling_ndhwc_min +// CHECK: linalg.pooling_ndhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<3xi64> +// CHECK-SAME: strides = dense<1> : tensor<3xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x4x1xf32>, memref<3x3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x2x1xf32>) +func @pooling_ndhwc_min(%input: memref<1x4x4x4x1xf32>, %fake: memref<3x3x3xf32>, %output: memref<1x2x2x2x1xf32>) { + linalg.pooling_ndhwc_min {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins(%input, %fake: memref<1x4x4x4x1xf32>, memref<3x3x3xf32>) + outs(%output: memref<1x2x2x2x1xf32>) + return +}