diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -173,7 +173,7 @@ /// } /// ``` /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. -static FailureOr> +static SmallVector yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, ValueRange yieldedValues, ArrayRef> tileOffsetsList, @@ -245,6 +245,27 @@ } } +/// Helper method to yield the values of the tiled op, as well as +/// update the destination operands of the tiled op, if it is +/// a destination passing style op. +static SmallVector +yieldTiledValues(RewriterBase &rewriter, ArrayRef initValues, + Operation *tiledOp, + ArrayRef> tileOffsetsList, + ArrayRef> tileSizesList, + MutableArrayRef loops) { + SmallVector replacements = + yieldTiledValues(rewriter, initValues, tiledOp->getResults(), + tileOffsetsList, tileSizesList, loops); + if (auto dstOp = dyn_cast(tiledOp)) { + auto innerMostLoop = loops.back(); + SmallVector tiledOpDestinationTensors = dstOp.getDpsInitOperands(); + updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, + innerMostLoop.getRegionIterArgs()); + } + return replacements; +} + /// Implementation of tiling transformation of `op` that implements the /// `TilingInterface` using `scf.for` to iterate over the tiles. FailureOr @@ -258,12 +279,6 @@ op, "missing tile size computation function"); } - // Get destination tensors. - SmallVector destinationTensors; - if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, - destinationTensors))) - return rewriter.notifyMatchFailure(op, "failed to get destinations"); - // 1. Get the range of the loops that are represented by the operation. SmallVector iterationDomain = op.getIterationDomain(rewriter); size_t numLoops = iterationDomain.size(); @@ -362,24 +377,14 @@ } } - FailureOr> replacementOr = yieldTiledValues( - rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(), - resultOffsetsList, resultSizesList, tilingResult.loops); - if (failed(replacementOr)) - return rewriter.notifyMatchFailure(op, "failed to yield replacement"); - - if (auto dstOp = - dyn_cast(tilingResult.tiledOps.back())) { - auto innerMostLoop = tilingResult.loops.back(); - SmallVector destinationTensors = dstOp.getDpsInitOperands(); - assert(destinationTensors.size() == - innerMostLoop.getRegionIterArgs().size() && - "unexpected number of outputs"); - updateDestinationOperandsForTiledOp(rewriter, destinationTensors, - innerMostLoop.getRegionIterArgs()); - } + SmallVector destinationTensors; + if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, + destinationTensors))) + return rewriter.notifyMatchFailure(op, "failed to get destinations"); - tilingResult.replacements = *replacementOr; + tilingResult.replacements = yieldTiledValues( + rewriter, destinationTensors, tilingResult.tiledOps.back(), + resultOffsetsList, resultSizesList, tilingResult.loops); LLVM_DEBUG({ if (!tilingResult.loops.empty()) { @@ -449,11 +454,9 @@ resultSizesList.push_back( b.createOrFold(loc, parallelOp->getResult(0), i)); SmallVector outOffsets(offsets.size(), b.getIndexAttr(0)); - FailureOr> replacementOr = yieldTiledValues( + SmallVector replacements = yieldTiledValues( b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, resultSizesList, loops); - if (failed(replacementOr)) - return b.notifyMatchFailure(op, "failed to yield replacement"); auto dstOp = cast(parallelOp); auto innerMostLoop = loops.back(); @@ -466,7 +469,7 @@ // 4. Apply the merge reduction to combine all the partial values. b.setInsertionPointAfter(*loops.begin()); - Operation *mergeOp = op.mergeReductions(b, loc, *replacementOr, reductionDim); + Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim); b.replaceOp(op, mergeOp->getResults()); SCFReductionTilingResult results;