diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -452,6 +452,10 @@ using PaddingValueComputationFunction = std::function(OpBuilder &, OpOperand &)>; +/// Callback returning true if the pad tensor operation defining the given +/// OpOperand shall be marked as nofold to enable packing. +using PackPaddingComputationFunction = std::function; + struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. /// Delayed construction of constant tile sizes should occur to interoperate @@ -526,6 +530,18 @@ return *this; } + /// Callback returning true if the pad tensor operation defining the given + /// OpOperand shall be marked as nofold to enable packing. A padding operation + /// is only marked nofold if `packPaddingComputationFunction` is set and + /// returns true. Otherwise, the nofold attribute is set to false. + PackPaddingComputationFunction packPaddingComputationFunction = nullptr; + + LinalgTilingOptions & + setPackPaddingComputationFunction(PackPaddingComputationFunction fun) { + packPaddingComputationFunction = std::move(fun); + return *this; + } + /// Peel the specified loops. SmallVector peeledLoops; @@ -999,6 +1015,7 @@ LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, const PaddingValueComputationFunction &paddingFunc, + const PackPaddingComputationFunction &packFunc, LinalgOp &paddedOp); using OptimizeCopyFn = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -153,7 +153,8 @@ /// padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, - const PaddingValueComputationFunction &paddingFunc, Value &result) { + const PaddingValueComputationFunction &paddingFunc, + const PackPaddingComputationFunction &packFunc, Value &result) { // Can't pad scalars. if (opToPad.getShape(opOperand).empty()) return success(); @@ -181,15 +182,17 @@ } auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); + bool packPadding = packFunc ? packFunc(*opOperand) : false; result = linalg::PadTensorOp::createPadHighOp( staticTensorType, opOperand->get(), paddingValue.getValue(), - /*nofold=*/true, opToPad->getLoc(), rewriter); + /*nofold=*/packPadding, opToPad->getLoc(), rewriter); return success(); } LogicalResult linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, const PaddingValueComputationFunction &paddingFunc, + const PackPaddingComputationFunction &packFunc, LinalgOp &paddedOp) { Location loc = opToPad->getLoc(); @@ -207,8 +210,9 @@ Value paddedOperand; // If padding was requested but the shape cannot be bounded statically then // the pattern fails to apply. - if (failed(padOperandToSmallestStaticBoundingBox( - rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) + if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, + opOperand, paddingFunc, + packFunc, paddedOperand))) return failure(); newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); } @@ -341,9 +345,9 @@ // Try to pad on the fly by rewriting res->op as a padded op. If successful, // `res.op` is rewritten in static form with padded operands. LinalgOp paddedOp; - if (succeeded(rewriteAsPaddedOp(rewriter, res->op, - options.paddingValueComputationFunction, - paddedOp))) { + if (succeeded(rewriteAsPaddedOp( + rewriter, res->op, options.paddingValueComputationFunction, + options.packPaddingComputationFunction, paddedOp))) { filter.replaceLinalgTransformationFilter(rewriter, paddedOp); res->op = paddedOp; result = *res; diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 packed-operands=0,1 tile-sizes=2,3,4" -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 packed-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor @@ -24,7 +24,7 @@ // CHECK: : tensor to tensor<2x4xi8> // CHECK: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<4x3xi8> -// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] nofold low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] +// CHECK: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK: : tensor to tensor<2x3xi32> // CHECK: %[[pD:.*]] = linalg.matmul ins(%[[pA]], %[[pB]] : tensor<2x4xi8>, tensor<4x3xi8>) // CHECK-SAME: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -113,6 +113,10 @@ *this, "padded-operands", llvm::cl::desc("Operands to pad when test-tile-pattern"), llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + ListOption packedOperands{ + *this, "packed-operands", + llvm::cl::desc("Operands to pack when test-tile-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), @@ -581,6 +585,7 @@ static void applyTilePattern(FuncOp funcOp, std::string loopType, ArrayRef tileSizes, ArrayRef paddedOperands, + ArrayRef packedOperands, ArrayRef peeledLoops, bool scalarizeDynamicDims) { MLIRContext *context = funcOp.getContext(); @@ -608,7 +613,13 @@ return failure(); return getNeutralOfLinalgOp(b, opOperand); }; + auto packFunc = [&](OpOperand &opOperand) { + if (llvm::count(packedOperands, opOperand.getOperandNumber()) != 0) + return true; + return false; + }; linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); + linalgTilingOptions.setPackPaddingComputationFunction(packFunc); } tilingPattern.add, linalg::LinalgTilingPattern>( @@ -743,9 +754,11 @@ skipPartial); if (testTilePattern) return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, - peeledLoops, /*scalarizeDynamicDims=*/false); + packedOperands, peeledLoops, + /*scalarizeDynamicDims=*/false); if (testTileScalarizeDynamicDims) return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, + packedOperands, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) {