diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -125,6 +125,37 @@ pattern. This will signal to the pattern driver that recursive application of this pattern may happen, and the pattern is equipped to safely handle it. +### Initialization + +Several pieces of pattern state require explicit initialization by the pattern, +for example setting `setHasBoundedRewriteRecursion` if a pattern safely handles +recursive application. This pattern state can be initialized either in the +constructor of the pattern or via the utility `initialize` hook. Using the +`initialize` hook removes the need to redefine pattern constructors just to +inject additional pattern state initialization. An example is shown below: + +```c++ +class MyPattern : public RewritePattern { +public: + /// Inherit the constructors from RewritePattern. + using RewritePattern::RewritePattern; + + /// Initialize the pattern. + void initialize() { + /// Signal that this pattern safely handles recursive application. + setHasBoundedRewriteRecursion(); + } + + // ... +}; +``` + +### Construction + +Constructing a RewritePattern should be performed by using the static +`RewritePattern::create` utility method. This method ensures that the pattern +is properly initialized and prepared for insertion into a `RewritePatternSet`. + ## Pattern Rewriter A `PatternRewriter` is a special class that allows for a pattern to communicate diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -255,10 +255,43 @@ return failure(); } + /// This method provides a convenient interface for creating and initializing + /// derived rewrite patterns of the given type `T`. + template + static std::unique_ptr create(Args &&... args) { + std::unique_ptr pattern = + std::make_unique(std::forward(args)...); + initializePattern(*pattern); + + // Set a default debug name if one wasn't provided. + if (pattern->getDebugName().empty()) + pattern->setDebugName(llvm::getTypeName()); + return pattern; + } + protected: /// Inherit the base constructors from `Pattern`. using Pattern::Pattern; +private: + /// Trait to check if T provides a `getOperationName` method. + template + using has_initialize = decltype(std::declval().initialize()); + template + using detect_has_initialize = llvm::is_detected; + + /// Initialize the derived pattern by calling its `initialize` method. + template + static std::enable_if_t::value> + initializePattern(T &pattern) { + pattern.initialize(); + } + /// Empty derived pattern initializer for patterns that do not have an + /// initialize method. + template + static std::enable_if_t::value> + initializePattern(T &) {} + /// An anchor for the virtual table. virtual void anchor(); }; @@ -992,13 +1025,8 @@ template std::enable_if_t::value> addImpl(Args &&... args) { - auto pattern = std::make_unique(std::forward(args)...); - - // Pattern can potentially set name in ctor. Preserve old name if present. - if (pattern->getDebugName().empty()) - pattern->setDebugName(llvm::getTypeName()); - - nativePatterns.emplace_back(std::move(pattern)); + nativePatterns.emplace_back( + RewritePattern::create(std::forward(args)...)); } template std::enable_if_t::value> diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -935,8 +935,9 @@ class VectorInsertStridedSliceOpSameRankRewritePattern : public OpRewritePattern { public: - VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) - : OpRewritePattern(ctx) { + using OpRewritePattern::OpRewritePattern; + + void initialize() { // This pattern creates recursive InsertStridedSliceOp, but the recursion is // bounded as the rank is strictly decreasing. setHasBoundedRewriteRecursion(); @@ -1330,8 +1331,9 @@ class VectorExtractStridedSliceOpConversion : public OpRewritePattern { public: - VectorExtractStridedSliceOpConversion(MLIRContext *ctx) - : OpRewritePattern(ctx) { + using OpRewritePattern::OpRewritePattern; + + void initialize() { // This pattern creates recursive ExtractStridedSliceOp, but the recursion // is bounded as the rank is strictly decreasing. setHasBoundedRewriteRecursion(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -473,8 +473,9 @@ /// bounded recursion. struct TestBoundedRecursiveRewrite : public OpRewritePattern { - TestBoundedRecursiveRewrite(MLIRContext *ctx) - : OpRewritePattern(ctx) { + using OpRewritePattern::OpRewritePattern; + + void initialize() { // The conversion target handles bounding the recursion of this pattern. setHasBoundedRewriteRecursion(); }