diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1626,18 +1626,19 @@ } }; -class MaxPool2dConverter : public OpRewritePattern { +template +class Pool2dConverter : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + LogicalResult matchAndRewrite(SrcOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value input = op.input(); ShapedType inputTy = input.getType().cast(); Type inElementTy = inputTy.getElementType(); - ShapedType resultTy = op.getType().cast(); + ShapedType resultTy = op.getType().template cast(); Type outElementTy = inputTy.getElementType(); int64_t rank = inputTy.getRank(); @@ -1646,17 +1647,20 @@ // Determine what the initial value needs to be for the max pool op. Attribute initialAttr; - if (outElementTy.isF32()) + if (isa(op) && outElementTy.isF32()) initialAttr = rewriter.getFloatAttr( outElementTy, APFloat::getLargest( outElementTy.cast().getFloatSemantics(), true)); - if (outElementTy.isa()) + if (isa(op) && outElementTy.isa()) initialAttr = rewriter.getIntegerAttr( outElementTy, APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth())); + if (isa(op) && outElementTy.isa()) + initialAttr = rewriter.getZeroAttr(outElementTy); + if (!initialAttr) return rewriter.notifyMatchFailure( op, "Unsupported initial value for tosa.maxpool_2d op"); @@ -1670,6 +1674,7 @@ Attribute strideAttr = rewriter.getI64VectorAttr(stride); Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1}); + int64_t kernelSize = kernel[0] * kernel[1]; // If non-zero padding we need to pad the input if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) { @@ -1716,34 +1721,46 @@ .getOperation()); }; - if (inElementTy.isF32()) { + if (isa(op) && inElementTy.isF32()) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } - if (inElementTy.isInteger(8)) { + if (isa(op) && inElementTy.isInteger(8)) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } - if (inElementTy.isInteger(16)) { + if (isa(op) && inElementTy.isInteger(16)) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } - if (inElementTy.isInteger(32)) { + if (isa(op) && inElementTy.isInteger(32)) { linalg::LinalgOp poolingOp = createOp(static_cast(nullptr)); rewriter.replaceOp(op, poolingOp->getResult(0)); return success(); } + if (isa(op) && inElementTy.isF32()) { + linalg::LinalgOp poolingOp = + createOp(static_cast(nullptr)); + auto constAttr = DenseElementsAttr::get( + resultTy, static_cast(1.0 / kernelSize)); + auto constant = rewriter.create(loc, constAttr); + auto mul = rewriter.create( + loc, resultTy, poolingOp->getResult(0), constant, 0); + rewriter.replaceOp(op, mul.output()); + return success(); + } + return failure(); } }; @@ -1805,7 +1822,8 @@ TileConverter, TransposeConverter, MatMulConverter, - MaxPool2dConverter, + Pool2dConverter, + Pool2dConverter, FullyConnectedConverter>(patterns->getContext()); - // clang-format on + // clang-format on } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -923,6 +923,21 @@ %0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>) return } +// ----- + +// CHECK-LABEL: @avg_pool +func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> () { + // CHECK-DAG: [[CONST:%.+]] = constant 0 + // CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 3, 31, 62] + // CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]]) + // CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [4, 4] + // CHECK: linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x3x31x62xf32>) + // CHECK: constant dense<6.250000e-02> + // CHECK: linalg.generic + // CHECK: mulf + %0 = "tosa.avg_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x3x31x62xf32>) + return +} // -----