diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -352,6 +352,51 @@ ow * strides[1] + kw * dilations[1], c)); } +ods_def: +def pooling_nhwc_i8_max + (I: i8(N, H, W, C), K: i8(KH, KW)) + -> (O: i8(N, OH, OW, C)) + attr(strides: 2xi64, dilations: 2xi64) +{ + O(n, oh, ow, c) = + std_select(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)), + I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)); +} + +ods_def: +def pooling_nhwc_i16_max + (I: i16(N, H, W, C), K: i16(KH, KW)) + -> (O: i16(N, OH, OW, C)) + attr(strides: 2xi64, dilations: 2xi64) +{ + O(n, oh, ow, c) = + std_select(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)), + I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)); +} + +ods_def: +def pooling_nhwc_i32_max + (I: i32(N, H, W, C), K: i32(KH, KW)) + -> (O: i32(N, OH, OW, C)) + attr(strides: 2xi64, dilations: 2xi64) +{ + O(n, oh, ow, c) = + std_select(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)), + I(n, oh * strides[0] + kh * dilations[0], + ow * strides[1] + kw * dilations[1], c), + O(n, oh, ow, c)); +} + ods_def: def pooling_nhwc_max (I: f32(N, H, W, C), K: f32(KH, KW)) diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -59,6 +59,15 @@ using std_cmpf_ogt = CmpFValueBuilder; using std_cmpf_olt = CmpFValueBuilder; +template +struct CmpIValueBuilder : public ValueBuilder { + using ValueBuilder::ValueBuilder; + template + CmpIValueBuilder(Args... args) : ValueBuilder(Predicate, args...) {} +}; + +using std_cmpi_sgt = CmpIValueBuilder; + /// Branches into `block` with `operands`. BranchOp std_br(Block *block, ValueRange operands); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1230,6 +1230,22 @@ "Pad converter requires static shaped input / padding values."); } + Attribute constantAttr; + if (elementTy.isa()) + constantAttr = rewriter.getFloatAttr(elementTy, 0.0); + else if (elementTy.isa() && !padOp.quantization_info()) + constantAttr = rewriter.getIntegerAttr(elementTy, 0); + else if (elementTy.isa() && padOp.quantization_info()) { + auto value = padOp.quantization_info().getValue().input_zp().getValue(); + constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); + } + + if (!constantAttr) { + return rewriter.notifyMatchFailure( + padOp, + "tosa.pad to linalg lowering encountered an unknown element type"); + } + Value lowIndex = rewriter.create(loc, rewriter.getIndexAttr(0)); Value highIndex = rewriter.create(loc, rewriter.getIndexAttr(1)); @@ -1256,22 +1272,6 @@ highValues.push_back(highVal); } - Attribute constantAttr; - if (elementTy.isa()) - constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - else if (elementTy.isa() && !padOp.quantization_info()) - constantAttr = rewriter.getIntegerAttr(elementTy, 0); - else if (elementTy.isa() && padOp.quantization_info()) { - auto value = padOp.quantization_info().getValue().input_zp().getValue(); - constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue()); - } - - if (!constantAttr) { - return rewriter.notifyMatchFailure( - padOp, - "tosa.pad to linalg lowering encountered an unknown element type"); - } - Value constant = rewriter.create(loc, constantAttr); auto newPadOp = linalg::PadTensorOp::createPadScalarOp( @@ -1523,6 +1523,128 @@ } }; +class MaxPool2dConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value input = op.input(); + ShapedType inputTy = input.getType().cast(); + Type inElementTy = inputTy.getElementType(); + + ShapedType resultTy = op.getType().cast(); + Type outElementTy = inputTy.getElementType(); + int64_t rank = inputTy.getRank(); + + if (!inputTy.hasStaticShape()) + return failure(); + + // Determine what the initial value needs to be for the max pool op. + Attribute initialAttr; + if (outElementTy.isF32()) + initialAttr = rewriter.getFloatAttr( + outElementTy, + APFloat::getLargest( + outElementTy.cast().getFloatSemantics(), true)); + + if (outElementTy.isa()) + initialAttr = rewriter.getIntegerAttr( + outElementTy, + APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth())); + + if (!initialAttr) + return rewriter.notifyMatchFailure( + op, "Unsupported initial value for tosa.maxpool_2d op"); + + Value initialValue = rewriter.create(loc, initialAttr); + + SmallVector kernel, stride, pad; + getValuesFromIntArrayAttribute(op.kernel(), kernel); + getValuesFromIntArrayAttribute(op.stride(), stride); + getValuesFromIntArrayAttribute(op.pad(), pad); + + Attribute strideAttr = rewriter.getI64VectorAttr(stride); + Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); + + // If non-zero padding we need to pad the input + if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) { + SmallVector paddedShape; + for (int64_t i = 0; i < rank; i++) + paddedShape.push_back(inputTy.getDimSize(i)); + + paddedShape[1] += pad[0] + pad[1]; + paddedShape[2] += pad[2] + pad[3]; + + OpFoldResult zeroIndex = rewriter.getIndexAttr(0); + OpFoldResult heightLowPadIndex = rewriter.getIndexAttr(pad[0]); + OpFoldResult heightHighPadIndex = rewriter.getIndexAttr(pad[1]); + OpFoldResult widthLowPadIndex = rewriter.getIndexAttr(pad[2]); + OpFoldResult widthHighPadIndex = rewriter.getIndexAttr(pad[3]); + + SmallVector lowIndices = {zeroIndex, heightLowPadIndex, + widthLowPadIndex, zeroIndex}; + SmallVector highIndices = {zeroIndex, heightHighPadIndex, + widthHighPadIndex, zeroIndex}; + + input = linalg::PadTensorOp::createPadScalarOp( + RankedTensorType::get(paddedShape, inElementTy), input, + initialValue, lowIndices, highIndices, loc, rewriter) + .result(); + } + + Value initTensor = rewriter.create( + loc, resultTy.getShape(), resultTy.getElementType()); + + Value filledInitTensor = + rewriter.create(loc, initTensor, initialValue).result(); + + Value fakeWindowDims = + rewriter.create(loc, kernel, outElementTy); + + auto createOp = [&](auto *typePtr) -> linalg::LinalgOp { + return cast( + rewriter + .create>( + loc, ArrayRef{resultTy}, + ValueRange{input, fakeWindowDims}, filledInitTensor, + dilationAttr, strideAttr) + .getOperation()); + }; + + if (inElementTy.isF32()) { + linalg::LinalgOp poolingOp = + createOp(static_cast(nullptr)); + rewriter.replaceOp(op, poolingOp->getResult(0)); + return success(); + } + + if (inElementTy.isInteger(8)) { + linalg::LinalgOp poolingOp = + createOp(static_cast(nullptr)); + rewriter.replaceOp(op, poolingOp->getResult(0)); + return success(); + } + + if (inElementTy.isInteger(16)) { + linalg::LinalgOp poolingOp = + createOp(static_cast(nullptr)); + rewriter.replaceOp(op, poolingOp->getResult(0)); + return success(); + } + + if (inElementTy.isInteger(32)) { + linalg::LinalgOp poolingOp = + createOp(static_cast(nullptr)); + rewriter.replaceOp(op, poolingOp->getResult(0)); + return success(); + } + + return failure(); + } +}; + } // namespace void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns( @@ -1579,6 +1701,7 @@ TileConverter, TransposeConverter, MatMulConverter, + MaxPool2dConverter, FullyConnectedConverter>(patterns->getContext()); // clang-format on } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -873,3 +873,53 @@ %0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>) return } + +// ----- + +// CHECK-LABEL: @max_pool +func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () { + // CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38 + // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62] + // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]]) + // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3] + // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>) + %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>) + return +} + +// CHECK-LABEL: @max_pool_padded +func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () { + // CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38 : f32 + // CHECK-DAG: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 0, 0, 0] high[0, 0, 1, 0] + // CHECK-DAG: linalg.yield [[CONST]] + // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62] + // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]]) + // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3] + // CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>) + %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x33x62xf32>) + return +} + +// CHECK-LABEL: @max_pool_i8 +func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () { + // CHECK: constant -128 + // CHECK: linalg.pooling_nhwc_i8_max + %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi8>) -> (tensor<1x4x32x62xi8>) + return +} + +// CHECK-LABEL: @max_pool_i16 +func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () { + // CHECK: constant -32768 + // CHECK: linalg.pooling_nhwc_i16_max + %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi16>) -> (tensor<1x4x32x62xi16>) + return +} + +// CHECK-LABEL: @max_pool_i32 +func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () { + // CHECK: constant -2147483648 + // CHECK: linalg.pooling_nhwc_i32_max + %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>) + return +} diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir @@ -340,6 +340,84 @@ // ----- +func @pooling_nhwc_i8_max(%input: memref, %fake: memref<2x3xi8>, %init: memref) { + linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>} + ins(%input, %fake: memref, memref<2x3xi8>) + outs(%init: memref) + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK: func @pooling_nhwc_i8_max + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xi8>) +// CHECK-SAME: outs(%{{.+}} : memref) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8) +// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i8 +// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i8 +// CHECK-NEXT: linalg.yield %[[RES]] : i8 + +// ----- + +func @pooling_nhwc_i16_max(%input: memref, %fake: memref<2x3xi16>, %init: memref) { + linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>} + ins(%input, %fake: memref, memref<2x3xi16>) + outs(%init: memref) + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK: func @pooling_nhwc_i16_max + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xi16>) +// CHECK-SAME: outs(%{{.+}} : memref) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i16, %[[BBARG1:.+]]: i16, %[[BBARG2:.+]]: i16) +// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i16 +// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i16 +// CHECK-NEXT: linalg.yield %[[RES]] : i16 + +// ----- + +func @pooling_nhwc_i32_max(%input: memref, %fake: memref<2x3xi32>, %init: memref) { + linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>} + ins(%input, %fake: memref, memref<2x3xi32>) + outs(%init: memref) + return +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK: func @pooling_nhwc_i32_max + +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref, memref<2x3xi32>) +// CHECK-SAME: outs(%{{.+}} : memref) + +// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32) +// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i32 +// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i32 +// CHECK-NEXT: linalg.yield %[[RES]] : i32 + +// ----- + func @pooling_nhwc_min(%input: memref, %fake: memref<2x3xf32>, %init: memref) { linalg.pooling_nhwc_min {dilations = dense<3> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%input, %fake: memref, memref<2x3xf32>) diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -344,6 +344,109 @@ return } +// ----- + +// CHECK-LABEL: func @pooling_nhwc_i8_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_i8_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi8>, tensor<3x3xi8>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> +func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> { + %fake = linalg.init_tensor [3, 3] : tensor<3x3xi8> + %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8> + %cst = constant 0 : i8 + %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi8>, i8 -> tensor<1x2x2x1xi8> + %res = linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x4x4x1xi8>, tensor<3x3xi8>) + outs(%fill: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8> + return %res : tensor<1x2x2x1xi8> +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_i8_max +// CHECK: linalg.pooling_nhwc_i8_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi8>, memref<3x3xi8>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi8>) +func @pooling_nhwc_i8_max(%input: memref<1x4x4x1xi8>, %fake: memref<3x3xi8>, %output: memref<1x2x2x1xi8>) { + linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: memref<1x4x4x1xi8>, memref<3x3xi8>) + outs(%output: memref<1x2x2x1xi8>) + return +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_i16_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_i16_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi16>, tensor<3x3xi16>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16> +func @pooling_nhwc_i16_max_tensor(%input: tensor<1x4x4x1xi16>) -> tensor<1x2x2x1xi16> { + %fake = linalg.init_tensor [3, 3] : tensor<3x3xi16> + %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi16> + %cst = constant 0 : i16 + %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi16>, i16 -> tensor<1x2x2x1xi16> + %res = linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x4x4x1xi16>, tensor<3x3xi16>) + outs(%fill: tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16> + return %res : tensor<1x2x2x1xi16> +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_i16_max +// CHECK: linalg.pooling_nhwc_i16_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi16>, memref<3x3xi16>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi16>) +func @pooling_nhwc_i16_max(%input: memref<1x4x4x1xi16>, %fake: memref<3x3xi16>, %output: memref<1x2x2x1xi16>) { + linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: memref<1x4x4x1xi16>, memref<3x3xi16>) + outs(%output: memref<1x2x2x1xi16>) + return +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_i32_max_tensor +// CHECK: %{{.+}} = linalg.pooling_nhwc_i32_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>) +// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> +func @pooling_nhwc_i32_max_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> { + %fake = linalg.init_tensor [3, 3] : tensor<3x3xi32> + %init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi32> + %cst = constant 0 : i32 + %fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi32>, i32 -> tensor<1x2x2x1xi32> + %res = linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>) + outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32> + return %res : tensor<1x2x2x1xi32> +} + +// ----- + +// CHECK-LABEL: func @pooling_nhwc_i32_max +// CHECK: linalg.pooling_nhwc_i32_max +// CHECK-SAME: dilations = dense<1> : tensor<2xi64> +// CHECK-SAME: strides = dense<1> : tensor<2xi64> +// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi32>, memref<3x3xi32>) +// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi32>) +func @pooling_nhwc_i32_max(%input: memref<1x4x4x1xi32>, %fake: memref<3x3xi32>, %output: memref<1x2x2x1xi32>) { + linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} + ins(%input, %fake: memref<1x4x4x1xi32>, memref<3x3xi32>) + outs(%output: memref<1x2x2x1xi32>) + return +} + + // ----- // CHECK-LABEL: func @pooling_nhwc_min_tensor