diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -329,3 +329,44 @@ w * strides[2] + kw * dilations[2]), K(kd, kh, kw, c, f))); } + +ods_def: +def pooling_nhwc_sum + (I: f32(N, H, W, C), K: f32(KH, KW)) + -> (O: f32(N, OH, OW, C)) + attr(strides: 2xi64, dilations: 2xi64) +{ + O(n, oh, ow, c) = std_addf(O(n, oh, ow, c), + I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c)); +} + +ods_def: +def pooling_nhwc_max + (I: f32(N, H, W, C), K: f32(KH, KW)) + -> (O: f32(N, OH, OW, C)) + attr(strides: 2xi64, dilations: 2xi64) +{ + O(n, oh, ow, c) = + std_select(std_cmpf_ogt(I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)), + I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)); +} + +ods_def: +def pooling_nhwc_min + (I: f32(N, H, W, C), K: f32(KH, KW)) + -> (O: f32(N, OH, OW, C)) + attr(strides: 2xi64, dilations: 2xi64) +{ + O(n, oh, ow, c) = + std_select(std_cmpf_olt(I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)), + I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)); +} diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -60,6 +60,16 @@ using std_sexti32 = SExtiValueBuilder<32>; +template +struct CmpFValueBuilder : public ValueBuilder { + using ValueBuilder::ValueBuilder; + template + CmpFValueBuilder(Args... args) : ValueBuilder(Predicate, args...) {} +}; + +using std_cmpf_ogt = CmpFValueBuilder; +using std_cmpf_olt = CmpFValueBuilder; + /// Branches into `block` with `operands`. BranchOp std_br(Block *block, ValueRange operands); diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -259,3 +259,80 @@ // CHECK-NEXT: %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32 // CHECK-NEXT: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32 // CHECK-NEXT: linalg.yield %[[ADD]] : f32 + +// ----- + +func @pooling_nhwc_sum(%input: memref, %fake: memref<2x3xf32>, %init: memref) { + linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: memref, memref<2x3xf32>) + outs(%init: memref) + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK: func @pooling_nhwc_sum + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK-NEXT: %[[RES:.+]] = addf %[[BBARG2]], %[[BBARG0]] : f32 +// CHECK-NEXT: linalg.yield %[[RES]] : f32 + +// ----- + +func @pooling_nhwc_max(%input: memref, %fake: memref<2x3xf32>, %init: memref) { + linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>} + ins(%input, %fake: memref, memref<2x3xf32>) + outs(%init: memref) + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK: func @pooling_nhwc_max + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK-NEXT: %[[CMP:.+]] = cmpf ogt, %[[BBARG0]], %[[BBARG2]] : f32 +// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : f32 +// CHECK-NEXT: linalg.yield %[[RES]] : f32 + +// ----- + +func @pooling_nhwc_min(%input: memref, %fake: memref<2x3xf32>, %init: memref) { + linalg.pooling_nhwc_min {dilations = dense<3> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} + ins(%input, %fake: memref, memref<2x3xf32>) + outs(%init: memref) + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4 * 3, d2 * 2 + d5 * 3, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK: func @pooling_nhwc_min + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32) +// CHECK-NEXT: %[[CMP:.+]] = cmpf olt, %[[BBARG0]], %[[BBARG2]] : f32 +// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : f32 +// CHECK-NEXT: linalg.yield %[[RES]] : f32 diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -246,3 +246,105 @@ outs (%output: memref) return } + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_sum_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> +func @pooling_nhwc_sum_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> { + %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> + %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> + %cst = constant 0.000000e+00 : f32 + %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32> + %res = linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>) + outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> + return %res : tensor<1x2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_sum +// CHECK: linalg.pooling_nhwc_sum +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>) +func @pooling_nhwc_sum(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { + linalg.pooling_nhwc_sum {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>) + outs(%output: memref<1x2x2x1xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> +func @pooling_nhwc_max_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> { + %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> + %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> + %cst = constant 0.000000e+00 : f32 + %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32> + %res = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>) + outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> + return %res : tensor<1x2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_max +// CHECK: linalg.pooling_nhwc_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>) +func @pooling_nhwc_max(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { + linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>) + outs(%output: memref<1x2x2x1xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_min_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x6x6x1xf32>, tensor<3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> +func @pooling_nhwc_min_tensor(%input: tensor<1x6x6x1xf32>) -> tensor<1x2x2x1xf32> { + %fake = linalg.init_tensor [3, 3] : tensor<3x3xf32> + %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xf32> + %cst = constant 0.000000e+00 : f32 + %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xf32>, f32 -> tensor<1x2x2x1xf32> + %res = linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x6x6x1xf32>, tensor<3x3xf32>) + outs(%fill: tensor<1x2x2x1xf32>) -> tensor<1x2x2x1xf32> + return %res : tensor<1x2x2x1xf32> +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_min +// CHECK: linalg.pooling_nhwc_min +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x6x6x1xf32>, memref<3x3xf32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xf32>) +func @pooling_nhwc_min(%input: memref<1x6x6x1xf32>, %fake: memref<3x3xf32>, %output: memref<1x2x2x1xf32>) { + linalg.pooling_nhwc_min {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: memref<1x6x6x1xf32>, memref<3x3xf32>) + outs(%output: memref<1x2x2x1xf32>) + return +}