diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -90,17 +90,25 @@ }; /// Fuse the producer of the source of `candidateSliceOp` by computing the -/// required slice of the producer in-place. `yieldFusedProducerReplacement` -/// results in reconstruction of the fused producer from within the -/// tiled-and-fused code. Based on the slice of the producer computed in place -/// it is possible that within the loop nest same slice of the producer is -/// computed multiple times. It is in general not possible to recompute the -/// value of the fused producer from the tiled loop code in such cases. For the -/// cases where no slice of the producer is computed in a redundant fashion it -/// is possible to reconstruct the value of the original producer from within -/// the tiled loop. It is upto the caller to ensure that when -/// `yieldFusedProducerReplacement` is set to `true`, the producer is not -/// computed redundantly within the tiled loop nest. For example, consider +/// required slice of the producer in-place. +struct SCFFuseProducerOfSliceResult { + OpResult origProducer; // Original untiled producer. + Value tiledAndFusedProducer; // Tile and fused producer value. +}; +std::optional +tileAndFuseProducerOfSlice(RewriterBase &rewriter, + tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops); + +/// Reconstruct the fused producer from within the tiled-and-fused code. Based +/// on the slice of the producer computed in place it is possible that within +/// the loop nest same slice of the producer is computed multiple times. It is +/// in general not possible to recompute the value of the fused producer from +/// the tiled loop code in such cases. For the cases where no slice of the +/// producer is computed in a redundant fashion it is possible to reconstruct +/// the value of the original producer from within the tiled loop. It is upto +/// the caller to ensure that the producer is not computed redundantly within +/// the tiled loop nest. For example, consider /// /// ```mlir /// %0 = linalg.matmul ins(...) outs(...) -> tensor @@ -142,9 +150,10 @@ /// where `%0` had other uses as well. If not reconstructed from within the loop /// body, uses of `%0` could not be replaced, making it still live and the /// fusion immaterial. -std::optional tileAndFuseProducerOfSlice( - RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, - MutableArrayRef loops, bool yieldFusedProducerReplacement); +void yieldReplacementForFusedProducer( + RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, + scf::SCFFuseProducerOfSliceResult fusedProducerInfo, + MutableArrayRef loops); /// Transformation information returned after tile and fuse. struct SCFTileAndFuseResult { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -507,9 +507,10 @@ /// Implementation of fusing producer of a single slice by computing the /// slice of the producer in-place. -std::optional mlir::scf::tileAndFuseProducerOfSlice( - RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp, - MutableArrayRef loops, bool yieldFusedProducerReplacement) { +std::optional +mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, + tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops) { // 1. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) auto [fusableProducer, destinationIterArg] = @@ -598,30 +599,34 @@ innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); } } + return scf::SCFFuseProducerOfSliceResult{fusableProducer, + fusedProducerValue.value()}; +} - if (yieldFusedProducerReplacement) { - SmallVector initValues; - FailureOr initValue = tensor::getOrCreateDestination( - rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); - if (succeeded(initValue)) { - SmallVector resultOffsets = - candidateSliceOp.getMixedOffsets(); - SmallVector resultSizes = candidateSliceOp.getMixedSizes(); - SmallVector yieldedVals = yieldTiledValues( - rewriter, initValue.value(), fusedProducerValue.value(), - resultOffsets, resultSizes, loops); - } - if (auto dstStyleProducer = - fusedProducerValue.value() - .getDefiningOp()) { - Value dstValue = - dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) - ->get(); - updateDestinationOperandsForTiledOp( - rewriter, dstValue, loops.back().getRegionIterArgs().back()); - } +/// Reconstruct the fused producer from within the tiled-and-fused code. +void mlir::scf::yieldReplacementForFusedProducer( + RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, + scf::SCFFuseProducerOfSliceResult fusedProducerInfo, + MutableArrayRef loops) { + auto [fusableProducer, fusedProducerValue] = fusedProducerInfo; + SmallVector initValues; + FailureOr initValue = tensor::getOrCreateDestination( + rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); + if (succeeded(initValue)) { + SmallVector resultOffsets = sliceOp.getMixedOffsets(); + SmallVector resultSizes = sliceOp.getMixedSizes(); + SmallVector yieldedVals = + yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, + resultOffsets, resultSizes, loops); + } + if (auto dstStyleProducer = + fusedProducerValue.getDefiningOp()) { + Value dstValue = + dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) + ->get(); + updateDestinationOperandsForTiledOp( + rewriter, dstValue, loops.back().getRegionIterArgs().back()); } - return fusedProducerValue->getDefiningOp(); } /// Implementation of tile consumer and fuse producer greedily. @@ -685,13 +690,17 @@ // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. - Optional fusedProducer = tileAndFuseProducerOfSlice( - rewriter, candidateSliceOp, tileAndFuseResult.loops, false); + std::optional fusedProducer = + tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, + tileAndFuseResult.loops); if (!fusedProducer) continue; - tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value()); - addCandidateSlices(fusedProducer.value(), candidates); + if (Operation *tiledAndFusedOp = + fusedProducer->tiledAndFusedProducer.getDefiningOp()) { + tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); + addCandidateSlices(tiledAndFusedOp, candidates); + } } return tileAndFuseResult; } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -307,28 +307,29 @@ tensor::ExtractSliceOp candidateSliceOp = candidates.front(); candidates.pop_front(); - // Check if the fused producer has other uses that require the value - // to be yielded from within the tiled loop. - OpResult untiledProducer = getProducerOfSlice( - tilingResult->loops, &candidateSliceOp->getOpOperand(0)); - bool yieldResult = - untiledProducer && - llvm::any_of( - untiledProducer.getOwner()->getUsers(), [&](Operation *user) { - return !isIgnoredUser(user, tilingResult->loops.front()); - }); - // Materialize the slice of the producer in place. - Optional fusedProducer = tileAndFuseProducerOfSlice( - rewriter, candidateSliceOp, tilingResult->loops, yieldResult); + std::optional fusedProducer = + tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, + tilingResult->loops); if (!fusedProducer) continue; - if (yieldResult) { + + // Check if the fused producer has other uses that require the value + // to be yielded from within the tiled loop. + OpResult untiledProducer = fusedProducer->origProducer; + if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) { + return !isIgnoredUser(user, tilingResult->loops.front()); + })) { + yieldReplacementForFusedProducer(rewriter, candidateSliceOp, + fusedProducer.value(), + tilingResult->loops); yieldedValuesToOrigValues.push_back(untiledProducer); } // Add more fusion candidates to the worklist. - addCandidateSlices(fusedProducer.value(), candidates); + if (auto fusedProducerOp = + fusedProducer->tiledAndFusedProducer.getDefiningOp()) + addCandidateSlices(fusedProducerOp, candidates); } scf::ForOp outermostLoop = tilingResult->loops.front(); @@ -368,22 +369,6 @@ return producers; } - /// Get the producer for the source of slice. When using `scf.for` - /// for tile and fuse, might need to walk the `iter_args` to get - /// to the actual producer. - OpResult getProducerOfSlice(ArrayRef loops, - OpOperand *source) const { - auto loopIt = loops.rbegin(); - while (auto iterArg = source->get().dyn_cast()) { - scf::ForOp loop = *loopIt; - if (iterArg.getOwner()->getParentOp() != loop) - break; - source = &loop.getOpOperandForRegionIterArg(iterArg); - loopIt++; - } - return source->get().dyn_cast(); - } - scf::SCFTilingOptions options; LinalgTransformationFilter filter; };