diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -444,11 +444,11 @@ /// Note: These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -bool applyPatternsGreedily(Operation *op, - const OwningRewritePatternList &patterns); +bool applyPatternsAndFoldGreedily(Operation *op, + const OwningRewritePatternList &patterns); /// Rewrite the given regions, which must be isolated from above. -bool applyPatternsGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns); +bool applyPatternsAndFoldGreedily(MutableArrayRef regions, + const OwningRewritePatternList &patterns); } // end namespace mlir #endif // MLIR_PATTERN_MATCH_H diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -268,7 +268,7 @@ // which need to be lowered further, which is not supported by a single // conversion pass. populateGpuRewritePatterns(m.getContext(), patterns); - applyPatternsGreedily(m, patterns); + applyPatternsAndFoldGreedily(m, patterns); patterns.clear(); populateStdToLLVMConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -170,7 +170,7 @@ OwningRewritePatternList patterns; auto *context = &getContext(); populateStdLegalizationPatternsForSPIRVLowering(context, patterns); - applyPatternsGreedily(getOperation()->getRegions(), patterns); + applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns); } std::unique_ptr mlir::createLegalizeStdOpsForSPIRVLoweringPass() { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1130,7 +1130,7 @@ OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), patterns); } // Convert to the LLVM IR dialect. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -276,6 +276,6 @@ OwningRewritePatternList patterns; AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext()); AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext()); - applyPatternsGreedily(f, std::move(patterns)); + applyPatternsAndFoldGreedily(f, std::move(patterns)); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -572,7 +572,7 @@ OwningRewritePatternList patterns; Operation *op = getOperation(); patterns.insert(op->getContext()); - applyPatternsGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); }; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -688,7 +688,7 @@ AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); // Just apply the patterns greedily. - applyPatternsGreedily(op, patterns); + applyPatternsAndFoldGreedily(op, patterns); } namespace { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -98,7 +98,7 @@ auto func = getFunction(); auto *context = &getContext(); patterns.insert(context); - applyPatternsGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, patterns); } std::unique_ptr> mlir::quant::createConvertConstPass() { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -131,7 +131,7 @@ auto ctx = func.getContext(); patterns.insert( ctx, &hadFailure); - applyPatternsGreedily(func, patterns); + applyPatternsAndFoldGreedily(func, patterns); if (hadFailure) signalPassFailure(); } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -32,7 +32,7 @@ op->getCanonicalizationPatterns(patterns, context); Operation *op = getOperation(); - applyPatternsGreedily(op->getRegions(), patterns); + applyPatternsAndFoldGreedily(op->getRegions(), patterns); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -551,7 +551,7 @@ // Apply the canonicalization patterns to this region. auto *node = nodesToCanonicalize[index]; - applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); + applyPatternsAndFoldGreedily(*node->getCallableRegion(), canonPatterns); // Make sure to reset the order ID for the diagnostic handler, as this // thread may be used in a different context. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -40,8 +40,8 @@ worklist.reserve(64); } - /// Perform the rewrites. Return true if the rewrite converges in - /// `maxIterations`. + /// Perform the rewrites while folding and erasing any dead ops. Return true + /// if the rewrite converges in `maxIterations`. bool simplify(MutableArrayRef regions, int maxIterations); void addToWorklist(Operation *op) { @@ -136,7 +136,7 @@ }; } // end anonymous namespace -/// Perform the rewrites. +/// Perform the rewrites while folding and erasing any dead ops. bool GreedyPatternRewriteDriver::simplify(MutableArrayRef regions, int maxIterations) { // Add the given operation to the worklist. @@ -216,14 +216,14 @@ /// the result operation regions. /// Note: This does not apply patterns to the top-level operation itself. /// -bool mlir::applyPatternsGreedily(Operation *op, - const OwningRewritePatternList &patterns) { - return applyPatternsGreedily(op->getRegions(), patterns); +bool mlir::applyPatternsAndFoldGreedily( + Operation *op, const OwningRewritePatternList &patterns) { + return applyPatternsAndFoldGreedily(op->getRegions(), patterns); } /// Rewrite the given regions, which must be isolated from above. -bool mlir::applyPatternsGreedily(MutableArrayRef regions, - const OwningRewritePatternList &patterns) { +bool mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const OwningRewritePatternList &patterns) { if (regions.empty()) return true; diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -46,7 +46,7 @@ // Verify named pattern is generated with expected name. patterns.insert(&getContext()); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; } // end anonymous namespace diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp --- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -22,7 +22,7 @@ void runOnOperation() override { OwningRewritePatternList patterns; populateGpuRewritePatterns(&getContext(), patterns); - applyPatternsGreedily(getOperation(), patterns); + applyPatternsAndFoldGreedily(getOperation(), patterns); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -40,7 +40,7 @@ // Add the generated patterns to the list. linalg::populateWithGenerated(&getContext(), &patterns); - applyPatternsGreedily(funcOp, patterns); + applyPatternsAndFoldGreedily(funcOp, patterns); // Drop the marker. funcOp.walk([](LinalgOp op) { diff --git a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp --- a/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp +++ b/mlir/test/lib/Transforms/TestVectorToLoopsConversion.cpp @@ -23,7 +23,7 @@ OwningRewritePatternList patterns; auto *context = &getContext(); populateVectorToAffineLoopsConversionPatterns(context, patterns); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -28,7 +28,7 @@ populateWithGenerated(context, &patterns); populateVectorToVectorCanonicalizationPatterns(patterns, context); populateVectorToVectorTransformationPatterns(patterns, context); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; @@ -37,7 +37,7 @@ void runOnFunction() override { OwningRewritePatternList patterns; populateVectorSlicesLoweringPatterns(patterns, &getContext()); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } }; @@ -57,7 +57,7 @@ VectorTransformsOptions options{ /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); - applyPatternsGreedily(getFunction(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); } };