diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -62,8 +62,10 @@ /// Transformation information returned after tiling. struct SCFTilingResult { - /// The tiled operation generated. - Operation *tiledOp; + /// Tiled operations that are generated during tiling. The order does not + /// matter except the last op. The replacements are expected to be the results + /// of the last op. + SmallVector tiledOps; /// The `scf.for` operations that iterate over the tiles. SmallVector loops; /// Values to use as replacements for the untiled op. Is the same size as the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -931,7 +931,7 @@ if (failed(maybeTilingResult)) return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - results.push_back(maybeTilingResult->tiledOp); + results.append(maybeTilingResult->tiledOps); return DiagnosedSilenceableFailure(success()); } @@ -1251,7 +1251,7 @@ rewriter.replaceOp(linalgOp, maybeTilingResult->loops.front()->getResults()); - tiled.push_back(maybeTilingResult->tiledOp); + tiled.append(maybeTilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) loops[en2.index()].push_back(en2.value()); } @@ -1609,7 +1609,7 @@ rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements); - tiled.push_back(tilingResult->tiledOp); + tiled.append(tilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(tilingResult->loops)) loops[en2.index()].push_back(en2.value()); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -360,11 +360,7 @@ tilingResult.loops.back().getBody()->getTerminator()); SmallVector tiledImplementation = op.getTiledImplementation(rewriter, offsets, sizes); - if (tiledImplementation.size() != 1) { - return rewriter.notifyMatchFailure( - op, "expected tiled implementation to return a single op"); - } - tilingResult.tiledOp = tiledImplementation[0]; + tilingResult.tiledOps.append(tiledImplementation); if (op->getNumResults() == 0) { // nothing more to do. return tilingResult; @@ -396,13 +392,13 @@ } FailureOr> replacementOr = yieldTiledValues( - rewriter, destinationTensors, tilingResult.tiledOp->getResults(), + rewriter, destinationTensors, tilingResult.tiledOps.back()->getResults(), resultOffsetsList, resultSizesList, tilingResult.loops); if (failed(replacementOr)) return rewriter.notifyMatchFailure(op, "failed to yield replacement"); if (auto dstOp = - dyn_cast(tilingResult.tiledOp)) { + dyn_cast(tilingResult.tiledOps.back())) { auto innerMostLoop = tilingResult.loops.back(); SmallVector destinationTensors = dstOp.getDpsInitOperands(); assert(destinationTensors.size() == @@ -554,13 +550,14 @@ tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); if (failed(tilingResult)) return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); - tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); + for (auto tiledOp : tilingResult->tiledOps) + tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); tileAndFuseResult.loops = std::move(tilingResult->loops); for (const auto &result : llvm::enumerate( llvm::zip(consumer->getResults(), tilingResult->replacements))) { tileAndFuseResult.replacements[std::get<0>(result.value())] = std::get<1>(result.value()); - yieldedValueToResultNumber[tilingResult->tiledOp->getResult( + yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( result.index())] = result.index(); } } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -193,7 +193,8 @@ rewriter.eraseOp(op); } - filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); + for (auto tiledOp : tilingResult->tiledOps) + filter.replaceLinalgTransformationFilter(rewriter, tiledOp); return success(); }