diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1085,10 +1085,13 @@ PatternRewriter &rewriter) const override; }; -/// Try to create a static bounding box around each operand of `opToPad`. -/// If successful, `paddedOp` will be updated to the cloned static form. -LogicalResult -rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, +/// Pad the operands of `opToPad` to a static bounding box. Use `paddingFunc` +/// and `nofoldFunc` to set the padding value and the nofold attribute of the +/// introduced PadTensorOps, respectively. Update `paddedOp` to the cloned +/// statically shaped operation and return the extracted dynamically shaped +/// results. If padding fails, return failure. +FailureOr> +rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, const PaddingValueComputationFunction &paddingFunc, const PaddingNoFoldComputationFunction &nofoldFunc, LinalgOp &paddedOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -152,14 +152,14 @@ /// result of the created PadTensorOp or return failure if the operand cannot be /// padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( - PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, + OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, const PaddingValueComputationFunction &paddingFunc, const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { // Can't pad scalars. if (opToPad.getShape(opOperand).empty()) return success(); // Can't pad if no padding value is known. - FailureOr paddingValue = paddingFunc(rewriter, *opOperand); + FailureOr paddingValue = paddingFunc(b, *opOperand); if (failed(paddingValue)) return success(); auto sliceOp = opOperand->get().getDefiningOp(); @@ -175,9 +175,10 @@ : linalg::getSmallestBoundingIndex(size.get()); // SmallestBoundingIndex must exist for all sizes. // For now return an error if we can't find it. - if (!indexAttr) - return rewriter.notifyMatchFailure( - opToPad, "No constant bounding box can be found for padding"); + if (!indexAttr) { + LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); + return failure(); + } staticSizes.push_back(indexAttr.getInt()); } auto staticTensorType = RankedTensorType::get( @@ -185,12 +186,12 @@ bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false; result = linalg::PadTensorOp::createPadHighOp( staticTensorType, opOperand->get(), paddingValue.getValue(), - /*nofold=*/nofold, opToPad->getLoc(), rewriter); + /*nofold=*/nofold, opToPad->getLoc(), b); return success(); } -LogicalResult -linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, +FailureOr> +linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad, const PaddingValueComputationFunction &paddingFunc, const PaddingNoFoldComputationFunction &nofoldFunc, LinalgOp &paddedOp) { @@ -200,9 +201,9 @@ assert(opToPad.hasTensorSemantics() && "expected operation to have tensor semantics"); - OpBuilder::InsertionGuard g(rewriter); + OpBuilder::InsertionGuard g(b); // Set IP after op because we also take the dims of the original output. - rewriter.setInsertionPointAfter(opToPad); + b.setInsertionPointAfter(opToPad); // Make a copy of the shaped operands and update it. SmallVector newOperands; newOperands.reserve(opToPad.getNumInputsAndOutputs()); @@ -211,15 +212,14 @@ // If padding was requested but the shape cannot be bounded statically then // the pattern fails to apply. if (failed(padOperandToSmallestStaticBoundingBox( - rewriter, opToPad, opOperand, paddingFunc, nofoldFunc, - paddedOperand))) + b, opToPad, opOperand, paddingFunc, nofoldFunc, paddedOperand))) return failure(); newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); } SmallVector> reifiedResultShapes; if (failed(cast(opToPad.getOperation()) - .reifyResultShapes(rewriter, reifiedResultShapes))) + .reifyResultShapes(b, reifiedResultShapes))) return failure(); assert(reifiedResultShapes.size() == opToPad->getNumResults() && "expected same number of results"); @@ -227,7 +227,7 @@ // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); - paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); + paddedOp = opToPad.clone(b, loc, resultTensorTypes, newOperands); // Recover the slice out of the new static results. This keeps the original // linalg op around because it uses the dims of the original results. @@ -237,16 +237,15 @@ Value paddedResult = en.value(); int64_t resultNumber = en.index(); int64_t rank = paddedResult.getType().cast().getRank(); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector offsets(rank, b.getIndexAttr(0)); SmallVector sizes; for (Value v : reifiedResultShapes[resultNumber]) sizes.push_back(v); - SmallVector strides(rank, rewriter.getIndexAttr(1)); - paddedSubviewResults.push_back(rewriter.create( + SmallVector strides(rank, b.getIndexAttr(1)); + paddedSubviewResults.push_back(b.create( loc, paddedResult, offsets, sizes, strides)); } - rewriter.replaceOp(opToPad, paddedSubviewResults); - return success(); + return paddedSubviewResults; } /// Linalg base tiling pattern. @@ -345,9 +344,11 @@ // Try to pad on the fly by rewriting res->op as a padded op. If successful, // `res.op` is rewritten in static form with padded operands. LinalgOp paddedOp; - if (succeeded(rewriteAsPaddedOp( - rewriter, res->op, options.paddingValueComputationFunction, - options.paddingNoFoldComputationFunction, paddedOp))) { + FailureOr> newResults = rewriteAsPaddedOp( + rewriter, res->op, options.paddingValueComputationFunction, + options.paddingNoFoldComputationFunction, paddedOp); + if (succeeded(newResults)) { + rewriter.replaceOp(res->op, newResults.getValue()); filter.replaceLinalgTransformationFilter(rewriter, paddedOp); res->op = paddedOp; result = *res;