diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -89,6 +89,72 @@ } }; +/// Fuse the producer of the source of `candidateSliceOp` by computing the +/// required slice of the producer in-place. +struct SCFFuseProducerOfSliceResult { + OpResult origProducer; // Original untiled producer. + Value tiledAndFusedProducer; // Tile and fused producer value. +}; +std::optional +tileAndFuseProducerOfSlice(RewriterBase &rewriter, + tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops); + +/// Reconstruct the fused producer from within the tiled-and-fused code. Based +/// on the slice of the producer computed in place it is possible that within +/// the loop nest same slice of the producer is computed multiple times. It is +/// in general not possible to recompute the value of the fused producer from +/// the tiled loop code in such cases. For the cases where no slice of the +/// producer is computed in a redundant fashion it is possible to reconstruct +/// the value of the original producer from within the tiled loop. It is upto +/// the caller to ensure that the producer is not computed redundantly within +/// the tiled loop nest. For example, consider +/// +/// ```mlir +/// %0 = linalg.matmul ins(...) outs(...) -> tensor +/// %1 = linalg.matmul ins(%0, ..) outs(...) -> tensor +/// ``` +/// +/// If `%1` is tiled in a 2D fashion and `%0` is fused with it, the resulting IR +/// is, +/// +/// ```mlir +/// %t1_0 = scf.for .... iter_args(%arg0 = ...) { +/// %t1_1 = scf.for ... iter_args(%arg1 = %arg0) { +/// ... +/// %t1_2 = linalg.matmul ins(...) outs(...) -> tensor +/// %t1_3 = linalg.matmul ins(%t1_2, ...) +/// %t1_4 = tensor.insert_slice %t1_3 into %arg1 ... +/// scf.yield %t1_4 +/// } +/// scf.yield %t1_1 +/// } +/// ``` +/// +/// Here `%t1_2` is the same for all iterations of the inner `scf.for`. Instead +/// if `%1` were tiled only along the rows, the resultant code would be +/// +/// ```mlir +/// %t2_0 = scf.for .... iter_args(%arg0 = ...) { +/// ... +/// %t2_1 = linalg.matmul ins(...) outs(...) -> tensor +/// %t2_2 = linalg.matmul ins(%t2_1, ...) +/// %t2_3 = tensor.insert_slice %t2_2 into %arg0 ... +/// scf.yield %t2_3 +/// } +/// ``` +/// +/// Here there is no intersection in the different slices of `%t2_1` computed +/// across iterations of the `scf.for`. In such cases, the value of the original +/// `%0` can be reconstructed from within the loop body. This is useful in cases +/// where `%0` had other uses as well. If not reconstructed from within the loop +/// body, uses of `%0` could not be replaced, making it still live and the +/// fusion immaterial. +void yieldReplacementForFusedProducer( + RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, + scf::SCFFuseProducerOfSliceResult fusedProducerInfo, + MutableArrayRef loops); + /// Transformation information returned after tile and fuse. struct SCFTileAndFuseResult { /// List of untiled operations that were fused with the tiled consumer. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -505,10 +505,12 @@ return {source->get().dyn_cast(), destinationIterArg}; } -static std::optional -tileAndFuseProducerOfSlice(RewriterBase &rewriter, - tensor::ExtractSliceOp candidateSliceOp, - MutableArrayRef loops) { +/// Implementation of fusing producer of a single slice by computing the +/// slice of the producer in-place. +std::optional +mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, + tensor::ExtractSliceOp candidateSliceOp, + MutableArrayRef loops) { // 1. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) auto [fusableProducer, destinationIterArg] = @@ -597,7 +599,34 @@ innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); } } - return fusedProducerValue->getDefiningOp(); + return scf::SCFFuseProducerOfSliceResult{fusableProducer, + fusedProducerValue.value()}; +} + +/// Reconstruct the fused producer from within the tiled-and-fused code. +void mlir::scf::yieldReplacementForFusedProducer( + RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, + scf::SCFFuseProducerOfSliceResult fusedProducerInfo, + MutableArrayRef loops) { + auto [fusableProducer, fusedProducerValue] = fusedProducerInfo; + SmallVector initValues; + FailureOr initValue = tensor::getOrCreateDestination( + rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); + if (succeeded(initValue)) { + SmallVector resultOffsets = sliceOp.getMixedOffsets(); + SmallVector resultSizes = sliceOp.getMixedSizes(); + SmallVector yieldedVals = + yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, + resultOffsets, resultSizes, loops); + } + if (auto dstStyleProducer = + fusedProducerValue.getDefiningOp()) { + Value dstValue = + dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) + ->get(); + updateDestinationOperandsForTiledOp( + rewriter, dstValue, loops.back().getRegionIterArgs().back()); + } } /// Implementation of tile consumer and fuse producer greedily. @@ -661,13 +690,17 @@ // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. - Optional fusedProducer = tileAndFuseProducerOfSlice( - rewriter, candidateSliceOp, tileAndFuseResult.loops); + std::optional fusedProducer = + tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, + tileAndFuseResult.loops); if (!fusedProducer) continue; - tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer.value()); - addCandidateSlices(fusedProducer.value(), candidates); + if (Operation *tiledAndFusedOp = + fusedProducer->tiledAndFusedProducer.getDefiningOp()) { + tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); + addCandidateSlices(tiledAndFusedOp, candidates); + } } return tileAndFuseResult; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -test-tiling-interface=tile-consumer-fuse-and-yield-producer-using-scf-for -cse -split-input-file %s | FileCheck %s + +func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor, %rhs0 : tensor, %rhs1 : tensor, + %init0 : tensor, %init1 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %lhs0, %c0 : tensor + %d1 = tensor.dim %rhs0, %c1 : tensor + %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor + %gemm0 = linalg.matmul + ins(%lhs0, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor + %d2 = tensor.dim %rhs1, %c1 : tensor + %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor + %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"} + ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor + return %gemm0, %gemm1 : tensor, tensor +} +// CHECK: func.func @gemm_gemm_fusion_yield_both( +// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[INIT0:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]]) +// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] +// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] +// CHECK: %[[FILL0_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_TILE]] : +// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : +// CHECK-SAME: outs(%[[FILL0_TILE]] : +// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[FILL1_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT1_TILE]] : +// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : +// CHECK-SAME: outs(%[[FILL1_TILE]] : +// CHECK: %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0] +// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -239,6 +239,140 @@ LinalgTransformationFilter filter; }; +/// Pattern to tile a consumer and fuse producer with it +/// while reconstructing the value of the fused producer +/// from within the loop nest to replace any external +/// uses of the producer. In general yielding the producer +/// this way requires a guarantee that the slice of the producer +/// is not computed redundantly within the tiled loops. An analysis that +/// figures it out has shown to be very complex. So this is left as a caller +/// side determination. In this test pattern it is assumed that the tile sizes +/// are selected such that all producers when fused into the tiled loops do no +/// have redundant computation. +struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp + : public OpInterfaceRewritePattern { + + TestTileConsumerFuseAndYieldProducerUsingSCFForOp( + MLIRContext *context, scf::SCFTilingOptions options, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} + + LogicalResult matchAndRewrite(TilingInterface rootOp, + PatternRewriter &rewriter) const override { + if (failed(filter.checkAndNotify(rewriter, rootOp))) + return failure(); + + // Collect list of operations that can be tiled and fused. + llvm::SmallDenseSet tiledAndFusedOps = + collectTiledAndFusedOps(rootOp); + auto isIgnoredUser = [&](Operation *user, scf::ForOp outerMostTiledLoop) { + return tiledAndFusedOps.count(user) || isa(user) || + outerMostTiledLoop->isAncestor(user); + }; + + // The rest of this method is similar to + // scf::tileAndFuseGreedilyUsingSCFForOp, except that also yields + // replacements for values of the fused producer. + + // 1. Tile the consumer. + SmallVector yieldedValuesToOrigValues; + FailureOr tilingResult = + scf::tileUsingSCFForOp(rewriter, rootOp, options); + if (failed(tilingResult)) { + return rewriter.notifyMatchFailure(rootOp, + "failed to tile base operation"); + } + yieldedValuesToOrigValues.append(rootOp->result_begin(), + rootOp->result_end()); + + // 2. Tiling each operation results in generation of slices. The source of + // these slices could be producers that can be fused into the tiled loops by + // computing the slices of these producers in-place. This results in more + // slices created for operands of the "fused producer". This open up more + // opportunities for fusion. Use a worklist to fuse greedily. + auto addCandidateSlices = + [](Operation *fusedOp, std::deque &candidates) { + for (Value operand : fusedOp->getOperands()) + if (auto sliceOp = operand.getDefiningOp()) + candidates.push_back(sliceOp); + }; + + std::deque candidates; + addCandidateSlices(tilingResult->tiledOps.back(), candidates); + OpBuilder::InsertionGuard g(rewriter); + while (!candidates.empty()) { + // Traverse the slices in BFS fashion. + tensor::ExtractSliceOp candidateSliceOp = candidates.front(); + candidates.pop_front(); + + // Materialize the slice of the producer in place. + std::optional fusedProducer = + tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, + tilingResult->loops); + if (!fusedProducer) + continue; + + // Check if the fused producer has other uses that require the value + // to be yielded from within the tiled loop. + OpResult untiledProducer = fusedProducer->origProducer; + if (llvm::any_of(untiledProducer.getUsers(), [&](Operation *user) { + return !isIgnoredUser(user, tilingResult->loops.front()); + })) { + yieldReplacementForFusedProducer(rewriter, candidateSliceOp, + fusedProducer.value(), + tilingResult->loops); + yieldedValuesToOrigValues.push_back(untiledProducer); + } + + // Add more fusion candidates to the worklist. + if (auto fusedProducerOp = + fusedProducer->tiledAndFusedProducer.getDefiningOp()) + addCandidateSlices(fusedProducerOp, candidates); + } + + scf::ForOp outermostLoop = tilingResult->loops.front(); + for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) { + Value replacement = outermostLoop.getResult(index); + rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) { + return !isIgnoredUser(use.getOwner(), outermostLoop); + }); + } + rewriter.eraseOp(rootOp); + filter.replaceLinalgTransformationFilter(rewriter, + tilingResult->tiledOps.back()); + return success(); + } + +private: + /// Starting from `op` walk all operands backwards to find all + /// potentially fusable operations, i.e. operations that implement + /// the `TilingInterface`. + llvm::SmallDenseSet + collectTiledAndFusedOps(Operation *op) const { + SmallVector worklist; + llvm::SmallDenseSet producers; + worklist.push_back(op); + producers.insert(op); + while (!worklist.empty()) { + Operation *current = worklist.pop_back_val(); + for (OpOperand &operand : current->getOpOperands()) { + Operation *producer = operand.get().getDefiningOp(); + if (!producer || !isa(producer) || + producers.count(producer)) + continue; + worklist.push_back(producer); + producers.insert(producer); + } + } + return producers; + } + + scf::SCFTilingOptions options; + LinalgTransformationFilter filter; +}; + /// Pattern to lower operations that implement the `TilingInterface` to /// loops/scalar IR using `scf.for`. struct LowerToLoopsUsingSCFForOp @@ -283,6 +417,13 @@ "Test tiling using TilingInterface with scf.for operations"), llvm::cl::init(false)}; + Option testTileConsumerFuseAndYieldProducer{ + *this, "tile-consumer-fuse-and-yield-producer-using-scf-for", + llvm::cl::desc( + "Test tile and fuse transformation while yielding fused producer " + "replacements using TilingInterface with scf.for operations"), + llvm::cl::init(false)}; + Option testTileConsumerAndFuseProducer{ *this, "tile-consumer-and-fuse-producer-using-scf-for", llvm::cl::desc("Test tile and fuse transformation using TilingInterface " @@ -314,6 +455,19 @@ patterns.add(context, tilingOptions, filter); } +static void addPatternForTileFuseAndYield(MLIRContext *context, + RewritePatternSet &patterns, + StringRef filterName, + ArrayRef tileSizes, + ArrayRef interchange = {}) { + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); + LinalgTransformationFilter filter(StringAttr::get(context, filterName), + StringAttr::get(context, "tiled")); + patterns.add( + context, tilingOptions, filter); +} + static void addPatternForTileAndFuse(MLIRContext *context, RewritePatternSet &patterns, StringRef filterName, @@ -375,6 +529,12 @@ {10}); return; } + if (testTileConsumerFuseAndYieldProducer) { + // 1. Fusion of back-to-back-reduction ops + addPatternForTileFuseAndYield(context, patterns, + "gemm_sequence_fusion_and_yield", {10}); + return; + } if (testLoweringToScalar) { patterns.add(context); }