diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -18,8 +18,9 @@ namespace mlir { -/// This struct allows control over how the GreedyPatternRewriteDriver works. -struct GreedyRewriteConfig { +/// This class allows control over how the GreedyPatternRewriteDriver works. +class GreedyRewriteConfig { +public: /// This specifies the order of initial traversal that populates the rewriters /// worklist. When set to true, it walks the operations top-down, which is /// generally more efficient in compile time. When set to false, its initial diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -23,6 +23,7 @@ namespace mlir { class AffineForOp; +class GreedyRewriteConfig; //===----------------------------------------------------------------------===// // Passes @@ -60,9 +61,14 @@ /// Creates a pass that converts memref function results to out-params. std::unique_ptr createBufferResultsToOutParamsPass(); -/// Creates an instance of the Canonicalizer pass. +/// Creates an instance of the Canonicalizer pass, configured with default +/// settings (which can be overridden by pass options on the command line). std::unique_ptr createCanonicalizerPass(); +/// Creates an instance of the Canonicalizer pass with the specified config. +std::unique_ptr +createCanonicalizerPass(const GreedyRewriteConfig &config); + /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -365,6 +365,12 @@ let options = [ Option<"topDownProcessingEnabled", "top-down", "bool", /*default=*/"false", + "Seed the worklist in general top-down order">, + Option<"enableRegionSimplification", "region-simplify", "bool", + /*default=*/"true", + "Seed the worklist in general top-down order">, + Option<"maxIterations", "max-iterations", "unsigned", + /*default=*/"10", "Seed the worklist in general top-down order"> ]; } diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -21,6 +21,16 @@ namespace { /// Canonicalize operations in nested regions. struct Canonicalizer : public CanonicalizerBase { + Canonicalizer(const GreedyRewriteConfig &config) : config(config) {} + + Canonicalizer() { + // Default constructed Canonicalizer takes its settings from command line + // options. + config.useTopDownTraversal = topDownProcessingEnabled; + config.enableRegionSimplification = enableRegionSimplification; + config.maxIterations = maxIterations; + } + /// Initialize the canonicalizer by building the set of patterns used during /// execution. LogicalResult initialize(MLIRContext *context) override { @@ -31,12 +41,11 @@ return success(); } void runOnOperation() override { - GreedyRewriteConfig config; - config.useTopDownTraversal = topDownProcessingEnabled; (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns, config); } + GreedyRewriteConfig config; FrozenRewritePatternSet patterns; }; } // end anonymous namespace @@ -45,3 +54,9 @@ std::unique_ptr mlir::createCanonicalizerPass() { return std::make_unique(); } + +/// Creates an instance of the Canonicalizer pass with the specified config. +std::unique_ptr +createCanonicalizerPass(const GreedyRewriteConfig &config) { + return std::make_unique(config); +}