diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -125,6 +125,36 @@ pattern. This will signal to the pattern driver that recursive application of this pattern may happen, and the pattern is equipped to safely handle it. +### Debug Names and Labels + +To aid in debugging, patterns may specify: a debug name (via `setDebugName`), +which should correspond to an identifier that uniquely identifies the specific +pattern; and a set of debug labels (via `addDebugLabels`), which correspond to +identifiers that uniquely identify groups of patterns. This information is used +by various utilities to aid in the debugging of pattern rewrites, e.g. in debug +logs, to provide pattern filtering, etc. A simple code example is shown below: + +```c++ +class MyPattern : public RewritePattern { +public: + /// Inherit constructors from RewritePattern. + using RewritePattern::RewritePattern; + + void initialize() { + setDebugName("MyPattern"); + addDebugLabels("MyRewritePass"); + } + + // ... +}; + +void populateMyPatterns(RewritePatternSet &patterns, MLIRContext *ctx) { + // Debug labels may also be attached to patterns during insertion. This allows + // for easily attaching common labels to groups of patterns. + patterns.addWithLabel("MyRewritePatterns", ctx); +} +``` + ### Initialization Several pieces of pattern state require explicit initialization by the pattern, @@ -311,3 +341,90 @@ Note: This driver is the one used by the [canonicalization](Canonicalization.md) [pass](Passes.md/#-canonicalize-canonicalize-operations) in MLIR. + +## Debugging + +### Pattern Filtering + +To simplify test case definition and reduction, the `FrozenRewritePatternSet` +class provides built-in support for filtering which patterns should be provided +to the pattern driver for application. Filtering behavior is specified by +providing a `disabledPatterns` and `enabledPatterns` list when constructing the +`FrozenRewritePatternSet`. The `disabledPatterns` list should contain a set of +debug names or labels for patterns that are disabled during pattern application, +i.e. which patterns should be filtered out. The `enabledPatterns` list should +contain a set of debug names or labels for patterns that are enabled during +pattern application, patterns that do not satisfy this constraint are filtered +out. Note that patterns specified by the `disabledPatterns` list will be +filtered out even if they match criteria in the `enabledPatterns` list. An +example is shown below: + +```c++ +void MyPass::initialize(MLIRContext *context) { + // No patterns are explicitly disabled. + SmallVector disabledPatterns; + // Enable only patterns with a debug name or label of `MyRewritePatterns`. + SmallVector enabledPatterns(1, "MyRewritePatterns"); + + RewritePatternSet rewritePatterns(context); + // ... + frozenPatterns = FrozenRewritePatternSet(rewritePatterns, disabledPatterns, + enabledPatterns); +} +``` + +### Common Pass Utilities + +Passes that utilize rewrite patterns should aim to provide a common set of +options and toggles to simplify the debugging experience when switching between +different passes/projects/etc. To aid in this endeavor, MLIR provides a common +set of utilities that can be easily included when defining a custom pass. These +are defined in `mlir/RewritePassUtil.td`; an example usage is shown below: + +```tablegen +def MyRewritePass : Pass<"..."> { + let summary = "..."; + let constructor = "createMyRewritePass()"; + + // Inherit the common pattern rewrite options from `RewritePassUtils`. + let options = RewritePassUtils.options; +} +``` + +#### Rewrite Pass Options + +This section documents common pass options that are useful for controlling the +behavior of rewrite pattern application. + +##### Pattern Filtering + +Two common pattern filtering options are exposed, `disable-patterns` and +`enable-patterns`, matching the behavior of the `disabledPatterns` and +`enabledPatterns` lists described in the [Pattern Filtering](#pattern-filtering) +section above. A snippet of the tablegen definition of these options is shown +below: + +```tablegen +ListOption<"disabledPatterns", "disable-patterns", "std::string", + "Labels of patterns that should be filtered out during application", + "llvm::cl::MiscFlags::CommaSeparated">, +ListOption<"enabledPatterns", "enable-patterns", "std::string", + "Labels of patterns that should be used during application, all " + "other patterns are filtered out", + "llvm::cl::MiscFlags::CommaSeparated">, +``` + +These options may be used to provide filtering behavior when constructing any +`FrozenRewritePatternSet`s within the pass: + +```c++ +void MyRewritePass::initialize(MLIRContext *context) { + RewritePatternSet rewritePatterns(context); + // ... + + // When constructing the `FrozenRewritePatternSet`, we provide the filter + // list options. + frozenPatterns = FrozenRewritePatternSet(rewritePatterns, disabledPatterns, + enabledPatterns); +} +``` diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -133,13 +133,23 @@ return contextAndHasBoundedRecursion.getPointer(); } - /// Return readable pattern name. Should only be used for debugging purposes. - /// Can be empty. + /// Return a readable name for this pattern. This name should only be used for + /// debugging purposes, and may be empty. StringRef getDebugName() const { return debugName; } - /// Set readable pattern name. Should only be used for debugging purposes. + /// Set the human readable debug name used for this pattern. This name will + /// only be used for debugging purposes. void setDebugName(StringRef name) { debugName = name; } + /// Return the set of debug labels attached to this pattern. + ArrayRef getDebugLabels() const { return debugLabels; } + + /// Add the provided debug labels to this pattern. + void addDebugLabels(ArrayRef labels) { + debugLabels.append(labels.begin(), labels.end()); + } + void addDebugLabels(StringRef label) { debugLabels.push_back(label); } + protected: /// This class acts as a special tag that makes the desire to match "any" /// operation type explicit. This helps to avoid unnecessary usages of this @@ -211,8 +221,11 @@ /// an op with this pattern. SmallVector generatedOps; - /// Readable pattern name. Can be empty. + /// A readable name for this pattern. May be empty. StringRef debugName; + + /// The set of debug labels attached to this pattern. + SmallVector debugLabels; }; //===----------------------------------------------------------------------===// @@ -906,7 +919,26 @@ // types 'Ts'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{0, (addImpl(arg, args...), 0)...}; + (void)std::initializer_list{ + 0, (addImpl(/*debugLabels=*/llvm::None, arg, args...), 0)...}; + return *this; + } + /// An overload of the above `add` method that allows for attaching a set + /// of debug labels to the attached patterns. This is useful for labeling + /// groups of patterns that may be shared between multiple different + /// passes/users. + template > + RewritePatternSet &addWithLabel(ArrayRef debugLabels, + ConstructorArg &&arg, + ConstructorArgs &&... args) { + // The following expands a call to emplace_back for each of the pattern + // types 'Ts'. This magic is necessary due to a limitation in the places + // that a parameter pack can be expanded in c++11. + // FIXME: In c++17 this can be simplified by using 'fold expressions'. + (void)std::initializer_list{ + 0, (addImpl(debugLabels, arg, args...), 0)...}; return *this; } @@ -970,7 +1002,8 @@ // types 'Ts'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{0, (addImpl(arg, args...), 0)...}; + (void)std::initializer_list{ + 0, (addImpl(/*debugLabels=*/llvm::None, arg, args...), 0)...}; return *this; } @@ -1024,13 +1057,17 @@ /// chaining insertions. template std::enable_if_t::value> - addImpl(Args &&... args) { - nativePatterns.emplace_back( - RewritePattern::create(std::forward(args)...)); + addImpl(ArrayRef debugLabels, Args &&... args) { + std::unique_ptr pattern = + RewritePattern::create(std::forward(args)...); + pattern->addDebugLabels(debugLabels); + nativePatterns.emplace_back(std::move(pattern)); } template std::enable_if_t::value> - addImpl(Args &&... args) { + addImpl(ArrayRef debugLabels, Args &&... args) { + // TODO: Add the provided labels to the PDL pattern when PDL supports + // labels. pdlPatterns.mergeIn(T(std::forward(args)...)); } diff --git a/mlir/include/mlir/Pass/PassOptions.h b/mlir/include/mlir/Pass/PassOptions.h --- a/mlir/include/mlir/Pass/PassOptions.h +++ b/mlir/include/mlir/Pass/PassOptions.h @@ -181,7 +181,13 @@ return *this; } - MutableArrayRef operator->() const { return &*this; } + /// Allow accessing the data held by this option. + MutableArrayRef operator*() { + return static_cast &>(*this); + } + ArrayRef operator*() const { + return static_cast &>(*this); + } private: /// Return the main option instance. @@ -189,6 +195,11 @@ /// Print the name and value of this option to the given stream. void print(raw_ostream &os) final { + // Don't print the list if empty. An empty option value can be treated as + // an element of the list in certain cases (e.g. ListOption). + if ((**this).empty()) + return; + os << this->ArgStr << '='; auto printElementFn = [&](const DataType &value) { printValue(os, this->getParser(), value); diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternSet.h @@ -29,9 +29,7 @@ using OpSpecificNativePatternListT = DenseMap>; - /// Freeze the patterns held in `patterns`, and take ownership. FrozenRewritePatternSet(); - FrozenRewritePatternSet(RewritePatternSet &&patterns); FrozenRewritePatternSet(FrozenRewritePatternSet &&patterns) = default; FrozenRewritePatternSet(const FrozenRewritePatternSet &patterns) = default; FrozenRewritePatternSet & @@ -40,6 +38,16 @@ operator=(FrozenRewritePatternSet &&patterns) = default; ~FrozenRewritePatternSet(); + /// Freeze the patterns held in `patterns`, and take ownership. + /// `disabledPatternLabels` is a set of labels used to filter out input + /// patterns with a label in this set. `enabledPatternLabels` is a set of + /// labels used to filter out input patterns that do not have one of the + /// lables in this set. + FrozenRewritePatternSet( + RewritePatternSet &&patterns, + ArrayRef disabledPatternLabels = llvm::None, + ArrayRef enabledPatternLabels = llvm::None); + /// Return the op specific native patterns held by this list. const OpSpecificNativePatternListT &getOpSpecificNativePatterns() const { return impl->nativeOpSpecificPatternMap; diff --git a/mlir/include/mlir/Rewrite/PassUtil.td b/mlir/include/mlir/Rewrite/PassUtil.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Rewrite/PassUtil.td @@ -0,0 +1,36 @@ +//===-- PassUtil.td - Utilities for rewrite passes ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains several utilities for passes that utilize rewrite +// patterns. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REWRITE_PASSUTIL_TD_ +#define MLIR_REWRITE_PASSUTIL_TD_ + +include "mlir/Pass/PassBase.td" + +def RewritePassUtils { + // A set of options commonly options used for pattern rewrites. + list