diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -886,6 +886,13 @@ PatternRewriter &rewriter) const override; }; +/// Try to create a static bounding box around each operand of `opToPad`. +/// If successful, `paddedOp` will be updated to the cloned static form. +LogicalResult +rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, + const PaddingValueComputationFunction &paddingFunc, + LinalgOp &paddedOp); + using OptimizeCopyFn = std::function; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -126,7 +126,7 @@ /// Return failure if the operand cannot be padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, - const LinalgTilingOptions &options, Value &result) { + const PaddingValueComputationFunction &paddingFunc, Value &result) { // Already static shape, no need to pad. if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic)) return success(); @@ -148,7 +148,7 @@ opToPad, "No constant bounding box can be found for padding"); staticSizes.push_back(indexAttr.getInt()); } - Value pad = options.paddingValueComputationFunction(rewriter, *opOperand); + Value pad = paddingFunc(rewriter, *opOperand); auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); result = linalg::PadTensorOp::createPadHighOp( @@ -156,13 +156,10 @@ return success(); } -// Try to create a static bounding box around each operand of `res.op`. -// If successful, `res.op` is rewritten in static form with padded operands. -// `res.op` is updated to the cloned static form of the op on success. -static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter, - TiledLinalgOp &res, - const LinalgTilingOptions &options) { - LinalgOp opToPad = res.op; +LogicalResult +linalg::rewriteAsPaddedOp(PatternRewriter &rewriter, LinalgOp opToPad, + const PaddingValueComputationFunction &paddingFunc, + LinalgOp &paddedOp) { Location loc = opToPad->getLoc(); // If the op is fully static, it does not need padding. @@ -183,7 +180,7 @@ // If padding was requested but the shape cannot be bounded statically then // the pattern fails to apply. if (failed(padOperandToSmallestStaticBoundingBox( - rewriter, opToPad, opOperand, options, paddedOperand))) + rewriter, opToPad, opOperand, paddingFunc, paddedOperand))) return failure(); newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); } @@ -191,8 +188,7 @@ // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); - linalg::LinalgOp paddedOp = - opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); + paddedOp = opToPad.clone(rewriter, loc, resultTensorTypes, newOperands); // Recover the slice out of the new static results. This keeps the original // linalg op around because it uses the dims of the original results. @@ -218,8 +214,6 @@ rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { return !newUsersOfOpToPad.contains(opOp.getOwner()); }); - - res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults}; return success(); } @@ -265,15 +259,19 @@ !linalgOp.hasTensorSemantics()) return success(); - // Try to pad on the fly by rewriting res->op as a padded op. - if (failed(rewriteAsPaddedOp(rewriter, *res, options))) { - // Set so RAII guard does not propagate TiledLinalgOp to `result`. - return failure(); + // Try to pad on the fly by rewriting res->op as a padded op. If successful, + // `res.op` is rewritten in static form with padded operands. + LinalgOp paddedOp; + if (succeeded(rewriteAsPaddedOp(rewriter, res->op, + options.paddingValueComputationFunction, + paddedOp))) { + res->op = paddedOp; + // Do not perform replacement of `linalgOp`, let the derived patterns + // do this as they see fit, from the resulting TiledLinalgOp. + return success(); } - - // Do not perform replacement of `linalgOp`, let the derived patterns - // do this as they see fit, from the resulting TiledLinalgOp. - return success(); + // Set so RAII guard does not propagate TiledLinalgOp to `result`. + return failure(); } static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {