diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -185,6 +185,13 @@ newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); } + SmallVector> reifiedResultShapes; + if (failed(cast(opToPad.getOperation()) + .reifyResultShapes(rewriter, reifiedResultShapes))) + return failure(); + if (reifiedResultShapes.size() != opToPad->getNumResults()) + return failure(); + // Clone `opToPad` to operate on the statically padded shapes. auto resultTensorTypes = ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes(); @@ -192,28 +199,22 @@ // Recover the slice out of the new static results. This keeps the original // linalg op around because it uses the dims of the original results. - // This later folds away. SmallVector paddedSubviewResults; paddedSubviewResults.reserve(opToPad->getNumResults()); SetVector newUsersOfOpToPad; - for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) { - auto rank = std::get<0>(it).getType().cast().getRank(); + for (auto en : llvm::enumerate(paddedOp->getResults())) { + Value paddedResult = en.value(); + int64_t resultNumber = en.index(); + int64_t rank = paddedResult.getType().cast().getRank(); SmallVector offsets(rank, rewriter.getIndexAttr(0)); - auto sizes = llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, rank), [&](unsigned d) -> OpFoldResult { - auto dimOp = rewriter.create(loc, std::get<0>(it), d); - newUsersOfOpToPad.insert(dimOp); - return dimOp.getResult(); - })); + SmallVector sizes; + for (Value v : reifiedResultShapes[resultNumber]) + sizes.push_back(v); SmallVector strides(rank, rewriter.getIndexAttr(1)); paddedSubviewResults.push_back(rewriter.create( - loc, std::get<1>(it), offsets, sizes, strides)); + loc, paddedResult, offsets, sizes, strides)); } - // Replace the transient `opToPad` locally, except for uses that we just - // created for the purpose of extracting the dims. - rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) { - return !newUsersOfOpToPad.contains(opOp.getOwner()); - }); + rewriter.replaceOp(opToPad, paddedSubviewResults); return success(); } @@ -244,14 +245,16 @@ return failure(); // Setup RAII guard to return properly. + LinalgOp paddedOp; LinalgOp tiledOp = res->op; auto guard = llvm::make_scope_exit([&]() { // Return relevant information to derived pattern. result = *res; - // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary. - filter.replaceLinalgTransformationFilter(rewriter, tiledOp); - if (tiledOp != res->op) - filter.replaceLinalgTransformationFilter(rewriter, res->op); + // Update filter. + if (paddedOp) + filter.replaceLinalgTransformationFilter(rewriter, paddedOp); + else + filter.replaceLinalgTransformationFilter(rewriter, tiledOp); }); // Consider padding on the fly only if the op has tensor semantics. @@ -261,7 +264,6 @@ // Try to pad on the fly by rewriting res->op as a padded op. If successful, // `res.op` is rewritten in static form with padded operands. - LinalgOp paddedOp; if (succeeded(rewriteAsPaddedOp(rewriter, res->op, options.paddingValueComputationFunction, paddedOp))) { diff --git a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir --- a/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir +++ b/mlir/test/Dialect/Linalg/tile-pad-tensor-op.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -resolve-shaped-type-result-dims -cse -split-input-file | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3" -cse -split-input-file | \ // RUN: FileCheck %s -check-prefix=TILE2 // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,3" -resolve-shaped-type-result-dims -cse -split-input-file | \ // RUN: FileCheck %s -check-prefix=TILE1