diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -452,6 +452,10 @@ using PaddingValueComputationFunction = std::function(OpBuilder &, OpOperand &)>; +/// Callback returning true if the pad tensor operation defining the given +/// OpOperand shall be marked as nofold to enable packing. +using PaddingNoFoldComputationFunction = std::function; + struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. /// Delayed construction of constant tile sizes should occur to interoperate @@ -526,6 +530,18 @@ return *this; } + /// Callback returning true if the pad tensor operation defining the given + /// OpOperand shall be marked as nofold to enable packing. A padding operation + /// is only marked nofold if `paddingNoFoldComputationFunction` is set and + /// returns true. Otherwise, the nofold attribute is set to false. + PaddingNoFoldComputationFunction paddingNoFoldComputationFunction = nullptr; + + LinalgTilingOptions & + setPaddingNoFoldComputationFunction(PaddingNoFoldComputationFunction fun) { + paddingNoFoldComputationFunction = std::move(fun); + return *this; + } + /// Peel the specified loops. SmallVector peeledLoops; @@ -999,6 +1015,7 @@ LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, const PaddingValueComputationFunction &paddingFunc, + const PaddingNoFoldComputationFunction &nofoldFunc, LinalgOp &paddedOp); using OptimizeCopyFn = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -153,7 +153,8 @@ /// padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, - const PaddingValueComputationFunction &paddingFunc, Value &result) { + const PaddingValueComputationFunction &paddingFunc, + const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { // Can't pad scalars. if (opToPad.getShape(opOperand).empty()) return success(); @@ -181,15 +182,17 @@ } auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); + bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; result = linalg::PadTensorOp::createPadHighOp( staticTensorType, opOperand->get(), paddingValue.getValue(), - /*nofold=*/true, opToPad->getLoc(), rewriter); + /*nofold=*/nofold, opToPad->getLoc(), rewriter); return success(); } LogicalResult linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, const PaddingValueComputationFunction &paddingFunc, + const PaddingNoFoldComputationFunction &nofoldFunc, LinalgOp &paddedOp) { Location loc = opToPad->getLoc(); @@ -208,7 +211,8 @@ // If padding was requested but the shape cannot be bounded statically then // the pattern fails to apply. if (failed(padOperandToSmallestStaticBoundingBox( - rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) + rewriter, opToPad, opOperand, paddingFunc, nofoldFunc, + paddedOperand))) return failure(); newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); } @@ -341,9 +345,9 @@ // Try to pad on the fly by rewriting res->op as a padded op. If successful, // `res.op` is rewritten in static form with padded operands. LinalgOp paddedOp; - if (succeeded(rewriteAsPaddedOp(rewriter, res->op, - options.paddingValueComputationFunction, - paddedOp))) { + if (succeeded(rewriteAsPaddedOp( + rewriter, res->op, options.paddingValueComputationFunction, + options.paddingNoFoldComputationFunction, paddedOp))) { filter.replaceLinalgTransformationFilter(rewriter, paddedOp); res->op = paddedOp; result = *res; diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 nofold-operands=0,1 tile-sizes=2,3,4" -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 nofold-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor @@ -24,7 +24,7 @@ // CHECK: : tensor to tensor<2x4xi8> // CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<4x3xi8> -// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x3xi32> // CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) // CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -113,6 +113,10 @@ *this, "padded-operands", llvm::cl::desc("Operands to pad when test-tile-pattern"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption nofoldOperands{ + *this, "nofold-operands", + llvm::cl::desc("Operands to set nofold when test-tile-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), @@ -581,6 +585,7 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType, ArrayRef tileSizes, ArrayRef paddedOperands, + ArrayRef nofoldOperands, ArrayRef peeledLoops, bool scalarizeDynamicDims) { MLIRContext *context = funcOp.getContext(); @@ -608,7 +613,13 @@ return failure(); return getNeutralOfLinalgOp(b, opOperand); }; + auto nofoldFunc = [&](OpOperand &opOperand) { + if (llvm::count(nofoldOperands, opOperand.getOperandNumber()) != 0) + return true; + return false; + }; linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); + linalgTilingOptions.setPaddingNoFoldComputationFunction(nofoldFunc); } tilingPattern.add, linalg::LinalgTilingPattern>( @@ -743,9 +754,11 @@ skipPartial); if (testTilePattern) return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, - peeledLoops, /*scalarizeDynamicDims=*/false); + nofoldOperands, peeledLoops, + /*scalarizeDynamicDims=*/false); if (testTileScalarizeDynamicDims) return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, + nofoldOperands, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) {