diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -245,6 +245,30 @@ } } +/// Helper method to yield the values of the tiled op, as well as +/// update the destination operands of the tiled op, if it is +/// a destination passing style op. +static FailureOr> +yieldTiledValues(RewriterBase &rewriter, ArrayRef initValues, + Operation *tiledOp, + ArrayRef> tileOffsetsList, + ArrayRef> tileSizesList, + MutableArrayRef loops) { + FailureOr> replacement = + yieldTiledValues(rewriter, initValues, tiledOp->getResults(), + tileOffsetsList, tileSizesList, loops); + if (failed(replacement)) + return rewriter.notifyMatchFailure(tiledOp, "failed to yield replacement"); + + if (auto dstOp = dyn_cast(tiledOp)) { + auto innerMostLoop = loops.back(); + SmallVector tiledOpDestinationTensors = dstOp.getDpsInitOperands(); + updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, + innerMostLoop.getRegionIterArgs()); + } + return replacement.value(); +} + /// Implementation of tiling transformation of `op` that implements the /// `TilingInterface` using `scf.for` to iterate over the tiles. FailureOr @@ -258,12 +282,6 @@ op, "missing tile size computation function"); } - // Get destination tensors. - SmallVector destinationTensors; - if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, - destinationTensors))) - return rewriter.notifyMatchFailure(op, "failed to get destinations"); - // 1. Get the range of the loops that are represented by the operation. SmallVector iterationDomain = op.getIterationDomain(rewriter); size_t numLoops = iterationDomain.size(); @@ -362,22 +380,16 @@ } } + SmallVector destinationTensors; + if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, + destinationTensors))) + return rewriter.notifyMatchFailure(op, "failed to get destinations"); + FailureOr> replacementOr = yieldTiledValues( - rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(), + rewriter, destinationTensors, tilingResult.tiledOps.back(), resultOffsetsList, resultSizesList, tilingResult.loops); if (failed(replacementOr)) - return rewriter.notifyMatchFailure(op, "failed to yield replacement"); - - if (auto dstOp = - dyn_cast(tilingResult.tiledOps.back())) { - auto innerMostLoop = tilingResult.loops.back(); - SmallVector destinationTensors = dstOp.getDpsInitOperands(); - assert(destinationTensors.size() == - innerMostLoop.getRegionIterArgs().size() && - "unexpected number of outputs"); - updateDestinationOperandsForTiledOp(rewriter, destinationTensors, - innerMostLoop.getRegionIterArgs()); - } + return failure(); tilingResult.replacements = replacementOr.value();