diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -199,9 +199,11 @@ /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the /// fused producer of fails if fusion is not possible. - // TODO: add replace uses callback to support passes and patterns. FailureOr fuseProducer(OpBuilder &b, OpOperand *rootOpOperand); + /// Returns the replacement results for the original untiled root operation. + ValueRange getRootOpReplacementResults(); + /// Returns the tiled root operation. LinalgOp getRootOp() { return rootOp; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -245,10 +245,15 @@ .setLoopType(LinalgTilingLoopType::Loops); Optional tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions); - // Replace all uses of the root operation. + // Exit if tiling the root operation fails. if (!tiledRootOp.hasValue()) return failure(); - rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); + + // Replace all uses of the root operation if it has been tiled before. All + // uses of the original untiled root operation are updated by the calling pass + // or pattern. + if (!isEmpty()) + rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); // Update the root operation and append the loops and tile loop dimensions. rootOp = tiledRootOp->op; @@ -323,6 +328,11 @@ return clonedOp; } +ValueRange TileLoopNest::getRootOpReplacementResults() { + assert(!isEmpty() && "expect tile loop nest to be non-empty"); + return loopOps.front()->getOpResults(); +} + //===----------------------------------------------------------------------===// // Tile and fuse entry-points. //===----------------------------------------------------------------------===// @@ -433,9 +443,13 @@ "expect the tile interchange permutes the root loops"); // Tile `rootOp` and fuse its producers. - if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, - rootInterchange))) + FailureOr tileLoopNest = + tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange); + if (failed(tileLoopNest)) return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly"); + + // Replace all uses of the tiled loop operation. + rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); } }; } // namespace