diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -715,18 +715,20 @@ //===--------------------------------------------------------------------===// /// This class represents a StringSwitch like class that is useful for parsing - /// expected keywords. On construction, it invokes `parseKeyword` and - /// processes each of the provided cases statements until a match is hit. The - /// provided `ResultT` must be assignable from `failure()`. + /// expected keywords. On construction, unless a non-empty keyword is + /// provided, it invokes `parseKeyword` and processes each of the provided + /// cases statements until a match is hit. The provided `ResultT` must be + /// assignable from `failure()`. template class KeywordSwitch { public: - KeywordSwitch(AsmParser &parser) + KeywordSwitch(AsmParser &parser, StringRef *keyword = nullptr) : parser(parser), loc(parser.getCurrentLocation()) { - if (failed(parser.parseKeywordOrCompletion(&keyword))) + if (keyword && !keyword->empty()) + this->keyword = *keyword; + else if (failed(parser.parseKeywordOrCompletion(&this->keyword))) result = failure(); } - /// Case that uses the provided value when true. KeywordSwitch &Case(StringLiteral str, ResultT value) { return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -29,7 +29,8 @@ Canonicalizer() = default; Canonicalizer(const GreedyRewriteConfig &config, ArrayRef disabledPatterns, - ArrayRef enabledPatterns) { + ArrayRef enabledPatterns) + : config(config) { this->topDownProcessingEnabled = config.useTopDownTraversal; this->enableRegionSimplification = config.enableRegionSimplification; this->maxIterations = config.maxIterations; @@ -41,30 +42,31 @@ /// Initialize the canonicalizer by building the set of patterns used during /// execution. LogicalResult initialize(MLIRContext *context) override { + // Set the config from possible pass options set in the meantime. + config.useTopDownTraversal = topDownProcessingEnabled; + config.enableRegionSimplification = enableRegionSimplification; + config.maxIterations = maxIterations; + config.maxNumRewrites = maxNumRewrites; + RewritePatternSet owningPatterns(context); for (auto *dialect : context->getLoadedDialects()) dialect->getCanonicalizationPatterns(owningPatterns); for (RegisteredOperationName op : context->getRegisteredOperations()) op.getCanonicalizationPatterns(owningPatterns, context); - patterns = FrozenRewritePatternSet(std::move(owningPatterns), - disabledPatterns, enabledPatterns); + patterns = std::make_shared( + std::move(owningPatterns), disabledPatterns, enabledPatterns); return success(); } void runOnOperation() override { - GreedyRewriteConfig config; - config.useTopDownTraversal = topDownProcessingEnabled; - config.enableRegionSimplification = enableRegionSimplification; - config.maxIterations = maxIterations; - config.maxNumRewrites = maxNumRewrites; LogicalResult converged = - applyPatternsAndFoldGreedily(getOperation(), patterns, config); + applyPatternsAndFoldGreedily(getOperation(), *patterns, config); // Canonicalization is best-effort. Non-convergence is not a pass failure. if (testConvergence && failed(converged)) signalPassFailure(); } - - FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; + std::shared_ptr patterns; }; } // namespace