diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -21,22 +21,17 @@ namespace { /// Canonicalize operations in nested regions. struct Canonicalizer : public CanonicalizerBase { + Canonicalizer() = default; Canonicalizer(const GreedyRewriteConfig &config, ArrayRef disabledPatterns, - ArrayRef enabledPatterns) - : config(config) { + ArrayRef enabledPatterns) { + this->topDownProcessingEnabled = config.useTopDownTraversal; + this->enableRegionSimplification = config.enableRegionSimplification; + this->maxIterations = config.maxIterations; this->disabledPatterns = disabledPatterns; this->enabledPatterns = enabledPatterns; } - Canonicalizer() { - // Default constructed Canonicalizer takes its settings from command line - // options. - config.useTopDownTraversal = topDownProcessingEnabled; - config.enableRegionSimplification = enableRegionSimplification; - config.maxIterations = maxIterations; - } - /// Initialize the canonicalizer by building the set of patterns used during /// execution. LogicalResult initialize(MLIRContext *context) override { @@ -51,11 +46,13 @@ return success(); } void runOnOperation() override { - (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), patterns, - config); + GreedyRewriteConfig config; + config.useTopDownTraversal = topDownProcessingEnabled; + config.enableRegionSimplification = enableRegionSimplification; + config.maxIterations = maxIterations; + (void)applyPatternsAndFoldGreedily(getOperation(), patterns, config); } - GreedyRewriteConfig config; FrozenRewritePatternSet patterns; }; } // namespace diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -pass-pipeline='func.func(canonicalize)' | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func.func(canonicalize{region-simplify=false})' | FileCheck %s --check-prefixes=CHECK,NO-RS // CHECK-LABEL: func @remove_op_with_inner_ops_pattern func.func @remove_op_with_inner_ops_pattern() { @@ -89,3 +90,15 @@ // CHECK: return %[[CST]] return %0 : i32 } + +// Check that the option to control region simplification actually works +// CHECK-LABEL: test_region_simplify +func.func @test_region_simplify() { + // CHECK-NEXT: return + // NO-RS-NEXT: ^bb1 + // NO-RS-NEXT: return + // CHECK-NEXT: } + return +^bb1: + return +}