diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -27,9 +27,12 @@ /// Canonicalize operations in nested regions. struct Canonicalizer : public impl::CanonicalizerBase { Canonicalizer() = default; + Canonicalizer(const Canonicalizer &other) + : config(other.config), patterns(other.patterns) {} Canonicalizer(const GreedyRewriteConfig &config, ArrayRef disabledPatterns, - ArrayRef enabledPatterns) { + ArrayRef enabledPatterns) + : config(config) { this->topDownProcessingEnabled = config.useTopDownTraversal; this->enableRegionSimplification = config.enableRegionSimplification; this->maxIterations = config.maxIterations; @@ -41,30 +44,31 @@ /// Initialize the canonicalizer by building the set of patterns used during /// execution. LogicalResult initialize(MLIRContext *context) override { + // Set the config from possible pass options set in the meantime. + config.useTopDownTraversal = topDownProcessingEnabled; + config.enableRegionSimplification = enableRegionSimplification; + config.maxIterations = maxIterations; + config.maxNumRewrites = maxNumRewrites; + RewritePatternSet owningPatterns(context); for (auto *dialect : context->getLoadedDialects()) dialect->getCanonicalizationPatterns(owningPatterns); for (RegisteredOperationName op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(owningPatterns, context); - patterns = FrozenRewritePatternSet(std::move(owningPatterns), - disabledPatterns, enabledPatterns); + patterns = std::make_shared( + std::move(owningPatterns), disabledPatterns, enabledPatterns); return success(); } void runOnOperation() override { - GreedyRewriteConfig config; - config.useTopDownTraversal = topDownProcessingEnabled; - config.enableRegionSimplification = enableRegionSimplification; - config.maxIterations = maxIterations; - config.maxNumRewrites = maxNumRewrites; LogicalResult converged = - applyPatternsAndFoldGreedily(getOperation(), patterns, config); + applyPatternsAndFoldGreedily(getOperation(), *patterns, config); // Canonicalization is best-effort. Non-convergence is not a pass failure. if (testConvergence && failed(converged)) signalPassFailure(); } - - FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; + std::shared_ptr patterns; }; } // namespace