diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -15,8 +15,10 @@ #define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ #include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir { +class SimplifyRegionsConfig; /// This class allows control over how the GreedyPatternRewriteDriver works. class GreedyRewriteConfig { @@ -32,6 +34,9 @@ // patterns. bool enableRegionSimplification = true; + /// Finer control on the regions simplification. + SimplifyRegionsConfig simplifyRegionsConfig; + /// This specifies the maximum number of times the rewriter will iterate /// between applying patterns and simplifying regions. unsigned maxIterations = 10; diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h --- a/mlir/include/mlir/Transforms/RegionUtils.h +++ b/mlir/include/mlir/Transforms/RegionUtils.h @@ -17,6 +17,19 @@ namespace mlir { class RewriterBase; +/// This class allows control over how the simplyRegions works. +class SimplifyRegionsConfig { +public: + /// Erase the unreachable blocks within the provided regions. + bool eraseUnreachableBlocks = true; + /// Performs a simple dead code elimination algorithm over the + /// given regions. + bool eliminateDeadOpsOrArgs = true; + /// Identify identical blocks within the given region and merge them, + /// inserting new block arguments as necessary. + bool mergeIdenticalBlocks = true; +}; + /// Check if all values in the provided range are defined above the `limit` /// region. That is, if they are defined in a region that is a proper ancestor /// of `limit`. @@ -57,7 +70,8 @@ /// of the regions were simplified, failure otherwise. The provided rewriter is /// used to notify callers of operation and block deletion. LogicalResult simplifyRegions(RewriterBase &rewriter, - MutableArrayRef regions); + MutableArrayRef regions, + const SimplifyRegionsConfig &config); } // namespace mlir diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -221,7 +221,8 @@ // After applying patterns, make sure that the CFG of each of the regions // is kept up to date. if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, regions)); + changed |= succeeded( + simplifyRegions(*this, regions, config.simplifyRegionsConfig)); } while (changed && ++iteration < config.maxIterations); // Whether the rewrite converges, i.e. wasn't changed in the last iteration. diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -706,11 +706,18 @@ /// elimination, as well as some other DCE. This function returns success if any /// of the regions were simplified, failure otherwise. LogicalResult mlir::simplifyRegions(RewriterBase &rewriter, - MutableArrayRef regions) { - bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); - bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); - bool mergedIdenticalBlocks = - succeeded(mergeIdenticalBlocks(rewriter, regions)); + MutableArrayRef regions, + const SimplifyRegionsConfig &config) { + bool eliminatedBlocks = true; + bool eliminatedOpsOrArgs = true; + bool mergedIdenticalBlocks = true; + + if (config.eraseUnreachableBlocks) + eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions)); + if (config.eliminateDeadOpsOrArgs) + eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions)); + if (config.mergeIdenticalBlocks) + mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions)); return success(eliminatedBlocks || eliminatedOpsOrArgs || mergedIdenticalBlocks); }