diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -441,11 +441,11 @@ /// Note: These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -bool applyPatternsGreedily(Operation *op, - const OwningRewritePatternList &patterns); +bool applyPatternsAndFoldGreedily(Operation *op, + const OwningRewritePatternList &patterns); /// Rewrite the given regions, which must be isolated from above. -bool applyPatternsGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns); +bool applyPatternsAndFoldGreedily(MutableArrayRef regions, + const OwningRewritePatternList &patterns); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -272,7 +272,7 @@ // which need to be lowered further, which is not supported by a single // conversion pass. populateGpuRewritePatterns(m.getContext(), patterns); - applyPatternsGreedily(m, patterns); + applyPatternsAndFoldGreedily(m, patterns); patterns.clear(); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -173,7 +173,7 @@ OwningRewritePatternList patterns; auto *context = &getContext(); populateStdLegalizationPatternsForSPIRVLowering(context, patterns); - applyPatternsGreedily(getOperation()->getRegions(), patterns); + applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); } std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1134,7 +1134,7 @@ OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getModule(), patterns); + applyPatternsAndFoldGreedily(getModule(), patterns); } // Convert to the LLVM IR dialect. diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -368,7 +368,7 @@ OwningRewritePatternList patterns; auto *context = &getContext(); patterns.insert(context); - applyPatternsGreedily(fn, patterns); + applyPatternsAndFoldGreedily(fn, patterns); } std::unique_ptr> @@ -385,7 +385,7 @@ OwningRewritePatternList patterns; auto *context = &getContext(); patterns.insert(context); - applyPatternsGreedily(fn, patterns); + applyPatternsAndFoldGreedily(fn, patterns); } std::unique_ptr> diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -576,7 +576,7 @@ OwningRewritePatternList patterns; Operation *op = getOperation(); patterns.insert(op->getContext()); - applyPatternsGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); }; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -689,7 +689,7 @@ AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); // Just apply the patterns greedily. - applyPatternsGreedily(op, patterns); + applyPatternsAndFoldGreedily(op, patterns); } namespace { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -102,7 +102,7 @@ auto func = getFunction(); auto *context = &getContext(); patterns.insert(context); - applyPatternsGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, patterns); } std::unique_ptr> mlir::quant::createConvertConstPass() { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -135,7 +135,7 @@ auto ctx = func.getContext(); patterns.insert( ctx, &hadFailure); - applyPatternsGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, patterns); if (hadFailure) signalPassFailure(); } diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -57,7 +57,7 @@ patterns.insert, RemoveIdentityOpRewrite, RemoveIdentityOpRewrite>(context); - applyPatternsGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, patterns); } std::unique_ptr> diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -35,7 +35,7 @@ op->getCanonicalizationPatterns(patterns, context); Operation *op = getOperation(); - applyPatternsGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -551,7 +551,7 @@ // Apply the canonicalization patterns to this region. auto *node = nodesToCanonicalize[index]; - applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); + applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); // Make sure to reset the order ID for the diagnostic handler, as this // thread may be used in a different context. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -40,8 +40,8 @@ worklist.reserve(64); } - /// Perform the rewrites. Return true if the rewrite converges in - /// `maxIterations`. + /// Perform the rewrites while folding and erasing any dead ops. Return true + /// if the rewrite converges in `maxIterations`. bool simplify(MutableArrayRef regions, int maxIterations); void addToWorklist(Operation *op) { @@ -136,7 +136,7 @@ }; } // end anonymous namespace -/// Perform the rewrites. +/// Perform the rewrites while folding and erasing any dead ops. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { // Add the given operation to the worklist. @@ -215,14 +215,14 @@ /// the result operation regions. /// Note: This does not apply patterns to the top-level operation itself. /// -bool mlir::applyPatternsGreedily(Operation *op, - const OwningRewritePatternList &patterns) { - return applyPatternsGreedily(op->getRegions(), patterns); +bool mlir::applyPatternsAndFoldGreedily( + Operation *op, const OwningRewritePatternList &patterns) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns); } /// Rewrite the given regions, which must be isolated from above. -bool mlir::applyPatternsGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns) { +bool mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const OwningRewritePatternList &patterns) { if (regions.empty()) return true; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -46,7 +46,7 @@ // Verify named pattern is generated with expected name. patterns.insert(&getContext()); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; } // end anonymous namespace diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -22,7 +22,7 @@ void runOnModule() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); - applyPatternsGreedily(getModule(), patterns); + applyPatternsAndFoldGreedily(getModule(), patterns); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -39,7 +39,7 @@ // Add the generated patterns to the list. linalg::populateWithGenerated(&getContext(), &patterns); - applyPatternsGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, patterns); // Drop the marker. funcOp.walk([](LinalgOp op) { diff --git a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp --- a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp +++ b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp @@ -23,7 +23,7 @@ OwningRewritePatternList patterns; auto *context = &getContext(); populateVectorToAffineLoopsConversionPatterns(context, patterns); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -28,7 +28,7 @@ populateWithGenerated(context, &patterns); populateVectorToVectorCanonicalizationPatterns(patterns, context); populateVectorToVectorTransformationPatterns(patterns, context); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; @@ -37,7 +37,7 @@ void runOnFunction() override { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; @@ -57,7 +57,7 @@ VectorTransformsOptions options{ /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } };