diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h @@ -18,34 +18,52 @@ /// This class represents a frozen set of patterns that can be processed by a /// pattern applicator. This class is designed to enable caching pattern lists -/// such that they need not be continuously recomputed. +/// such that they need not be continuously recomputed. Note that all copies of +/// this class share the same compiled pattern list, allowing for a reduction in +/// the number of duplicated patterns that need to be created. class FrozenRewritePatternList { using NativePatternListT = std::vector>; public: /// Freeze the patterns held in `patterns`, and take ownership. + FrozenRewritePatternList(); FrozenRewritePatternList(OwningRewritePatternList &&patterns); - FrozenRewritePatternList(FrozenRewritePatternList &&patterns); + FrozenRewritePatternList(FrozenRewritePatternList &&patterns) = default; + FrozenRewritePatternList(const FrozenRewritePatternList &patterns) = default; + FrozenRewritePatternList & + operator=(const FrozenRewritePatternList &patterns) = default; + FrozenRewritePatternList & + operator=(FrozenRewritePatternList &&patterns) = default; ~FrozenRewritePatternList(); /// Return the native patterns held by this list. iterator_range> getNativePatterns() const { + const NativePatternListT &nativePatterns = impl->nativePatterns; return llvm::make_pointee_range(nativePatterns); } /// Return the compiled PDL bytecode held by this list. Returns null if /// there are no PDL patterns within the list. const detail::PDLByteCode *getPDLByteCode() const { - return pdlByteCode.get(); + return impl->pdlByteCode.get(); } private: - /// The set of. - std::vector> nativePatterns; + /// The internal implementation of the frozen pattern list. + struct Impl { + /// The set of native C++ rewrite patterns. + NativePatternListT nativePatterns; - /// The bytecode containing the compiled PDL patterns. - std::unique_ptr pdlByteCode; + /// The bytecode containing the compiled PDL patterns. + std::unique_ptr pdlByteCode; + }; + + /// A pointer to the internal pattern list. This uses a shared_ptr to avoid + /// the need to compile the same pattern list multiple times. For example, + /// during multi-threaded pass execution, all copies of a pass can share the + /// same pattern list. + std::shared_ptr impl; }; } // end namespace mlir diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -50,12 +50,16 @@ // FrozenRewritePatternList //===----------------------------------------------------------------------===// +FrozenRewritePatternList::FrozenRewritePatternList() + : impl(std::make_shared()) {} + FrozenRewritePatternList::FrozenRewritePatternList( OwningRewritePatternList &&patterns) - : nativePatterns(std::move(patterns.getNativePatterns())) { - PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); + : impl(std::make_shared()) { + impl->nativePatterns = std::move(patterns.getNativePatterns()); // Generate the bytecode for the PDL patterns if any were provided. + PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); ModuleOp pdlModule = pdlPatterns.getModule(); if (!pdlModule) return; @@ -64,14 +68,9 @@ "failed to lower PDL pattern module to the PDL Interpreter"); // Generate the pdl bytecode. - pdlByteCode = std::make_unique( + impl->pdlByteCode = std::make_unique( pdlModule, pdlPatterns.takeConstraintFunctions(), pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions()); } -FrozenRewritePatternList::FrozenRewritePatternList( - FrozenRewritePatternList &&patterns) - : nativePatterns(std::move(patterns.nativePatterns)), - pdlByteCode(std::move(patterns.pdlByteCode)) {} - FrozenRewritePatternList::~FrozenRewritePatternList() {}