diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -157,8 +157,12 @@ /// Fuse two `linalg.generic` operations that have a producer-consumer /// relationship captured through `fusedOperand`. The method expects /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`. -FailureOr fuseElementwiseOps(RewriterBase &rewriter, - OpOperand *fusedOperand); +struct ElementwiseOpFusionResult { + Operation *fusedOp; + llvm::DenseMap replacements; +}; +FailureOr +fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); /// Split the given `op` into two parts along the given iteration space /// `dimension` at the specified `splitPoint`, and return the two parts. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -23,8 +23,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include #include +#include namespace mlir { #define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMS @@ -73,6 +73,9 @@ /// Conditions for elementwise fusion of generic operations. bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { + if (!fusedOperand) + return false; + auto producer = fusedOperand->get().getDefiningOp(); auto consumer = dyn_cast(fusedOperand->getOwner()); @@ -270,7 +273,7 @@ "Ill-formed GenericOp region"); } -FailureOr +FailureOr mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand) { assert(areElementwiseOpsFusable(fusedOperand) && @@ -390,7 +393,15 @@ generateFusedElementwiseOpRegion( rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand, consumer.getNumLoops(), preservedProducerResults); - return fusedOp.getOperation(); + ElementwiseOpFusionResult result; + result.fusedOp = fusedOp; + int resultNum = 0; + for (auto [index, producerResult] : llvm::enumerate(producer->getResults())) + if (preservedProducerResults.count(index)) + result.replacements[producerResult] = fusedOp->getResult(resultNum++); + for (auto consumerResult : consumer->getResults()) + result.replacements[consumerResult] = fusedOp->getResult(resultNum++); + return result; } namespace { @@ -411,13 +422,20 @@ if (!controlFn(&opOperand)) continue; - FailureOr fusedOp = fuseElementwiseOps(rewriter, &opOperand); - if (succeeded(fusedOp)) { - auto replacements = - (*fusedOp)->getResults().take_back(genericOp.getNumResults()); - rewriter.replaceOp(genericOp, replacements); - return success(); + FailureOr fusionResult = + fuseElementwiseOps(rewriter, &opOperand); + if (failed(fusionResult)) + rewriter.notifyMatchFailure(genericOp, "fusion failed"); + Operation *producer = opOperand.get().getDefiningOp(); + for (auto [origVal, replacement] : fusionResult->replacements) { + Value origValCopy = origVal; + rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) { + // Only replace consumer uses. + return use.get().getDefiningOp() != producer; + }); } + rewriter.eraseOp(genericOp); + return success(); } return failure(); } diff --git a/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-multiuse-producer.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=fuse-multiuse-producer -split-input-file %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @multi_use_producer(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor, %arg3 : tensor, %arg4 : tensor) + -> (tensor, tensor, tensor) { + %0:2 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg1, %arg2 : tensor, tensor) { + ^bb0(%b0: f32, %b1 : f32, %b2 : f32): + %1 = arith.addf %b0, %b1 : f32 + linalg.yield %1, %1 : f32, f32 + } -> (tensor, tensor) + %2 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%0#1, %arg3 : tensor, tensor) + outs(%arg4 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %3 = arith.mulf %b0, %b1 : f32 + linalg.yield %3 : f32 + } -> tensor + return %0#0, %0#1, %2 : tensor, tensor, tensor +} +// CHECK: func @multi_use_producer( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[RESULT:.+]]:3 = linalg.generic +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1, %[[RESULT]]#2 diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -51,6 +51,38 @@ } namespace { + +/// Pattern to test fusion of producer with consumer, even if producer has +/// multiple uses. +struct TestMultiUseProducerFusion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, + PatternRewriter &rewriter) const override { + OpOperand *fusableOperand = nullptr; + for (OpOperand &operand : genericOp->getOpOperands()) { + if (linalg::areElementwiseOpsFusable(&operand)) { + fusableOperand = &operand; + break; + } + } + if (!fusableOperand) { + return rewriter.notifyMatchFailure(genericOp, "no fusable operand found"); + } + std::optional fusionResult = + linalg::fuseElementwiseOps(rewriter, fusableOperand); + if (!fusionResult) + rewriter.notifyMatchFailure(genericOp, "fusion failed"); + for (auto [origValue, replacement] : fusionResult->replacements) { + rewriter.replaceUseIf(origValue, replacement, [&](OpOperand &use) { + return use.getOwner() != genericOp.getOperation(); + }); + } + rewriter.eraseOp(genericOp); + return success(); + } +}; + struct TestLinalgElementwiseFusion : public PassWrapper> { @@ -105,6 +137,12 @@ "fusion patterns that " "collapse the iteration space of the consumer"), llvm::cl::init(false)}; + + Option fuseMultiUseProducer{ + *this, "fuse-multiuse-producer", + llvm::cl::desc("Test fusion of producer ops with multiple uses"), + llvm::cl::init(false)}; + ListOption collapseDimensions{ *this, "collapse-dimensions-control", llvm::cl::desc("Test controlling dimension collapse pattern")}; @@ -117,8 +155,9 @@ RewritePatternSet fusionPatterns(context); auto controlFn = [](OpOperand *operand) { return true; }; linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) + return signalPassFailure(); return; } @@ -127,8 +166,9 @@ linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, setFusedOpOperandLimit<4>); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) + return signalPassFailure(); return; } @@ -172,8 +212,9 @@ linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, controlReshapeFusionFn); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) + return signalPassFailure(); return; } @@ -181,7 +222,10 @@ RewritePatternSet patterns(context); linalg::populateFoldReshapeOpsByCollapsingPatterns( patterns, [](OpOperand * /*fusedOperand */) { return true; }); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + return; } if (fuseWithReshapeByCollapsingWithControlFn) { @@ -195,7 +239,19 @@ return true; }; linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + return; + } + + if (fuseMultiUseProducer) { + RewritePatternSet patterns(context); + patterns.insert(context); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + return; } if (!collapseDimensions.empty()) { @@ -209,7 +265,10 @@ }; RewritePatternSet patterns(context); linalg::populateCollapseDimensions(patterns, collapseFn); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(patterns)))) + return signalPassFailure(); + return; } } };