diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h @@ -40,9 +40,13 @@ /// Freeze the patterns held in `patterns`, and take ownership. /// `disabledPatternLabels` is a set of labels used to filter out input - /// patterns with a label in this set. `enabledPatternLabels` is a set of - /// labels used to filter out input patterns that do not have one of the - /// labels in this set. + /// patterns with a debug label or debug name in this set. + /// `enabledPatternLabels` is a set of labels used to filter out input + /// patterns that do not have one of the labels in this set. Debug labels must + /// be set explicitly on patterns or when adding them with + /// `RewritePatternSet::addWithLabel`. Debug names may be empty, but patterns + /// created with `RewritePattern::create` have their default debug name set to + /// their type name. FrozenRewritePatternSet( RewritePatternSet &&patterns, ArrayRef disabledPatternLabels = llvm::None, diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -62,8 +62,17 @@ std::unique_ptr createCanonicalizerPass(); /// Creates an instance of the Canonicalizer pass with the specified config. +/// `disabledPatterns` is a set of labels used to filter out input patterns with +/// a debug label or debug name in this set. `enabledPatterns` is a set of +/// labels used to filter out input patterns that do not have one of the labels +/// in this set. Debug labels must be set explicitly on patterns or when adding +/// them with `RewritePatternSet::addWithLabel`. Debug names may be empty, but +/// patterns created with `RewritePattern::create` have their default debug name +/// set to their type name. std::unique_ptr -createCanonicalizerPass(const GreedyRewriteConfig &config); +createCanonicalizerPass(const GreedyRewriteConfig &config, + ArrayRef disabledPatterns = llvm::None, + ArrayRef enabledPatterns = llvm::None); /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -21,7 +21,13 @@ namespace { /// Canonicalize operations in nested regions. struct Canonicalizer : public CanonicalizerBase { - Canonicalizer(const GreedyRewriteConfig &config) : config(config) {} + Canonicalizer(const GreedyRewriteConfig &config, + ArrayRef disabledPatterns, + ArrayRef enabledPatterns) + : config(config) { + this->disabledPatterns = disabledPatterns; + this->enabledPatterns = enabledPatterns; + } Canonicalizer() { // Default constructed Canonicalizer takes its settings from command line @@ -61,6 +67,9 @@ /// Creates an instance of the Canonicalizer pass with the specified config. std::unique_ptr -mlir::createCanonicalizerPass(const GreedyRewriteConfig &config) { - return std::make_unique(config); +createCanonicalizerPass(const GreedyRewriteConfig &config, + ArrayRef disabledPatterns = llvm::None, + ArrayRef enabledPatterns = llvm::None) { + return std::make_unique(config, disabledPatterns, + enabledPatterns); }