diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -23,6 +23,7 @@ class RewriterBase; namespace linalg { +struct ForallTilingResult; class GenericOp; class LinalgOp; } // namespace linalg @@ -48,12 +49,13 @@ namespace transform { /// Implementation of tiling operations using `scf.forall`. -DiagnosedSilenceableFailure tileToForallOpImpl( - RewriterBase &rewriter, transform::TransformState &state, - TransformOpInterface transformOp, ArrayRef targets, - ArrayRef mixedNumThreads, - ArrayRef mixedTileSizes, std::optional mapping, - SmallVector &tileOps, SmallVector &tiledOps); +DiagnosedSilenceableFailure +tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state, + TransformOpInterface transformOp, Operation *target, + ArrayRef mixedNumThreads, + ArrayRef mixedTileSizes, + std::optional mapping, + linalg::ForallTilingResult &tilingResult); } // namespace transform } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2506,15 +2506,11 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl( RewriterBase &rewriter, transform::TransformState &state, - TransformOpInterface transformOp, ArrayRef targets, + TransformOpInterface transformOp, Operation *target, ArrayRef mixedNumThreads, ArrayRef mixedTileSizes, std::optional mapping, - SmallVector &tileOps, SmallVector &tiledOps) { - if (targets.empty()) - return DiagnosedSilenceableFailure::success(); - + linalg::ForallTilingResult &tilingResult) { // Transform all targets one by one. - for (Operation *target : targets) { auto tileableOp = dyn_cast(target); if (!tileableOp) { DiagnosedSilenceableFailure diag = @@ -2524,23 +2520,21 @@ return diag; } rewriter.setInsertionPoint(tileableOp); - FailureOr tilingResult = failure(); + FailureOr maybeTilingResult = failure(); if (!mixedNumThreads.empty()) { - tilingResult = linalg::tileToForallOp(rewriter, tileableOp, - mixedNumThreads, mapping); + maybeTilingResult = linalg::tileToForallOp(rewriter, tileableOp, + mixedNumThreads, mapping); } else { - tilingResult = linalg::tileToForallOpUsingTileSizes( + maybeTilingResult = linalg::tileToForallOpUsingTileSizes( rewriter, tileableOp, mixedTileSizes, mapping); } - if (failed(tilingResult)) + if (failed(maybeTilingResult)) return transformOp.emitDefaultSilenceableFailure(tileableOp); - rewriter.replaceOp(tileableOp, tilingResult->tileOp->getResults()); + rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults()); - tileOps.push_back(tilingResult->tileOp); - tiledOps.push_back(tilingResult->tiledOp); - } - return DiagnosedSilenceableFailure::success(); + tilingResult = *maybeTilingResult; + return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure @@ -2577,12 +2571,16 @@ if (!status.succeeded()) return status; - DiagnosedSilenceableFailure diag = - tileToForallOpImpl(rewriter, state, transformOp, targets, mixedNumThreads, - mixedTileSizes, getMapping(), tileOps, tiledOps); - - if (!diag.succeeded()) - return diag; + for (Operation *target : targets) { + linalg::ForallTilingResult tilingResult; + DiagnosedSilenceableFailure diag = tileToForallOpImpl( + rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, + getMapping(), tilingResult); + if (!diag.succeeded()) + return diag; + tileOps.push_back(tilingResult.tileOp); + tiledOps.push_back(tilingResult.tiledOp); + } transformResults.set(getForallOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps);