diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -438,10 +438,11 @@ using TileSizeComputationFunction = std::function(OpBuilder &, Operation *)>; -/// Specify the padding value for an OpOperand. This should be a function of -/// both the operation and the operand type. +/// Callback returning the padding value to use for a given OpOperand or failure +/// for no padding. This should be a function of both the operation and the +/// operand type. using PaddingValueComputationFunction = - std::function; + std::function(OpBuilder &, OpOperand &)>; struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. @@ -504,10 +505,11 @@ return *this; } - /// Computation function that returns a padding value to use when padding to - /// force static sizes. When `paddingValueComputationFunction` is set, padding - /// operations are introduced, that guarantee the underlying op is statically - /// shaped and can thus be vectorized. + /// Callback returning the padding value to use for a given OpOperand or + /// failure for no padding. Padding operations are introduced if + /// `paddingValueComputationFunction` is set and does not return failure. + /// Padding all operands guarantees the operation is statically shaped and + /// thus can be vectorized. PaddingValueComputationFunction paddingValueComputationFunction = nullptr; LinalgTilingOptions & diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -145,16 +145,21 @@ return *this; } -/// Try to compute a static bounding box for `operand`. The padding happens -/// even if the operand already has static shape. `result` is the result of a -/// freshly created PadTensorOp. Return failure if the operand cannot be padded -/// to a static shape. +/// Helper function that tries to pad `opOperand`. Exit early and return success +/// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to +/// pad the operand even if it already has a static shape. Set `result` to the +/// result of the created PadTensorOp or return failure if the operand cannot be +/// padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, const PaddingValueComputationFunction &paddingFunc, Value &result) { // Can't pad scalars. if (opToPad.getShape(opOperand).empty()) return success(); + // Can't pad if no padding value is known. + FailureOr paddingValue = paddingFunc(rewriter, *opOperand); + if (failed(paddingValue)) + return success(); auto sliceOp = opOperand->get().getDefiningOp(); // Not a slice op, cannot construct a static bounding box. if (!sliceOp) @@ -173,12 +178,11 @@ opToPad, "No constant bounding box can be found for padding"); staticSizes.push_back(indexAttr.getInt()); } - Value pad = paddingFunc(rewriter, *opOperand); auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); result = linalg::PadTensorOp::createPadHighOp( - staticTensorType, opOperand->get(), pad, /*packing=*/true, - opToPad->getLoc(), rewriter); + staticTensorType, opOperand->get(), paddingValue.getValue(), + /*packing=*/true, opToPad->getLoc(), rewriter); return success(); } diff --git a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-pad-tensors.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3,4" -canonicalize | FileCheck %s -// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern pad-tiles tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1,2 tile-sizes=2,3,4" -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns="test-tile-pattern padded-operands=0,1 tile-sizes=2,3" -canonicalize | FileCheck %s -check-prefix=CHECK-1DIM-TILE // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor @@ -97,6 +97,7 @@ // CHECK: linalg.matmul_i8_i8_i32 ins({{.*}}, {{.*}} : tensor<2x4xi8>, tensor<4x3xi8>) outs({{.*}} : tensor<2x3xi32>) -> tensor<2x3xi32> +// Check only the the input operands are padded. // CHECK-1DIM-TILE: func @matmul_partially_padded_tensors( // CHECK-1DIM-TILE-SAME: %[[TA:[0-9a-z]+]]: tensor // CHECK-1DIM-TILE-SAME: %[[TB:[0-9a-z]+]]: tensor<8x?xi8> @@ -111,10 +112,8 @@ // CHECK-1DIM-TILE: : tensor to tensor<2x8xi8> // CHECK-1DIM-TILE: %[[pB:.*]] = linalg.pad_tensor %[[sTB]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] // CHECK-1DIM-TILE: : tensor<8x?xi8> to tensor<8x3xi8> -// CHECK-1DIM-TILE: %[[pC:.*]] = linalg.pad_tensor %[[sTC]] packing low[%[[C0]], %[[C0]]] high[%{{.*}}, %{{.*}}] -// CHECK-1DIM-TILE: : tensor to tensor<2x3xi32> -// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) -// CHECK-1DIM-TILE: outs(%[[pC]] : tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK-1DIM-TILE: %[[pD:.*]] = linalg.matmul_i8_i8_i32 ins(%[[pA]], %[[pB]] : tensor<2x8xi8>, tensor<8x3xi8>) +// CHECK-1DIM-TILE: outs(%[[sTC]] : tensor) -> tensor // Check that the tile-and-pad transformation actually introduces the padding // as requested, even if original operation already operates on static diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -108,9 +108,10 @@ llvm::cl::desc("Test rewrite of subtensor(pad_tensor) into " "pad_tensor(subtensor)"), llvm::cl::init(false)}; - Option padTiles{*this, "pad-tiles", - llvm::cl::desc("Pad tiles when test-tile-pattern"), - llvm::cl::init(false)}; + ListOption paddedOperands{ + *this, "padded-operands", + llvm::cl::desc("Operands to pad when test-tile-pattern"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; ListOption peeledLoops{ *this, "peeled-loops", llvm::cl::desc("Loops to be peeled when test-tile-pattern"), @@ -576,7 +577,8 @@ } static void applyTilePattern(FuncOp funcOp, std::string loopType, - ArrayRef tileSizes, bool padTiles, + ArrayRef tileSizes, + ArrayRef paddedOperands, ArrayRef peeledLoops, bool scalarizeDynamicDims) { MLIRContext *context = funcOp.getContext(); @@ -597,10 +599,15 @@ } else { linalgTilingOptions.setTileSizes(tileSizes); } - if (padTiles) - linalgTilingOptions.setPaddingValueComputationFunction( - getNeutralOfLinalgOp); - + if (!paddedOperands.empty()) { + auto paddingFunc = [&](OpBuilder &b, + OpOperand &opOperand) -> FailureOr { + if (llvm::count(paddedOperands, opOperand.getOperandNumber()) == 0) + return failure(); + return getNeutralOfLinalgOp(b, opOperand); + }; + linalgTilingOptions.setPaddingValueComputationFunction(paddingFunc); + } tilingPattern.add, linalg::LinalgTilingPattern, linalg::LinalgTilingPattern>( @@ -734,10 +741,10 @@ return applyTiledLoopPeelingPattern(getFunction(), testTiledLoopPeeling, skipPartial); if (testTilePattern) - return applyTilePattern(getFunction(), loopType, tileSizes, padTiles, + return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, peeledLoops, /*scalarizeDynamicDims=*/false); if (testTileScalarizeDynamicDims) - return applyTilePattern(getFunction(), loopType, tileSizes, padTiles, + return applyTilePattern(getFunction(), loopType, tileSizes, paddedOperands, /*peeledLoops=*/{}, /*scalarizeDynamicDims=*/true); if (testHoistPadding) { getFunction().walk([&](linalg::PadTensorOp padTensorOp) {