diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -366,10 +366,7 @@ /// shaped results. If padding fails, return failure. FailureOr> rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, - ArrayRef paddingDimensions, - ArrayRef padToMultipleOf, - ArrayRef paddingValues, - ArrayRef packPaddings, LinalgOp &paddedOp); + const LinalgPaddingOptions &options, LinalgOp &paddedOp); namespace detail { @@ -455,7 +452,7 @@ /// specified in `options`. FailureOr padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, - LinalgPaddingOptions options); + const LinalgPaddingOptions &options); /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1583,15 +1583,18 @@ transposePaddings.push_back( extractFromI64ArrayAttr(cast(transposeVector))); + // Set up options and pad. LinalgOp paddedOp; - SmallVector paddingDimensions = - extractFromI64ArrayAttr(getPaddingDimensions()); - SmallVector padToMultipleOf(paddingDimensions.size(), 1); + LinalgPaddingOptions options; + options.paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); + SmallVector padToMultipleOf(options.paddingDimensions.size(), 1); if (getPadToMultipleOf().has_value()) padToMultipleOf = extractFromI64ArrayAttr(*getPadToMultipleOf()); + options.padToMultipleOf = padToMultipleOf; + options.paddingValues = paddingValues; + options.packPaddings = packPaddings; FailureOr> result = - rewriteAsPaddedOp(rewriter, target, paddingDimensions, padToMultipleOf, - paddingValues, packPaddings, paddedOp); + rewriteAsPaddedOp(rewriter, target, options, paddedOp); if (succeeded(result)) { // We need to perform our own replacement here because this API is still // used in patterns that "pad and hoist", for which the replacement values diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -20,37 +20,30 @@ #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #define DBGSNL() (llvm::dbgs() << "\n") -/// Pad the `opOperand` in the `paddingDimensions` using the padding value and -/// the nofold flag found in `paddingValues` and `packPaddings`, respectively. -/// -/// Exit early and return the `opOperand` value if it already has the requested -/// shape. I.e.: -/// - static shape -/// - nofold is not set -/// - dim sizes are multiples of `padToMultipleOf` -/// -/// Otherwise, try to pad the shape dimensions that match the iterator -/// dimensions `paddingDimensions` and return the tensor::PadOp result if -/// padding succeeds or failure otherwise. -static FailureOr padOperandToSmallestStaticBoundingBox( - RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, - ArrayRef paddingDimensions, ArrayRef padToMultipleOf, - ArrayRef paddingValues, ArrayRef packPaddings) { - assert(padToMultipleOf.size() == paddingDimensions.size() && - "invalid number of elements in padToMultipleOf"); - +/// Compute the padded shape of the given operand. The operand is padded to a +/// static bounding box according to the specified options. +static LogicalResult computePaddedShape(linalg::LinalgOp opToPad, + OpOperand *opOperand, + const LinalgPaddingOptions &options, + SmallVector &paddedShape, + bool &alreadyHasRequestedShape) { AffineMap indexingMap = opToPad.getMatchingIndexingMap(opOperand); ArrayRef shape = opToPad.getShape(opOperand); - // Collect the shape dimensions that are a function of `paddingDimensions`, + // Collect the shape dimensions that are a function of "paddingDimensions", // along with the multiple that they should be padded to ("1" if none). - bool alreadyHasRequestedShape = true; + alreadyHasRequestedShape = true; DenseMap shapeDimToMultiple; - for (const auto &dimEn : enumerate(paddingDimensions)) { + for (const auto &dimEn : enumerate(options.paddingDimensions)) { for (const auto &en : enumerate(indexingMap.getResults())) { if (en.value().isFunctionOfDim(dimEn.value())) { int64_t dimSize = shape[en.index()]; - shapeDimToMultiple[en.index()] = padToMultipleOf[dimEn.index()]; + if (options.padToMultipleOf.has_value()) { + shapeDimToMultiple[en.index()] = + (*options.padToMultipleOf)[dimEn.index()]; + } else { + shapeDimToMultiple[en.index()] = 1; + } if (ShapedType::isDynamic(dimSize)) { alreadyHasRequestedShape = false; } else if (dimSize % shapeDimToMultiple[en.index()] != 0) { @@ -60,29 +53,13 @@ } } - // Return the unpadded operand if padding to a static shape is not needed and - // if the nofold flag is not set. - bool nofold = opOperand->getOperandNumber() < packPaddings.size() - ? packPaddings[opOperand->getOperandNumber()] - : false; - if (!nofold && alreadyHasRequestedShape) - return opOperand->get(); - - // Fail if `paddingValues` specifies no padding value. - if (opOperand->getOperandNumber() >= paddingValues.size()) { - return rewriter.notifyMatchFailure(opToPad, "--no padding value specified"); - } - Attribute paddingAttr = paddingValues[opOperand->getOperandNumber()]; - Value paddingValue = rewriter.create( - opToPad.getLoc(), cast(paddingAttr)); - // Helper function to round a number up to a given multiple. auto ceil = [](int64_t val, int64_t multiple) { return ((val + multiple - 1) / multiple) * multiple; }; // Upper bound the sizes to obtain a static bounding box. - SmallVector paddedShape(shape.begin(), shape.end()); + paddedShape.assign(shape.begin(), shape.end()); for (int64_t i = 0, e = shape.size(); i < e; ++i) { LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n"); // Skip dimensions that do not require padding. @@ -97,13 +74,58 @@ /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "----count not compute a bounding box for padding"); - return rewriter.notifyMatchFailure( - opToPad, "count not compute a bounding box for padding"); + return failure(); } paddedShape[i] = ceil(*upperBound, shapeDimToMultiple[i]); LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n"); } + return success(); +} + +/// Pad the `opOperand` in the "paddingDimensions" using the padding value and +/// the nofold flag found in "paddingValues" and "packPaddings", respectively. +/// +/// Exit early and return the `opOperand` value if it already has the requested +/// shape. I.e.: +/// - static shape +/// - nofold is not set +/// - dim sizes are multiples of "padToMultipleOf" +/// +/// Otherwise, try to pad the shape dimensions that match the iterator +/// dimensions "paddingDimensions" and return the tensor::PadOp result if +/// padding succeeds or failure otherwise. +static FailureOr padOperandToSmallestStaticBoundingBox( + RewriterBase &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, + const LinalgPaddingOptions &options) { + assert(!options.padToMultipleOf.has_value() || + options.padToMultipleOf->size() == options.paddingDimensions.size() && + "invalid number of elements in padToMultipleOf"); + + // Compute padded shape. + SmallVector paddedShape; + bool alreadyHasRequestedShape = false; + if (failed(computePaddedShape(opToPad, opOperand, options, paddedShape, + alreadyHasRequestedShape))) + return rewriter.notifyMatchFailure(opToPad, + "--failed to compute padded shape"); + + // Return the unpadded operand if padding to a static shape is not needed and + // if the nofold flag is not set. + bool nofold = opOperand->getOperandNumber() < options.packPaddings.size() + ? options.packPaddings[opOperand->getOperandNumber()] + : false; + if (!nofold && alreadyHasRequestedShape) + return opOperand->get(); + + // Fail if `paddingValues` specifies no padding value. + if (opOperand->getOperandNumber() >= options.paddingValues.size()) { + return rewriter.notifyMatchFailure(opToPad, "--no padding value specified"); + } + Attribute paddingAttr = options.paddingValues[opOperand->getOperandNumber()]; + Value paddingValue = rewriter.create( + opToPad.getLoc(), cast(paddingAttr)); + // Pad the operand to the bounding box defined by `paddedShape`. auto paddedTensorType = RankedTensorType::get( paddedShape, getElementTypeOrSelf(opOperand->get())); @@ -115,10 +137,8 @@ FailureOr> linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, - ArrayRef paddingDimensions, - ArrayRef padToMultipleOf, - ArrayRef paddingValues, - ArrayRef packPaddings, LinalgOp &paddedOp) { + const LinalgPaddingOptions &options, + LinalgOp &paddedOp) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad->getLoc(); @@ -136,8 +156,7 @@ newOperands.reserve(opToPad->getNumOperands()); for (OpOperand &opOperand : opToPad->getOpOperands()) { FailureOr paddedOperand = padOperandToSmallestStaticBoundingBox( - rewriter, opToPad, &opOperand, paddingDimensions, padToMultipleOf, - paddingValues, packPaddings); + rewriter, opToPad, &opOperand, options); // Exit if `paddingDimensions` cannot be bounded statically. if (failed(paddedOperand)) { LLVM_DEBUG(DBGS() << "--operand cannot be bound statically : " @@ -183,20 +202,15 @@ FailureOr mlir::linalg::padAndHoistLinalgOp(RewriterBase &rewriter, LinalgOp linalgOp, - LinalgPaddingOptions options) { + const LinalgPaddingOptions &options) { if (!linalgOp.hasTensorSemantics()) return rewriter.notifyMatchFailure( linalgOp, "only applies to Linalg ops with tensor semantics"); // Pad the operation. LinalgOp paddedOp; - SmallVector padToMultipleOf(options.paddingDimensions.size(), 1); - if (options.padToMultipleOf.has_value()) - padToMultipleOf.assign(options.padToMultipleOf->begin(), - options.padToMultipleOf->end()); - FailureOr> newResults = rewriteAsPaddedOp( - rewriter, linalgOp, options.paddingDimensions, padToMultipleOf, - options.paddingValues, options.packPaddings, paddedOp); + FailureOr> newResults = + rewriteAsPaddedOp(rewriter, linalgOp, options, paddedOp); if (failed(newResults)) return rewriter.notifyMatchFailure(linalgOp, "failed to rewrite as a padded op");