diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -62,8 +62,13 @@ std::unique_ptr createCanonicalizerPass(); /// Creates an instance of the Canonicalizer pass with the specified config. +/// `disabledPatterns` is a set of labels used to filter out input patterns with +/// a label in this set. `enabledPatterns` is a set of labels used to filter out +/// input patterns that do not have one of the labels in this set. std::unique_ptr -createCanonicalizerPass(const GreedyRewriteConfig &config); +createCanonicalizerPass(const GreedyRewriteConfig &config, + ArrayRef disabledPatterns = llvm::None, + ArrayRef enabledPatterns = llvm::None); /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -21,7 +21,13 @@ namespace { /// Canonicalize operations in nested regions. struct Canonicalizer : public CanonicalizerBase { - Canonicalizer(const GreedyRewriteConfig &config) : config(config) {} + Canonicalizer(const GreedyRewriteConfig &config, + ArrayRef disabledPatterns, + ArrayRef enabledPatterns) + : config(config) { + this->disabledPatterns = disabledPatterns; + this->enabledPatterns = enabledPatterns; + } Canonicalizer() { // Default constructed Canonicalizer takes its settings from command line @@ -61,6 +67,9 @@ /// Creates an instance of the Canonicalizer pass with the specified config. std::unique_ptr -mlir::createCanonicalizerPass(const GreedyRewriteConfig &config) { - return std::make_unique(config); +createCanonicalizerPass(const GreedyRewriteConfig &config, + ArrayRef disabledPatterns = llvm::None, + ArrayRef enabledPatterns = llvm::None) { + return std::make_unique(config, disabledPatterns, + enabledPatterns); }