diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -108,7 +108,7 @@ ```mlir // Apply `myConstraint` to the entities defined by `input`, `attr`, and // `op`. - pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + pdl_interp.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) -> ^matchDest, ^failureDest ``` }]; @@ -316,7 +316,7 @@ Example: ```mlir - pdl_interp.check_type %type is 0 -> ^matchDest, ^failureDest + pdl_interp.check_type %type is i32 -> ^matchDest, ^failureDest ``` }]; @@ -338,7 +338,7 @@ Example: ```mlir - pdl_interp.create_attribute 10 : i64 + %attr = pdl_interp.create_attribute 10 : i64 ``` }]; @@ -369,7 +369,7 @@ Example: ```mlir - %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1) : !pdl.attribute + %ret = pdl_interp.create_native "myNativeFunc"[42, "gt"](%arg0, %arg1 : !pdl.value, !pdl.value) : !pdl.attribute ``` }]; @@ -771,7 +771,7 @@ Example: ```mlir - pdl_interp.switch_attribute %attr to [10, true] -> ^10Dest, ^trueDest, ^defaultDest + pdl_interp.switch_attribute %attr to [10, true](^10Dest, ^trueDest) -> ^defaultDest ``` }]; let arguments = (ins PDL_Attribute:$attribute, ArrayAttr:$caseValues); @@ -836,7 +836,7 @@ Example: ```mlir - pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"] -> ^fooDest, ^barDest, ^defaultDest + pdl_interp.switch_operation_name of %op to ["foo.op", "bar.op"](^fooDest, ^barDest) -> ^defaultDest ``` }]; @@ -873,7 +873,7 @@ Example: ```mlir - pdl_interp.switch_result_count of %op to [0, 2] -> ^0Dest, ^2Dest, ^defaultDest + pdl_interp.switch_result_count of %op to [0, 2](^0Dest, ^2Dest) -> ^defaultDest ``` }]; diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -58,6 +58,7 @@ SuccessorRange, BlockOperand *, Block *, Block *, Block *> { public: using RangeBaseT::RangeBaseT; + SuccessorRange(); SuccessorRange(Block *block); SuccessorRange(Operation *term); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -66,6 +66,9 @@ /// otherwise false. bool isRegistered() { return getAbstractOperation(); } + /// Remove the operation from its parent block, but don't delete it. + void removeFromParent(); + /// Remove this operation from its parent block and delete it. void erase(); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -311,7 +311,7 @@ void *getAsOpaquePointer() const { return static_cast(representation.getOpaqueValue()); } - static OperationName getFromOpaquePointer(void *pointer); + static OperationName getFromOpaquePointer(const void *pointer); private: RepresentationUnion representation; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -10,6 +10,7 @@ #define MLIR_PATTERNMATCHER_H #include "mlir/IR/Builders.h" +#include "mlir/IR/Module.h" namespace mlir { @@ -225,6 +226,189 @@ } }; +//===----------------------------------------------------------------------===// +// PDLPatternModule +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// PDLValue + +/// Storage type of byte-code interpreter values. These are passed to constraint +/// functions as arguments. +class PDLValue { + /// The internal implementation type when the value is an Attribute, + /// Operation*, or Type. See `impl` below for more details. + using AttrOpTypeImplT = llvm::PointerUnion; + +public: + PDLValue(const PDLValue &other) : impl(other.impl) {} + PDLValue(std::nullptr_t = nullptr) : impl() {} + PDLValue(Attribute value) : impl(value) {} + PDLValue(Operation *value) : impl(value) {} + PDLValue(Type value) : impl(value) {} + PDLValue(Value value) : impl(value) {} + + /// Returns true if the type of the held value is `T`. + template + std::enable_if_t::value, bool> isa() const { + return impl.is(); + } + template + std::enable_if_t::value, bool> isa() const { + auto attrOpTypeImpl = impl.dyn_cast(); + return attrOpTypeImpl && attrOpTypeImpl.is(); + } + + /// Attempt to dynamically cast this value to type `T`, returns null if this + /// value is not an instance of `T`. + template + std::enable_if_t::value, T> dyn_cast() const { + return impl.dyn_cast(); + } + template + std::enable_if_t::value, T> dyn_cast() const { + auto attrOpTypeImpl = impl.dyn_cast(); + return attrOpTypeImpl && attrOpTypeImpl.dyn_cast(); + } + + /// Cast this value to type `T`, asserts if this value is not an instance of + /// `T`. + template + std::enable_if_t::value, T> cast() const { + return impl.get(); + } + template + std::enable_if_t::value, T> cast() const { + return impl.get().get(); + } + + /// Get an opaque pointer to the value. + void *getAsOpaquePointer() { return impl.getOpaqueValue(); } + + /// Print this value to the provided output stream. + void print(raw_ostream &os); + +private: + /// The internal opaque representation of a PDLValue. We use a nested + /// PointerUnion structure here because `Value` only has 1 low bit + /// available, where as the remaining types all have 3. + llvm::PointerUnion impl; +}; + +inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { + value.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// PDLPatternModule + +/// A generic PDL pattern constraint function. This function applies a +/// constraint to a given set of opaque PDLValue entities. The second parameter +/// is a set of constant value parameters specified in Attribute form. Returns +/// success if the constraint successfully held, failure otherwise. +using PDLConstraintFunction = std::function, ArrayAttr, PatternRewriter &)>; +/// A generic PDL pattern constraint function. This function applies a +/// constraint to a given opaque PDLValue entity. The second parameter is a set +/// of constant value parameters specified in Attribute form. Returns success if +/// the constraint successfully held, failure otherwise. +using PDLSingleEntityConstraintFunction = + std::function; +/// A native PDL creation function. This function creates a new PDLValue given +/// a set of existing PDL values, a set of constant parameters specified in +/// Attribute form, and a PatternRewriter. Returns the newly created PDLValue. +using PDLCreateFunction = + std::function, ArrayAttr, PatternRewriter &)>; +/// A native PDL rewrite function. This function rewrites the given root +/// operation using the provided PatternRewriter. This method is only invoked +/// when the corresponding match was successful. +using PDLRewriteFunction = std::function, + ArrayAttr, PatternRewriter &)>; + +/// This class contains all of the necessary data for a set of PDL patterns, or +/// pattern rewrites specified in the form of the PDL dialect. This PDL module +/// contained by this pattern may contain any number of `pdl.pattern` +/// operations. +class PDLPatternModule { +public: + PDLPatternModule() = default; + + /// Construct a PDL pattern with the given module. + PDLPatternModule(OwningModuleRef pdlModule) + : pdlModule(std::move(pdlModule)) {} + + /// Merge the state in `other` into this pattern module. + void mergeIn(PDLPatternModule &&other); + + /// Return the internal PDL module of this pattern. + ModuleOp getModule() { return pdlModule.get(); } + + //===--------------------------------------------------------------------===// + // Function Registry + + /// Register a constraint function. + void registerConstraintFunction(StringRef name, + PDLConstraintFunction constraintFn); + /// Register a single entity constraint function. + template + std::enable_if_t, + ArrayAttr, PatternRewriter &>::value> + registerConstraintFunction(StringRef name, SingleEntityFn &&constraintFn) { + registerConstraintFunction(name, [=](ArrayRef values, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + assert(values.size() == 1 && "expected values to have a single entity"); + return constraintFn(values[0], constantParams, rewriter); + }); + } + + /// Register a creation function. + void registerCreateFunction(StringRef name, PDLCreateFunction createFn); + + /// Register a rewrite function. + void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); + + /// Return the set of the registered constraint functions. + const llvm::StringMap &getConstraintFunctions() const { + return constraintFunctions; + } + llvm::StringMap takeConstraintFunctions() { + return constraintFunctions; + } + /// Return the set of the registered create functions. + const llvm::StringMap &getCreateFunctions() const { + return createFunctions; + } + llvm::StringMap takeCreateFunctions() { + return createFunctions; + } + /// Return the set of the registered rewrite functions. + const llvm::StringMap &getRewriteFunctions() const { + return rewriteFunctions; + } + llvm::StringMap takeRewriteFunctions() { + return rewriteFunctions; + } + + /// Clear out the patterns and functions within this module. + void clear() { + pdlModule = nullptr; + constraintFunctions.clear(); + createFunctions.clear(); + rewriteFunctions.clear(); + } + +private: + /// The module containing the `pdl.pattern` operations. + OwningModuleRef pdlModule; + + /// The external functions referenced from within the pdl module. + llvm::StringMap constraintFunctions; + llvm::StringMap createFunctions; + llvm::StringMap rewriteFunctions; +}; + //===----------------------------------------------------------------------===// // PatternRewriter //===----------------------------------------------------------------------===// @@ -421,28 +605,28 @@ //===----------------------------------------------------------------------===// class OwningRewritePatternList { - using PatternListT = std::vector>; + using NativePatternListT = std::vector>; public: OwningRewritePatternList() = default; - /// Construct a OwningRewritePatternList populated with the pattern `t` of - /// type `T`. - template - OwningRewritePatternList(T &&t) { - patterns.emplace_back(std::make_unique(std::forward(t))); + /// Construct a OwningRewritePatternList populated with the given pattern. + OwningRewritePatternList(std::unique_ptr pattern) { + nativePatterns.emplace_back(std::move(pattern)); } + OwningRewritePatternList(PDLPatternModule &&pattern) + : pdlPatterns(std::move(pattern)) {} + + /// Return the native patterns held in this list. + NativePatternListT &getNativePatterns() { return nativePatterns; } - PatternListT::iterator begin() { return patterns.begin(); } - PatternListT::iterator end() { return patterns.end(); } - PatternListT::const_iterator begin() const { return patterns.begin(); } - PatternListT::const_iterator end() const { return patterns.end(); } - PatternListT::size_type size() const { return patterns.size(); } - void clear() { patterns.clear(); } + /// Return the PDL patterns held in this list. + PDLPatternModule &getPDLPatterns() { return pdlPatterns; } - /// Take ownership of the patterns held by this list. - std::vector> takePatterns() { - return std::move(patterns); + /// Clear out all of the held patterns in this list. + void clear() { + nativePatterns.clear(); + pdlPatterns.clear(); } //===--------------------------------------------------------------------===// @@ -456,31 +640,53 @@ typename... ConstructorArgs, typename = std::enable_if_t> OwningRewritePatternList &insert(ConstructorArg &&arg, - ConstructorArgs &&... args) { + ConstructorArgs &&...args) { // The following expands a call to emplace_back for each of the pattern // types 'Ts'. This magic is necessary due to a limitation in the places // that a parameter pack can be expanded in c++11. // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{ - 0, (patterns.emplace_back(std::make_unique(arg, args...)), 0)...}; + (void)std::initializer_list{0, (insertImpl(arg, args...), 0)...}; return *this; } /// Add an instance of each of the pattern types 'Ts'. Return a reference to /// `this` for chaining insertions. template OwningRewritePatternList &insert() { - (void)std::initializer_list{ - 0, (patterns.emplace_back(std::make_unique()), 0)...}; + (void)std::initializer_list{0, (insertImpl(), 0)...}; return *this; } - /// Add the given pattern to the pattern list. - void insert(std::unique_ptr pattern) { - patterns.emplace_back(std::move(pattern)); + /// Add the given native pattern to the pattern list. Return a reference to + /// `this` for chaining insertions. + OwningRewritePatternList &insert(std::unique_ptr pattern) { + nativePatterns.emplace_back(std::move(pattern)); + return *this; + } + + /// Add the given pdl pattern to the pattern list. Return a reference to + /// `this` for chaining insertions. + OwningRewritePatternList &insert(PDLPatternModule &&pattern) { + pdlPatterns.mergeIn(std::move(pattern)); + return *this; } private: - PatternListT patterns; + /// Add an instance of the pattern type 'T'. Return a reference to `this` for + /// chaining insertions. + template + std::enable_if_t::value> + insertImpl(Args &&...args) { + nativePatterns.emplace_back( + std::make_unique(std::forward(args)...)); + } + template + std::enable_if_t::value> + insertImpl(Args &&...args) { + pdlPatterns.mergeIn(T(std::forward(args)...)); + } + + NativePatternListT nativePatterns; + PDLPatternModule pdlPatterns; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -104,6 +104,12 @@ return UniquerT::template get(loc.getContext(), args...); } + /// Get an instance of the concrete type from a void pointer. + static ConcreteT getFromOpaquePointer(const void *ptr) { + return ptr ? BaseT::getFromOpaquePointer(ptr).template cast() + : nullptr; + } + protected: /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. diff --git a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h --- a/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h +++ b/mlir/include/mlir/Rewrite/FrozenRewritePatternList.h @@ -12,25 +12,40 @@ #include "mlir/IR/PatternMatch.h" namespace mlir { +namespace detail { +class PDLByteCode; +} // end namespace detail + /// This class represents a frozen set of patterns that can be processed by a /// pattern applicator. This class is designed to enable caching pattern lists /// such that they need not be continuously recomputed. class FrozenRewritePatternList { - using PatternListT = std::vector>; + using NativePatternListT = std::vector>; public: /// Freeze the patterns held in `patterns`, and take ownership. FrozenRewritePatternList(OwningRewritePatternList &&patterns); + FrozenRewritePatternList(FrozenRewritePatternList &&patterns); + ~FrozenRewritePatternList(); + + /// Return the native patterns held by this list. + iterator_range> + getNativePatterns() const { + return llvm::make_pointee_range(nativePatterns); + } - /// Return the patterns held by this list. - iterator_range> - getPatterns() const { - return llvm::make_pointee_range(patterns); + /// Return the compiled PDL bytecode held by this list. Returns null if + /// there are no PDL patterns within the list. + const detail::PDLByteCode *getPDLByteCode() const { + return pdlByteCode.get(); } private: - /// The patterns held by this list. - std::vector> patterns; + /// The set of. + std::vector> nativePatterns; + + /// The bytecode containing the compiled PDL patterns. + std::unique_ptr pdlByteCode; }; } // end namespace mlir diff --git a/mlir/include/mlir/Rewrite/PatternApplicator.h b/mlir/include/mlir/Rewrite/PatternApplicator.h --- a/mlir/include/mlir/Rewrite/PatternApplicator.h +++ b/mlir/include/mlir/Rewrite/PatternApplicator.h @@ -19,6 +19,10 @@ namespace mlir { class PatternRewriter; +namespace detail { +class PDLByteCodeMutableState; +} // end namespace detail + /// This class manages the application of a group of rewrite patterns, with a /// user-provided cost model. class PatternApplicator { @@ -29,8 +33,8 @@ /// `impossibleToMatch`. using CostModel = function_ref; - explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList) - : frozenPatternList(frozenPatternList) {} + explicit PatternApplicator(const FrozenRewritePatternList &frozenPatternList); + ~PatternApplicator(); /// Attempt to match and rewrite the given op with any pattern, allowing a /// predicate to decide if a pattern can be applied or not, and hooks for if @@ -60,16 +64,6 @@ void walkAllPatterns(function_ref walk); private: - /// Attempt to match and rewrite the given op with the given pattern, allowing - /// a predicate to decide if a pattern can be applied or not, and hooks for if - /// the pattern match was a success or failure. - LogicalResult - matchAndRewrite(Operation *op, const RewritePattern &pattern, - PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess); - /// The list that owns the patterns used within this applicator. const FrozenRewritePatternList &frozenPatternList; /// The set of patterns to match for each operation, stable sorted by benefit. @@ -77,6 +71,8 @@ /// The set of patterns that may match against any operation type, stable /// sorted by benefit. SmallVector anyOpPatterns; + /// The mutable state used during execution of the PDL bytecode. + std::unique_ptr mutableByteCodeState; }; } // end namespace mlir diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -285,13 +285,15 @@ // SuccessorRange //===----------------------------------------------------------------------===// -SuccessorRange::SuccessorRange(Block *block) : SuccessorRange(nullptr, 0) { +SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {} + +SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() { if (Operation *term = block->getTerminator()) if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } -SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange(nullptr, 0) { +SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() { if ((count = term->getNumSuccessors())) base = term->getBlockOperands().data(); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -61,8 +61,9 @@ return representation.dyn_cast(); } -OperationName OperationName::getFromOpaquePointer(void *pointer) { - return OperationName(RepresentationUnion::getFromOpaqueValue(pointer)); +OperationName OperationName::getFromOpaquePointer(const void *pointer) { + return OperationName( + RepresentationUnion::getFromOpaqueValue(const_cast(pointer))); } //===----------------------------------------------------------------------===// @@ -473,6 +474,12 @@ first->block = curParent; } +/// Remove the operation from its parent block, but don't delete it. +void Operation::removeFromParent() { + if (Block *parent = getBlock()) + parent->getOperations().remove(this); +} + /// Remove this operation (and its descendants) from its Block and delete /// all of them. void Operation::erase() { diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -69,6 +69,84 @@ /// Out-of-line vtable anchor. void RewritePattern::anchor() {} +//===----------------------------------------------------------------------===// +// PDLValue +//===----------------------------------------------------------------------===// + +void PDLValue::print(raw_ostream &os) { + if (!impl) { + os << ""; + return; + } + if (Value val = impl.dyn_cast()) { + os << val; + return; + } + AttrOpTypeImplT aotImpl = impl.get(); + if (Attribute attr = aotImpl.dyn_cast()) + os << attr; + else if (Operation *op = aotImpl.dyn_cast()) + os << *op; + else + os << aotImpl.get(); +} + +//===----------------------------------------------------------------------===// +// PDLPatternModule +//===----------------------------------------------------------------------===// + +void PDLPatternModule::mergeIn(PDLPatternModule &&other) { + // Ignore the other module if it has no patterns. + if (!other.pdlModule) + return; + // Steal the other state if we have no patterns. + if (!pdlModule) { + constraintFunctions = std::move(other.constraintFunctions); + createFunctions = std::move(other.createFunctions); + rewriteFunctions = std::move(other.rewriteFunctions); + pdlModule = std::move(other.pdlModule); + return; + } + // Steal the functions of the other module. + for (auto &it : constraintFunctions) + registerConstraintFunction(it.first(), std::move(it.second)); + for (auto &it : createFunctions) + registerCreateFunction(it.first(), std::move(it.second)); + for (auto &it : rewriteFunctions) + registerRewriteFunction(it.first(), std::move(it.second)); + + // Merge the pattern operations from the other module into this one. + Block *block = pdlModule->getBody(); + block->getTerminator()->erase(); + block->getOperations().splice(block->end(), + other.pdlModule->getBody()->getOperations()); +} + +//===----------------------------------------------------------------------===// +// Function Registry + +void PDLPatternModule::registerConstraintFunction( + StringRef name, PDLConstraintFunction constraintFn) { + auto it = constraintFunctions.try_emplace(name, std::move(constraintFn)); + (void)it; + assert(it.second && + "constraint with the given name has already been registered"); +} +void PDLPatternModule::registerCreateFunction(StringRef name, + PDLCreateFunction createFn) { + auto it = createFunctions.try_emplace(name, std::move(createFn)); + (void)it; + assert(it.second && "native create function with the given name has " + "already been registered"); +} +void PDLPatternModule::registerRewriteFunction(StringRef name, + PDLRewriteFunction rewriteFn) { + auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn)); + (void)it; + assert(it.second && "native rewrite function with the given name has " + "already been registered"); +} + //===----------------------------------------------------------------------===// // PatternRewriter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Rewrite/ByteCode.h @@ -0,0 +1,173 @@ +//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares a byte-code and interpreter for pattern rewrites in MLIR. +// The byte-code is constructed from the PDL Interpreter dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_REWRITE_BYTECODE_H_ +#define MLIR_REWRITE_BYTECODE_H_ + +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace pdl_interp { +class RecordMatchOp; +} // end namespace pdl_interp + +namespace detail { +class PDLByteCode; + +/// Use generic bytecode types. ByteCodeField refers to the actual bytecode +/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of +/// indices into the bytecode. Correctness is checked with static asserts. +using ByteCodeField = uint16_t; +using ByteCodeAddr = uint32_t; + +//===----------------------------------------------------------------------===// +// PDLByteCodePattern +//===----------------------------------------------------------------------===// + +/// All of the data pertaining to a specific pattern within the bytecode. +class PDLByteCodePattern : public Pattern { +public: + static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, + ByteCodeAddr rewriterAddr); + + /// Return the bytecode address of the rewriter for this pattern. + ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } + +private: + template + PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) + : Pattern(std::forward(patternArgs)...), + rewriterAddr(rewriterAddr) {} + + /// The address of the rewriter for this pattern. + ByteCodeAddr rewriterAddr; +}; + +//===----------------------------------------------------------------------===// +// PDLByteCodeMutableState +//===----------------------------------------------------------------------===// + +/// This class contains the mutable state of a bytecode instance. This allows +/// for a bytecode instance to be cached and reused across various different +/// threads/drivers. +class PDLByteCodeMutableState { +public: + /// Initialize the state from a bytecode instance. + void initialize(PDLByteCode &bytecode); + + /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds + /// to the position of the pattern within the range returned by + /// `PDLByteCode::getPatterns`. + void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); + +private: + /// Allow access to data fields. + friend class PDLByteCode; + + /// The mutable block of memory used during the matching and rewriting phases + /// of the bytecode. + std::vector memory; + + /// The up-to-date benefits of the patterns held by the bytecode. The order + /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. + std::vector currentPatternBenefits; +}; + +//===----------------------------------------------------------------------===// +// PDLByteCode +//===----------------------------------------------------------------------===// + +/// The bytecode class is also the interpreter. Contains the bytecode itself, +/// the static info, addresses of the rewriter functions, the interpreter +/// memory buffer, and the execution context. +class PDLByteCode { +public: + /// Each successful match returns a MatchResult, which contains information + /// necessary to execute the rewriter and indicates the originating pattern. + struct MatchResult { + MatchResult(Location loc, const PDLByteCodePattern &pattern, + PatternBenefit benefit) + : location(loc), pattern(&pattern), benefit(benefit) {} + + /// The fused location of operations to be replaced. + Location location; + /// Memory values defined in the matcher that are passed to the rewriter. + SmallVector values; + /// The originating pattern that was matched. This is always non-null, but + /// represented with a pointer to allow for assignment. + const PDLByteCodePattern *pattern; + /// The current benefit of the pattern that was matched. + PatternBenefit benefit; + }; + + /// Create a ByteCode instance from the given module containing operations in + /// the PDL interpreter dialect. + PDLByteCode(ModuleOp module, + llvm::StringMap constraintFns, + llvm::StringMap createFns, + llvm::StringMap rewriteFns); + + /// Return the patterns held by the bytecode. + ArrayRef getPatterns() const { return patterns; } + + /// Initialize the given state such that it can be used to execute the current + /// bytecode. + void initializeMutableState(PDLByteCodeMutableState &state) const; + + /// Run the pattern matcher on the given root operation, collecting the + /// matched patterns in `matches`. + void match(Operation *op, PatternRewriter &rewriter, + SmallVectorImpl &matches, + PDLByteCodeMutableState &state) const; + + /// Run the rewriter of the given pattern that was previously matched in + /// `match`. + void rewrite(PatternRewriter &rewriter, const MatchResult &match, + PDLByteCodeMutableState &state) const; + +private: + /// Execute the given byte code starting at the provided instruction `inst`. + /// `matches` is an optional field provided when this function is executed in + /// a matching context. + void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, + PDLByteCodeMutableState &state, + SmallVectorImpl *matches) const; + + /// A vector containing pointers to unqiued data. The storage is intentionally + /// opaque such that we can store a wide range of data types. The types of + /// data stored here include: + /// * Attribute, Identifier, OperationName, Type + std::vector uniquedData; + + /// A vector containing the generated bytecode for the matcher. + SmallVector matcherByteCode; + + /// A vector containing the generated bytecode for all of the rewriters. + SmallVector rewriterByteCode; + + /// The set of patterns contained within the bytecode. + SmallVector patterns; + + /// A set of user defined functions invoked via PDL. + std::vector constraintFunctions; + std::vector createFunctions; + std::vector rewriteFunctions; + + /// The maximum memory index used by a value. + ByteCodeField maxValueMemoryIndex = 0; +}; + +} // end namespace detail +} // end namespace mlir + +#endif // MLIR_REWRITE_BYTECODE_H_ diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -0,0 +1,1245 @@ +//===- ByteCode.cpp - Pattern ByteCode Interpreter --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements MLIR to byte-code generation and the interpreter. +// +//===----------------------------------------------------------------------===// + +#include "ByteCode.h" +#include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "llvm/ADT/IntervalMap.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "pdl-bytecode" + +using namespace mlir; +using namespace mlir::detail; + +//===----------------------------------------------------------------------===// +// PDLByteCodePattern +//===----------------------------------------------------------------------===// + +PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp, + ByteCodeAddr rewriterAddr) { + SmallVector generatedOps; + if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr()) + generatedOps = + llvm::to_vector<8>(generatedOpsAttr.getAsValueRange()); + + PatternBenefit benefit = matchOp.benefit(); + MLIRContext *ctx = matchOp.getContext(); + + // Check to see if this is pattern matches a specific operation type. + if (Optional rootKind = matchOp.rootKind()) + return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit, + ctx); + return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx, + MatchAnyOpTypeTag()); +} + +//===----------------------------------------------------------------------===// +// PDLByteCodeMutableState +//===----------------------------------------------------------------------===// + +/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds +/// to the position of the pattern within the range returned by +/// `PDLByteCode::getPatterns`. +void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex, + PatternBenefit benefit) { + currentPatternBenefits[patternIndex] = benefit; +} + +//===----------------------------------------------------------------------===// +// Bytecode OpCodes +//===----------------------------------------------------------------------===// + +namespace { +enum OpCode : ByteCodeField { + /// Apply an externally registered constraint. + ApplyConstraint, + /// Apply an externally registered rewrite. + ApplyRewrite, + /// Check if two generic values are equal. + AreEqual, + /// Unconditional branch. + Branch, + /// Compare the operand count of an operation with a constant. + CheckOperandCount, + /// Compare the name of an operation with a constant. + CheckOperationName, + /// Compare the result count of an operation with a constant. + CheckResultCount, + /// Invoke a native creation method. + CreateNative, + /// Create an operation. + CreateOperation, + /// Erase an operation. + EraseOp, + /// Terminate a matcher or rewrite sequence. + Finalize, + /// Get a specific attribute of an operation. + GetAttribute, + /// Get the type of an attribute. + GetAttributeType, + /// Get the defining operation of a value. + GetDefiningOp, + /// Get a specific operand of an operation. + GetOperand0, + GetOperand1, + GetOperand2, + GetOperand3, + GetOperand, + /// Get a specific result of an operation. + GetResult0, + GetResult1, + GetResult2, + GetResult3, + GetResult, + /// Get the type of a value. + GetValueType, + /// Check if a generic value is not null. + IsNotNull, + /// Record a successful pattern match. + RecordMatch, + /// Replace an operation. + ReplaceOp, + /// Compare an attribute with a set of constants. + SwitchAttribute, + /// Compare the operand count of an operation with a set of constants. + SwitchOperandCount, + /// Compare the name of an operation with a set of constants. + SwitchOperationName, + /// Compare the result count of an operation with a set of constants. + SwitchResultCount, + /// Compare a type with a set of constants. + SwitchType, +}; + +enum class PDLValueKind { Attribute, Operation, Type, Value }; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// ByteCode Generation +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Generator + +namespace { +class ByteCodeWriter; + +/// This class represents the main generator for the pattern bytecode. +class Generator { +public: + Generator(MLIRContext *ctx, std::vector &uniquedData, + SmallVectorImpl &matcherByteCode, + SmallVectorImpl &rewriterByteCode, + SmallVectorImpl &patterns, + ByteCodeField &maxValueMemoryIndex, + llvm::StringMap &constraintFns, + llvm::StringMap &createFns, + llvm::StringMap &rewriteFns) + : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), + rewriterByteCode(rewriterByteCode), patterns(patterns), + maxValueMemoryIndex(maxValueMemoryIndex) { + for (auto it : llvm::enumerate(constraintFns)) + constraintToMemIndex.try_emplace(it.value().first(), it.index()); + for (auto it : llvm::enumerate(createFns)) + nativeCreateToMemIndex.try_emplace(it.value().first(), it.index()); + for (auto it : llvm::enumerate(rewriteFns)) + externalRewriterToMemIndex.try_emplace(it.value().first(), it.index()); + } + + /// Generate the bytecode for the given PDL interpreter module. + void generate(ModuleOp module); + + /// Return the memory index to use for the given value. + ByteCodeField &getMemIndex(Value value) { + assert(valueToMemIndex.count(value) && + "expected memory index to be assigned"); + return valueToMemIndex[value]; + } + + /// Return an index to use when referring to the given data that is uniqued in + /// the MLIR context. + template + std::enable_if_t::value, ByteCodeField &> + getMemIndex(T val) { + const void *opaqueVal = val.getAsOpaquePointer(); + + // Get or insert a reference to this value. + auto it = uniquedDataToMemIndex.try_emplace( + opaqueVal, maxValueMemoryIndex + uniquedData.size()); + if (it.second) + uniquedData.push_back(opaqueVal); + return it.first->second; + } + +private: + /// Allocate memory indices for the results of operations within the matcher + /// and rewriters. + void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); + + /// Generate the bytecode for the given operation. + void generate(Operation *op, ByteCodeWriter &writer); + void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); + void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer); + void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); + void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); + void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); + + /// Mapping from value to its corresponding memory index. + DenseMap valueToMemIndex; + + /// Mapping from the name of an externally registered rewrite to its index in + /// the bytecode registry. + llvm::StringMap externalRewriterToMemIndex; + + /// Mapping from the name of an externally registered constraint to its index + /// in the bytecode registry. + llvm::StringMap constraintToMemIndex; + + /// Mapping from the name of an externally registered creation method to its + /// index in the bytecode registry. + llvm::StringMap nativeCreateToMemIndex; + + /// Mapping from rewriter function name to the bytecode address of the + /// rewriter function in byte. + llvm::StringMap rewriterToAddr; + + /// Mapping from a uniqued storage object to its memory index within + /// `uniquedData`. + DenseMap uniquedDataToMemIndex; + + /// The current MLIR context. + MLIRContext *ctx; + + /// Data of the ByteCode class to be populated. + std::vector &uniquedData; + SmallVectorImpl &matcherByteCode; + SmallVectorImpl &rewriterByteCode; + SmallVectorImpl &patterns; + ByteCodeField &maxValueMemoryIndex; +}; + +/// This class provides utilities for writing a bytecode stream. +struct ByteCodeWriter { + ByteCodeWriter(SmallVectorImpl &bytecode, Generator &generator) + : bytecode(bytecode), generator(generator) {} + + /// Append a field to the bytecode. + void append(ByteCodeField field) { bytecode.push_back(field); } + + /// Append an address to the bytecode. + void append(ByteCodeAddr field) { + static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, + "unexpected ByteCode address size"); + + ByteCodeField fieldParts[2]; + std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr)); + bytecode.append({fieldParts[0], fieldParts[1]}); + } + + /// Append a successor range to the bytecode, the exact address will need to + /// be resolved later. + void append(SuccessorRange successors) { + // Add back references to the any successors so that the address can be + // resolved later. + for (Block *successor : successors) { + unresolvedSuccessorRefs[successor].push_back(bytecode.size()); + append(ByteCodeAddr(0)); + } + } + + /// Append a range of values that will be read as generic PDLValues. + void appendPDLValueList(OperandRange values) { + bytecode.push_back(values.size()); + for (Value value : values) { + // Append the type of the value in addition to the value itself. + PDLValueKind kind = + TypeSwitch(value.getType()) + .Case( + [](Type) { return PDLValueKind::Attribute; }) + .Case( + [](Type) { return PDLValueKind::Operation; }) + .Case([](Type) { return PDLValueKind::Type; }) + .Case([](Type) { return PDLValueKind::Value; }); + bytecode.push_back(static_cast(kind)); + append(value); + } + } + + /// Check if the given class `T` has an iterator type. + template + using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); + + /// Append a value that will be stored in a memory slot and not inline within + /// the bytecode. + template + std::enable_if_t::value || + std::is_pointer::value> + append(T value) { + bytecode.push_back(generator.getMemIndex(value)); + } + + /// Append a range of values. + template > + std::enable_if_t::value> + append(T range) { + bytecode.push_back(llvm::size(range)); + for (auto it : range) + append(it); + } + + /// Append a variadic number of fields to the bytecode. + template + void append(FieldTy field, Field2Ty field2, FieldTys... fields) { + append(field); + append(field2, fields...); + } + + /// Successor references in the bytecode that have yet to be resolved. + DenseMap> unresolvedSuccessorRefs; + + /// The underlying bytecode buffer. + SmallVectorImpl &bytecode; + + /// The main generator producing PDL. + Generator &generator; +}; +} // end anonymous namespace + +void Generator::generate(ModuleOp module) { + FuncOp matcherFunc = module.lookupSymbol( + pdl_interp::PDLInterpDialect::getMatcherFunctionName()); + ModuleOp rewriterModule = module.lookupSymbol( + pdl_interp::PDLInterpDialect::getRewriterModuleName()); + assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module"); + + // Allocate memory indices for the results of operations within the matcher + // and rewriters. + allocateMemoryIndices(matcherFunc, rewriterModule); + + // Generate code for the rewriter functions. + ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this); + for (FuncOp rewriterFunc : rewriterModule.getOps()) { + rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size()); + for (Operation &op : rewriterFunc.getOps()) + generate(&op, rewriterByteCodeWriter); + } + assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() && + "unexpected branches in rewriter function"); + + // Generate code for the matcher function. + DenseMap blockToAddr; + llvm::ReversePostOrderTraversal rpot(&matcherFunc.getBody()); + ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); + for (Block *block : rpot) { + // Keep track of where this block begins within the matcher function. + blockToAddr.try_emplace(block, matcherByteCode.size()); + for (Operation &op : *block) + generate(&op, matcherByteCodeWriter); + } + + // Resolve successor references in the matcher. + for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { + ByteCodeAddr addr = blockToAddr[it.first]; + for (unsigned offsetToFix : it.second) + std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr)); + } +} + +void Generator::allocateMemoryIndices(FuncOp matcherFunc, + ModuleOp rewriterModule) { + // Rewriters use simplistic allocation scheme that simply assigns an index to + // each result. + for (FuncOp rewriterFunc : rewriterModule.getOps()) { + ByteCodeField index = 0; + for (BlockArgument arg : rewriterFunc.getArguments()) + valueToMemIndex.try_emplace(arg, index++); + rewriterFunc.getBody().walk([&](Operation *op) { + for (Value result : op->getResults()) + valueToMemIndex.try_emplace(result, index++); + }); + if (index > maxValueMemoryIndex) + maxValueMemoryIndex = index; + } + + // The matcher function uses a more sophisticated numbering that tries to + // minimize the number of memory indices assigned. This is done by determining + // a live range of the values within the matcher, then the allocation is just + // finding the minimal number of overlapping live ranges. This is essentially + // a simplified form of register allocation where we don't necessarily have a + // limited number of registers, but we still want to minimize the number used. + DenseMap opToIndex; + matcherFunc.getBody().walk([&](Operation *op) { + opToIndex.insert(std::make_pair(op, opToIndex.size())); + }); + + // Liveness info for each of the defs within the matcher. + using LivenessSet = llvm::IntervalMap; + LivenessSet::Allocator allocator; + DenseMap valueDefRanges; + + // Assign the root operation being matched to slot 0. + BlockArgument rootOpArg = matcherFunc.getArgument(0); + valueToMemIndex[rootOpArg] = 0; + + // Walk each of the blocks, computing the def interval that the value is used. + Liveness matcherLiveness(matcherFunc); + for (Block &block : matcherFunc.getBody()) { + const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); + assert(info && "expected liveness info for block"); + auto processValue = [&](Value value, Operation *firstUseOrDef) { + // We don't need to process the root op argument, this value is always + // assigned to the first memory slot. + if (value == rootOpArg) + return; + + // Set indices for the range of this block that the value is used. + auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; + defRangeIt->second.insert( + opToIndex[firstUseOrDef], + opToIndex[info->getEndOperation(value, firstUseOrDef)], + /*dummyValue*/ 0); + }; + + // Process the live-ins of this block. + for (Value liveIn : info->in()) + processValue(liveIn, &block.front()); + + // Process any new defs within this block. + for (Operation &op : block) + for (Value result : op.getResults()) + processValue(result, &op); + } + + // Greedily allocate memory slots using the computed def live ranges. + std::vector allocatedIndices; + for (auto &defIt : valueDefRanges) { + ByteCodeField &memIndex = valueToMemIndex[defIt.first]; + LivenessSet &defSet = defIt.second; + + // Try to allocate to an existing index. + for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { + LivenessSet &existingIndex = existingIndexIt.value(); + llvm::IntervalMapOverlaps overlaps( + defIt.second, existingIndex); + if (overlaps.valid()) + continue; + // Union the range of the def within the existing index. + for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) + existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); + memIndex = existingIndexIt.index() + 1; + } + + // If no existing index could be used, add a new one. + if (memIndex == 0) { + allocatedIndices.emplace_back(allocator); + for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) + allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); + memIndex = allocatedIndices.size(); + } + } + + // Update the max number of indices. + ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; + if (numMatcherIndices > maxValueMemoryIndex) + maxValueMemoryIndex = numMatcherIndices; +} + +void Generator::generate(Operation *op, ByteCodeWriter &writer) { + TypeSwitch(op) + .Case( + [&](auto interpOp) { this->generate(interpOp, writer); }) + .Default([](Operation *) { + llvm_unreachable("unknown `pdl_interp` operation"); + }); +} + +void Generator::generate(pdl_interp::ApplyConstraintOp op, + ByteCodeWriter &writer) { + assert(constraintToMemIndex.count(op.name()) && + "expected index for constraint function"); + writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()], + op.constParamsAttr()); + writer.appendPDLValueList(op.args()); + writer.append(op.getSuccessors()); +} +void Generator::generate(pdl_interp::ApplyRewriteOp op, + ByteCodeWriter &writer) { + assert(externalRewriterToMemIndex.count(op.name()) && + "expected index for rewrite function"); + writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()], + op.constParamsAttr(), op.root()); + writer.appendPDLValueList(op.args()); +} +void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { + writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); +} +void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { + writer.append(OpCode::Branch, SuccessorRange(op)); +} +void Generator::generate(pdl_interp::CheckAttributeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckOperandCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckOperationNameOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CheckOperationName, op.operation(), + OperationName(op.name(), ctx), op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckResultCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CheckResultCount, op.operation(), op.count(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { + writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); +} +void Generator::generate(pdl_interp::CreateAttributeOp op, + ByteCodeWriter &writer) { + // Simply repoint the memory index of the result to the constant. + getMemIndex(op.attribute()) = getMemIndex(op.value()); +} +void Generator::generate(pdl_interp::CreateNativeOp op, + ByteCodeWriter &writer) { + assert(nativeCreateToMemIndex.count(op.name()) && + "expected index for creation function"); + writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()], + op.result(), op.constParamsAttr()); + writer.appendPDLValueList(op.args()); +} +void Generator::generate(pdl_interp::CreateOperationOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::CreateOperation, op.operation(), + OperationName(op.name(), ctx), op.operands()); + + // Add the attributes. + OperandRange attributes = op.attributes(); + writer.append(static_cast(attributes.size())); + for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { + writer.append( + Identifier::get(std::get<0>(it).cast().getValue(), ctx), + std::get<1>(it)); + } + writer.append(op.types()); +} +void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { + // Simply repoint the memory index of the result to the constant. + getMemIndex(op.result()) = getMemIndex(op.value()); +} +void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { + writer.append(OpCode::EraseOp, op.operation()); +} +void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { + writer.append(OpCode::Finalize); +} +void Generator::generate(pdl_interp::GetAttributeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), + Identifier::get(op.name(), ctx)); +} +void Generator::generate(pdl_interp::GetAttributeTypeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetAttributeType, op.result(), op.value()); +} +void Generator::generate(pdl_interp::GetDefiningOpOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); +} +void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { + uint32_t index = op.index(); + if (index < 4) + writer.append(static_cast(OpCode::GetOperand0 + index)); + else + writer.append(OpCode::GetOperand, index); + writer.append(op.operation(), op.value()); +} +void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { + uint32_t index = op.index(); + if (index < 4) + writer.append(static_cast(OpCode::GetResult0 + index)); + else + writer.append(OpCode::GetResult, index); + writer.append(op.operation(), op.value()); +} +void Generator::generate(pdl_interp::GetValueTypeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetValueType, op.result(), op.value()); +} +void Generator::generate(pdl_interp::InferredTypeOp op, + ByteCodeWriter &writer) { + // InferType maps to a null type as a marker for inferring a result type. + getMemIndex(op.type()) = getMemIndex(Type()); +} +void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { + writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); +} +void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { + ByteCodeField patternIndex = patterns.size(); + patterns.emplace_back(PDLByteCodePattern::create( + op, rewriterToAddr[op.rewriter().getLeafReference()])); + writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op), + op.matchedOps(), op.inputs()); +} +void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { + writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); +} +void Generator::generate(pdl_interp::SwitchAttributeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchOperandCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchOperationNameOp op, + ByteCodeWriter &writer) { + auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) { + return OperationName(attr.cast().getValue(), ctx); + }); + writer.append(OpCode::SwitchOperationName, op.operation(), cases, + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchResultCountOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(), + op.getSuccessors()); +} +void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) { + writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), + op.getSuccessors()); +} + +//===----------------------------------------------------------------------===// +// PDLByteCode +//===----------------------------------------------------------------------===// + +PDLByteCode::PDLByteCode(ModuleOp module, + llvm::StringMap constraintFns, + llvm::StringMap createFns, + llvm::StringMap rewriteFns) { + Generator generator(module.getContext(), uniquedData, matcherByteCode, + rewriterByteCode, patterns, maxValueMemoryIndex, + constraintFns, createFns, rewriteFns); + generator.generate(module); + + // Initialize the external functions. + for (auto &it : constraintFns) + constraintFunctions.push_back(std::move(it.second)); + for (auto &it : createFns) + createFunctions.push_back(std::move(it.second)); + for (auto &it : rewriteFns) + rewriteFunctions.push_back(std::move(it.second)); +} + +/// Initialize the given state such that it can be used to execute the current +/// bytecode. +void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { + state.memory.resize(maxValueMemoryIndex, nullptr); + state.currentPatternBenefits.reserve(patterns.size()); + for (const PDLByteCodePattern &pattern : patterns) + state.currentPatternBenefits.push_back(pattern.getBenefit()); +} + +//===----------------------------------------------------------------------===// +// ByteCode Execution + +namespace { +/// This class provides support for executing a bytecode stream. +class ByteCodeExecutor { +public: + ByteCodeExecutor(const ByteCodeField *curCodeIt, + MutableArrayRef memory, + ArrayRef uniquedMemory, + ArrayRef code, + ArrayRef currentPatternBenefits, + ArrayRef patterns, + ArrayRef constraintFunctions, + ArrayRef createFunctions, + ArrayRef rewriteFunctions) + : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), + code(code), currentPatternBenefits(currentPatternBenefits), + patterns(patterns), constraintFunctions(constraintFunctions), + createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {} + + /// Start executing the code at the current bytecode index. `matches` is an + /// optional field provided when this function is executed in a matching + /// context. + void execute(PatternRewriter &rewriter, + SmallVectorImpl *matches = nullptr, + Optional mainRewriteLoc = {}); + +private: + /// Read a value from the bytecode buffer, optionally skipping a certain + /// number of prefix values. These methods always update the buffer to point + /// to the next field after the read data. + template T read(size_t skipN = 0) { + curCodeIt += skipN; + return readImpl(); + } + ByteCodeField read(size_t skipN = 0) { return read(skipN); } + + /// Read a list of values from the bytecode buffer. + template + void readList(SmallVectorImpl &list) { + list.clear(); + for (unsigned i = 0, e = read(); i != e; ++i) + list.push_back(read()); + } + + /// Jump to a specific successor based on a predicate value. + void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } + /// Jump to a specific successor based on a destination index. + void selectJump(size_t destIndex) { + curCodeIt = &code[read(destIndex * 2)]; + } + + /// Handle a switch operation with the provided value and cases. + template + void handleSwitch(const T &value, RangeT &&cases) { + LLVM_DEBUG({ + llvm::dbgs() << " * Value: " << value << "\n" + << " * Cases: "; + llvm::interleaveComma(cases, llvm::dbgs()); + llvm::dbgs() << "\n\n"; + }); + + // Check to see if the attribute value is within the case list. Jump to + // the correct successor index based on the result. + auto it = llvm::find(cases, value); + selectJump(it == cases.end() ? size_t(0) : ((it - cases.begin()) + 1)); + } + + /// Internal implementation of reading various data types from the bytecode + /// stream. + template const void *readFromMemory() { + size_t index = *curCodeIt++; + + // If this type is an SSA value, it can only be stored in non-const memory. + if (llvm::is_one_of::value || index < memory.size()) + return memory[index]; + + // Otherwise, if this index is not inbounds it is uniqued. + return uniquedMemory[index - memory.size()]; + } + template + std::enable_if_t::value, T> readImpl() { + return reinterpret_cast(const_cast(readFromMemory())); + } + template + std::enable_if_t::value && !std::is_same::value, + T> + readImpl() { + return T(T::getFromOpaquePointer(readFromMemory())); + } + template + std::enable_if_t::value, T> readImpl() { + switch (static_cast(read())) { + case PDLValueKind::Attribute: + return read(); + case PDLValueKind::Operation: + return read(); + case PDLValueKind::Type: + return read(); + case PDLValueKind::Value: + return read(); + } + } + template + std::enable_if_t::value, T> readImpl() { + static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2, + "unexpected ByteCode address size"); + ByteCodeAddr result; + std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr)); + curCodeIt += 2; + return result; + } + template + std::enable_if_t::value, T> readImpl() { + return *curCodeIt++; + } + + /// The underlying bytecode buffer. + const ByteCodeField *curCodeIt; + + /// The current execution memory. + MutableArrayRef memory; + + /// References to ByteCode data necessary for execution. + ArrayRef uniquedMemory; + ArrayRef code; + ArrayRef currentPatternBenefits; + ArrayRef patterns; + ArrayRef constraintFunctions; + ArrayRef createFunctions; + ArrayRef rewriteFunctions; +}; +} // end anonymous namespace + +void ByteCodeExecutor::execute( + PatternRewriter &rewriter, + SmallVectorImpl *matches, + Optional mainRewriteLoc) { + while (true) { + OpCode opCode = static_cast(read()); + switch (opCode) { + case ApplyConstraint: { + LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); + const PDLConstraintFunction &constraintFn = constraintFunctions[read()]; + ArrayAttr constParams = read(); + SmallVector args; + readList(args); + LLVM_DEBUG({ + llvm::dbgs() << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; + }); + + // Invoke the constraint and jump to the proper destination. + selectJump(succeeded(constraintFn(args, constParams, rewriter))); + break; + } + case ApplyRewrite: { + LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); + const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; + ArrayAttr constParams = read(); + Operation *root = read(); + SmallVector args; + readList(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Root: " << *root << "\n" + << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n"; + }); + rewriteFn(root, args, constParams, rewriter); + break; + } + case AreEqual: { + LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); + const void *lhs = read(); + const void *rhs = read(); + + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + selectJump(lhs == rhs); + break; + } + case Branch: { + LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n"); + curCodeIt = &code[read()]; + break; + } + case CheckOperandCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); + Operation *op = read(); + uint32_t expectedCount = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" + << " * Expected: " << expectedCount << "\n\n"); + selectJump(op->getNumOperands() == expectedCount); + break; + } + case CheckOperationName: { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); + Operation *op = read(); + OperationName expectedName = read(); + + LLVM_DEBUG(llvm::dbgs() + << " * Found: \"" << op->getName() << "\"\n" + << " * Expected: \"" << expectedName << "\"\n\n"); + selectJump(op->getName() == expectedName); + break; + } + case CheckResultCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); + Operation *op = read(); + uint32_t expectedCount = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" + << " * Expected: " << expectedCount << "\n\n"); + selectJump(op->getNumResults() == expectedCount); + break; + } + case CreateNative: { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n"); + const PDLCreateFunction &createFn = createFunctions[read()]; + ByteCodeField resultIndex = read(); + ArrayAttr constParams = read(); + SmallVector args; + readList(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Arguments: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; + }); + + PDLValue result = createFn(args, constParams, rewriter); + memory[resultIndex] = result.getAsOpaquePointer(); + + LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n"); + break; + } + case CreateOperation: { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); + assert(mainRewriteLoc && "expected rewrite loc to be provided when " + "executing the rewriter bytecode"); + + unsigned memIndex = read(); + OperationState state(*mainRewriteLoc, read()); + readList(state.operands); + for (unsigned i = 0, e = read(); i != e; ++i) { + Identifier name = read(); + if (Attribute attr = read()) + state.addAttribute(name, attr); + } + + bool hasInferredTypes = false; + for (unsigned i = 0, e = read(); i != e; ++i) { + Type resultType = read(); + hasInferredTypes |= !resultType; + state.types.push_back(resultType); + } + + // Handle the case where the operation has inferred types. + if (hasInferredTypes) { + InferTypeOpInterface::Concept *concept = + state.name.getAbstractOperation() + ->getInterface(); + + // TODO: Handle failure. + SmallVector inferredTypes; + if (failed(concept->inferReturnTypes( + state.getContext(), state.location, state.operands, + state.attributes.getDictionary(state.getContext()), + state.regions, inferredTypes))) + return; + + for (unsigned i = 0, e = state.types.size(); i != e; ++i) + if (!state.types[i]) + state.types[i] = inferredTypes[i]; + } + Operation *resultOp = rewriter.createOperation(state); + memory[memIndex] = resultOp; + + LLVM_DEBUG({ + llvm::dbgs() << " * Attributes: " + << state.attributes.getDictionary(state.getContext()) + << "\n * Operands: "; + llvm::interleaveComma(state.operands, llvm::dbgs()); + llvm::dbgs() << "\n * Result Types: "; + llvm::interleaveComma(state.types, llvm::dbgs()); + llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n"; + }); + break; + } + case EraseOp: { + LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); + Operation *op = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n"); + rewriter.eraseOp(op); + break; + } + case Finalize: { + LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); + return; + } + case GetAttribute: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); + unsigned memIndex = read(); + Operation *op = read(); + Identifier attrName = read(); + Attribute attr = op->getAttr(attrName); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Attribute: " << attrName << "\n" + << " * Result: " << attr << "\n\n"); + memory[memIndex] = attr.getAsOpaquePointer(); + break; + } + case GetAttributeType: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); + unsigned memIndex = read(); + Attribute attr = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" + << " * Result: " << attr.getType() << "\n\n"); + memory[memIndex] = attr.getType().getAsOpaquePointer(); + break; + } + case GetDefiningOp: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); + unsigned memIndex = read(); + Value value = read(); + Operation *op = value ? value.getDefiningOp() : nullptr; + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" + << " * Result: " << *op << "\n\n"); + memory[memIndex] = op; + break; + } + case GetOperand0: + case GetOperand1: + case GetOperand2: + case GetOperand3: + case GetOperand: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand:\n"); + unsigned index = + opCode == GetOperand ? read() : (opCode - GetOperand0); + Operation *op = read(); + unsigned memIndex = read(); + Value operand = + index < op->getNumOperands() ? op->getOperand(index) : Value(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Index: " << index << "\n" + << " * Result: " << operand << "\n\n"); + memory[memIndex] = operand.getAsOpaquePointer(); + break; + } + case GetResult0: + case GetResult1: + case GetResult2: + case GetResult3: + case GetResult: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetResult:\n"); + unsigned index = + opCode == GetResult ? read() : (opCode - GetResult0); + Operation *op = read(); + unsigned memIndex = read(); + OpResult result = + index < op->getNumResults() ? op->getResult(index) : OpResult(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Index: " << index << "\n" + << " * Result: " << result << "\n\n"); + memory[memIndex] = result.getAsOpaquePointer(); + break; + } + case GetValueType: { + LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); + unsigned memIndex = read(); + Value value = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" + << " * Result: " << value.getType() << "\n\n"); + memory[memIndex] = value.getType().getAsOpaquePointer(); + break; + } + case IsNotNull: { + LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); + const void *value = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n"); + selectJump(value != nullptr); + break; + } + case RecordMatch: { + LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); + assert(matches && + "expected matches to be provided when executing the matcher"); + unsigned patternIndex = read(); + PatternBenefit benefit = currentPatternBenefits[patternIndex]; + const ByteCodeField *dest = &code[read()]; + + // If the benefit of the pattern is impossible, skip the processing of the + // rest of the pattern. + if (benefit.isImpossibleToMatch()) { + LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n"); + curCodeIt = dest; + break; + } + + Location matchLoc = rewriter.getUnknownLoc(); + for (unsigned i = 0, e = read(); i != e; ++i) { + matchLoc = FusedLoc::get({matchLoc, read()->getLoc()}, + matchLoc.getContext()); + } + + LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" + << " * Location: " << matchLoc << "\n\n"); + matches->emplace_back(matchLoc, patterns[patternIndex], benefit); + readList(matches->back().values); + curCodeIt = dest; + break; + } + case ReplaceOp: { + LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); + Operation *op = read(); + SmallVector args; + readList(args); + + LLVM_DEBUG({ + llvm::dbgs() << " * Operation: " << *op << "\n" + << " * Values: "; + llvm::interleaveComma(args, llvm::dbgs()); + llvm::dbgs() << "\n\n"; + }); + rewriter.replaceOp(op, args); + break; + } + case SwitchAttribute: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); + Attribute value = read(); + ArrayAttr cases = read(); + handleSwitch(value, cases); + break; + } + case SwitchOperandCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); + Operation *op = read(); + auto cases = read().getValues(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + handleSwitch(op->getNumOperands(), cases); + break; + } + case SwitchOperationName: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); + OperationName value = read()->getName(); + size_t caseCount = read(); + + // The operation names are stored in-line, so to print them out for + // debugging purposes we need to read the array before executing the + // switch so that we can display all of the possible values. + LLVM_DEBUG({ + const ByteCodeField *prevCodeIt = curCodeIt; + llvm::dbgs() << " * Value: " << value << "\n" + << " * Cases: "; + llvm::interleaveComma( + llvm::map_range(llvm::seq(0, caseCount), + [&](size_t i) { return read(); }), + llvm::dbgs()); + llvm::dbgs() << "\n\n"; + curCodeIt = prevCodeIt; + }); + + // Try to find the switch value within any of the cases. + size_t jumpDest = 0; + for (size_t i = 0; i != caseCount; ++i) { + if (read() == value) { + curCodeIt += (caseCount - i - 1); + jumpDest = i + 1; + break; + } + } + selectJump(jumpDest); + break; + } + case SwitchResultCount: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); + Operation *op = read(); + auto cases = read().getValues(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + handleSwitch(op->getNumResults(), cases); + break; + } + case SwitchType: { + LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); + Type value = read(); + auto cases = read().getAsValueRange(); + handleSwitch(value, cases); + break; + } + } + } +} + +/// Run the pattern matcher on the given root operation, collecting the matched +/// patterns in `matches`. +void PDLByteCode::match(Operation *op, PatternRewriter &rewriter, + SmallVectorImpl &matches, + PDLByteCodeMutableState &state) const { + // The first memory slot is always the root operation. + state.memory[0] = op; + + // The matcher function always starts at code address 0. + ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData, + matcherByteCode, state.currentPatternBenefits, + patterns, constraintFunctions, createFunctions, + rewriteFunctions); + executor.execute(rewriter, &matches); + + // Order the found matches by benefit. + std::stable_sort(matches.begin(), matches.end(), + [](const MatchResult &lhs, const MatchResult &rhs) { + return lhs.benefit > rhs.benefit; + }); +} + +/// Run the rewriter of the given pattern on the root operation `op`. +void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match, + PDLByteCodeMutableState &state) const { + // The arguments of the rewrite function are stored at the start of the + // memory buffer. + llvm::copy(match.values, state.memory.begin()); + + ByteCodeExecutor executor( + &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, + uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, + constraintFunctions, createFunctions, rewriteFunctions); + executor.execute(rewriter, /*matches=*/nullptr, match.location); +} diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -7,13 +7,66 @@ //===----------------------------------------------------------------------===// #include "mlir/Rewrite/FrozenRewritePatternList.h" +#include "ByteCode.h" +#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" using namespace mlir; +static LogicalResult convertPDLToPDLInterp(ModuleOp pdlModule) { + // Skip the conversion if the module doesn't contain pdl. + if (llvm::empty(pdlModule.getOps())) + return success(); + + // Simplify the provided PDL module. Note that we can't use the canonicalizer + // here because it would create a cyclic dependency. + auto simplifyFn = [](Operation *op) { + // TODO: Add folding here if ever necessary. + if (isOpTriviallyDead(op)) + op->erase(); + }; + pdlModule.getBody()->walk(simplifyFn); + + /// Lower the PDL pattern module to the interpreter dialect. + PassManager pdlPipeline(pdlModule.getContext(), /*verifyPasses=*/false); + pdlPipeline.addPass(createPDLToPDLInterpPass()); + if (failed(pdlPipeline.run(pdlModule))) + return failure(); + + // Simplify again after running the lowering pipeline. + pdlModule.getBody()->walk(simplifyFn); + return success(); +} + //===----------------------------------------------------------------------===// // FrozenRewritePatternList //===----------------------------------------------------------------------===// FrozenRewritePatternList::FrozenRewritePatternList( OwningRewritePatternList &&patterns) - : patterns(patterns.takePatterns()) {} + : nativePatterns(std::move(patterns.getNativePatterns())) { + PDLPatternModule &pdlPatterns = patterns.getPDLPatterns(); + + // Generate the bytecode for the PDL patterns if any were provided. + ModuleOp pdlModule = pdlPatterns.getModule(); + if (!pdlModule) + return; + if (failed(convertPDLToPDLInterp(pdlModule))) + llvm::report_fatal_error( + "failed to lower PDL pattern module to the PDL Interpreter"); + + // Generate the pdl bytecode. + pdlByteCode = std::make_unique( + pdlModule, pdlPatterns.takeConstraintFunctions(), + pdlPatterns.takeCreateFunctions(), pdlPatterns.takeRewriteFunctions()); +} + +FrozenRewritePatternList::FrozenRewritePatternList( + FrozenRewritePatternList &&patterns) + : nativePatterns(std::move(patterns.nativePatterns)), + pdlByteCode(std::move(patterns.pdlByteCode)) {} + +FrozenRewritePatternList::~FrozenRewritePatternList() {} diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -12,17 +12,36 @@ //===----------------------------------------------------------------------===// #include "mlir/Rewrite/PatternApplicator.h" +#include "ByteCode.h" #include "llvm/Support/Debug.h" using namespace mlir; +using namespace mlir::detail; + +PatternApplicator::PatternApplicator( + const FrozenRewritePatternList &frozenPatternList) + : frozenPatternList(frozenPatternList) { + if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { + mutableByteCodeState = std::make_unique(); + bytecode->initializeMutableState(*mutableByteCodeState); + } +} +PatternApplicator::~PatternApplicator() {} #define DEBUG_TYPE "pattern-match" void PatternApplicator::applyCostModel(CostModel model) { + // Apply the cost model to the bytecode patterns first, and then the native + // patterns. + if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { + for (auto it : llvm::enumerate(bytecode->getPatterns())) + mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); + } + // Separate patterns by root kind to simplify lookup later on. patterns.clear(); anyOpPatterns.clear(); - for (const auto &pat : frozenPatternList.getPatterns()) { + for (const auto &pat : frozenPatternList.getNativePatterns()) { // If the pattern is always impossible to match, just ignore it. if (pat.getBenefit().isImpossibleToMatch()) { LLVM_DEBUG({ @@ -81,8 +100,12 @@ void PatternApplicator::walkAllPatterns( function_ref walk) { - for (auto &it : frozenPatternList.getPatterns()) + for (const Pattern &it : frozenPatternList.getNativePatterns()) walk(it); + if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { + for (const Pattern &it : bytecode->getPatterns()) + walk(it); + } } LogicalResult PatternApplicator::matchAndRewrite( @@ -90,6 +113,14 @@ function_ref canApply, function_ref onFailure, function_ref onSuccess) { + // Before checking native patterns, first match against the bytecode. This + // won't automatically perform any rewrites so there is no need to worry about + // conflicts. + SmallVector pdlMatches; + const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); + if (bytecode) + bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); + // Check to see if there are patterns matching this specific operation type. MutableArrayRef opPatterns; auto patternIt = patterns.find(op->getName()); @@ -98,51 +129,50 @@ // Process the patterns for that match the specific operation type, and any // operation type in an interleaved fashion. - // FIXME: It'd be nice to just write an llvm::make_merge_range utility - // and pass in a comparison function. That would make this code trivial. auto opIt = opPatterns.begin(), opE = opPatterns.end(); auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); - while (opIt != opE && anyIt != anyE) { - // Try to match the pattern providing the most benefit. - const RewritePattern *pattern; - if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) - pattern = *(opIt++); - else - pattern = *(anyIt++); + auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end(); + while (true) { + // Find the next pattern with the highest benefit. + const Pattern *bestPattern = nullptr; + const PDLByteCode::MatchResult *pdlMatch = nullptr; + /// Operation specific patterns. + if (opIt != opE) + bestPattern = *(opIt++); + /// Operation agnostic patterns. + if (anyIt != anyE && + (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit())) + bestPattern = *(anyIt++); + /// PDL patterns. + if (pdlIt != pdlE && + (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) { + pdlMatch = pdlIt; + bestPattern = (pdlIt++)->pattern; + } + if (!bestPattern) + break; - // Otherwise, try to match the generic pattern. - if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, - onSuccess))) - return success(); - } - // If we break from the loop, then only one of the ranges can still have - // elements. Loop over both without checking given that we don't need to - // interleave anymore. - for (const RewritePattern *pattern : llvm::concat( - llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { - if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, - onSuccess))) + // Check that the pattern can be applied. + if (canApply && !canApply(*bestPattern)) + continue; + + // Try to match and rewrite this pattern. The patterns are sorted by + // benefit, so if we match we can immediately rewrite. For PDL patterns, the + // match has already been performed, we just need to rewrite. + rewriter.setInsertionPoint(op); + LogicalResult result = success(); + if (pdlMatch) { + bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); + } else { + result = static_cast(bestPattern) + ->matchAndRewrite(op, rewriter); + } + if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern)))) return success(); - } - return failure(); -} -LogicalResult PatternApplicator::matchAndRewrite( - Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, - function_ref canApply, - function_ref onFailure, - function_ref onSuccess) { - // Check that the pattern can be applied. - if (canApply && !canApply(pattern)) - return failure(); - - // Try to match and rewrite this pattern. The patterns are sorted by - // benefit, so if we match we can immediately rewrite. - rewriter.setInsertionPoint(op); - if (succeeded(pattern.matchAndRewrite(op, rewriter))) - return success(!onSuccess || succeeded(onSuccess(pattern))); - - if (onFailure) - onFailure(pattern); + // Perform any necessary cleanups. + if (onFailure) + onFailure(*bestPattern); + } return failure(); } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -0,0 +1,785 @@ +// RUN: mlir-opt %s -test-pdl-bytecode-pass -split-input-file | FileCheck %s + +// Note: Tests here are written using the PDL Interpreter dialect to avoid +// unnecessarily testing unnecessary aspects of the pattern compilation +// pipeline. These tests are written such that we can focus solely on the +// lowering/execution of the bytecode itself. + +//===----------------------------------------------------------------------===// +// pdl_interp::ApplyConstraintOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.apply_constraint "multi_entity_constraint"(%root, %root : !pdl.operation, !pdl.operation) -> ^pat, ^end + + ^pat: + pdl_interp.apply_constraint "single_entity_constraint"(%root : !pdl.operation) -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_constraint_1 +// CHECK: "test.success" +module @ir attributes { test.apply_constraint_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::ApplyRewriteOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %operand = pdl_interp.get_operand 0 of %root + pdl_interp.apply_rewrite "rewriter"[42](%operand : !pdl.value) on %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.apply_rewrite_1 +// CHECK: %[[INPUT:.*]] = "test.op_input" +// CHECK-NOT: "test.op" +// CHECK: "test.success"(%[[INPUT]]) {constantParams = [42]} +module @ir attributes { test.apply_rewrite_1 } { + %input = "test.op_input"() : () -> i32 + "test.op"(%input) : (i32) -> () +} +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::AreEqualOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %test_attr = pdl_interp.create_attribute unit + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.are_equal %test_attr, %attr : !pdl.attribute -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.are_equal_1 +// CHECK: "test.success" +module @ir attributes { test.are_equal_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::BranchOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end + + ^pat1: + pdl_interp.branch ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.branch_1 +// CHECK: "test.success" +module @ir attributes { test.branch_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckAttributeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.check_attribute %attr is unit -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_attribute_1 +// CHECK: "test.success" +module @ir attributes { test.check_attribute_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckOperandCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operand_count of %root is 1 -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_operand_count_1 +// CHECK: "test.op"() : () -> i32 +// CHECK: "test.success" +module @ir attributes { test.check_operand_count_1 } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckOperationNameOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_operation_name_1 +// CHECK: "test.success" +module @ir attributes { test.check_operation_name_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckResultCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is 1 -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_result_count_1 +// CHECK: "test.success"() : () -> () +module @ir attributes { test.check_result_count_1 } { + "test.op"() : () -> i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CheckTypeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end + + ^pat1: + %type = pdl_interp.get_attribute_type of %attr + pdl_interp.check_type %type is i32 -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.check_type_1 +// CHECK: "test.success" +module @ir attributes { test.check_type_1 } { + "test.op"() { test_attr = 10 : i32 } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateAttributeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateNativeOp +//===----------------------------------------------------------------------===// + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_native "creator"(%root : !pdl.operation) : !pdl.operation + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.create_native_1 +// CHECK: "test.success" +module @ir attributes { test.create_native_1 } { + "test.op"() : () -> () +} + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateOperationOp +//===----------------------------------------------------------------------===// + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::CreateTypeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end + + ^pat1: + %test_type = pdl_interp.create_type i32 + %type = pdl_interp.get_attribute_type of %attr + pdl_interp.are_equal %type, %test_type : !pdl.type -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.create_type_1 +// CHECK: "test.success" +module @ir attributes { test.create_type_1 } { + "test.op"() { test_attr = 0 : i32 } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::EraseOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::FinalizeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetAttributeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetAttributeTypeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetDefiningOpOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operand_count of %root is 5 -> ^pat1, ^end + + ^pat1: + %operand0 = pdl_interp.get_operand 0 of %root + %operand4 = pdl_interp.get_operand 4 of %root + %defOp0 = pdl_interp.get_defining_op of %operand0 + %defOp4 = pdl_interp.get_defining_op of %operand4 + pdl_interp.are_equal %defOp0, %defOp4 : !pdl.operation -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_defining_op_1 +// CHECK: %[[OPERAND0:.*]] = "test.op" +// CHECK: %[[OPERAND1:.*]] = "test.op" +// CHECK: "test.success" +// CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND1]]) +module @ir attributes { test.get_defining_op_1 } { + %operand = "test.op"() : () -> i32 + %other_operand = "test.op"() : () -> i32 + "test.op"(%operand, %operand, %operand, %operand, %operand) : (i32, i32, i32, i32, i32) -> () + "test.op"(%operand, %operand, %operand, %operand, %other_operand) : (i32, i32, i32, i32, i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::GetOperandOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetResultOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is 5 -> ^pat1, ^end + + ^pat1: + %result0 = pdl_interp.get_result 0 of %root + %result4 = pdl_interp.get_result 4 of %root + %result0_type = pdl_interp.get_value_type of %result0 + %result4_type = pdl_interp.get_value_type of %result4 + pdl_interp.are_equal %result0_type, %result4_type : !pdl.type -> ^pat2, ^end + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_result_1 +// CHECK: "test.success" +// CHECK: "test.op"() : () -> (i32, i32, i32, i32, i64) +module @ir attributes { test.get_result_1 } { + %a:5 = "test.op"() : () -> (i32, i32, i32, i32, i32) + %b:5 = "test.op"() : () -> (i32, i32, i32, i32, i64) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::GetValueTypeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::InferredTypeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::IsNotNullOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::RecordMatchOp +//===----------------------------------------------------------------------===// + +// Check that the highest benefit pattern is selected. +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@failure(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(2), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @failure(%root : !pdl.operation) { + pdl_interp.erase %root + pdl_interp.finalize + } + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.record_match_1 +// CHECK: "test.success" +module @ir attributes { test.record_match_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::ReplaceOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.op" -> ^pat, ^end + + ^pat: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %operand = pdl_interp.get_operand 0 of %root + pdl_interp.replace %root with (%operand) + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.replace_op_1 +// CHECK: %[[INPUT:.*]] = "test.op_input" +// CHECK-NOT: "test.op" +// CHECK: "test.op_consumer"(%[[INPUT]]) +module @ir attributes { test.replace_op_1 } { + %input = "test.op_input"() : () -> i32 + %result = "test.op"(%input) : (i32) -> i32 + "test.op_consumer"(%result) : (i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchAttributeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.switch_attribute %attr to [0, unit](^end, ^pat) -> ^end + + ^pat: + %attr_2 = pdl_interp.get_attribute "test_attr_2" of %root + pdl_interp.switch_attribute %attr_2 to [0, unit](^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_attribute_1 +// CHECK: "test.success" +module @ir attributes { test.switch_attribute_1 } { + "test.op"() { test_attr } : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperandCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.switch_operand_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end + + ^pat: + pdl_interp.switch_operand_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_operand_1 +// CHECK: "test.success" +module @ir attributes { test.switch_operand_1 } { + %input = "test.op_input"() : () -> i32 + "test.op"(%input) : (i32) -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchOperationNameOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.switch_operation_name of %root to ["foo.op", "test.op"](^end, ^pat1) -> ^end + + ^pat1: + pdl_interp.switch_operation_name of %root to ["foo.op", "bar.op"](^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_operation_name_1 +// CHECK: "test.success" +module @ir attributes { test.switch_operation_name_1 } { + "test.op"() : () -> () +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchResultCountOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.switch_result_count of %root to dense<[0, 1]> : vector<2xi32>(^end, ^pat) -> ^end + + ^pat: + pdl_interp.switch_result_count of %root to dense<[0, 2]> : vector<2xi32>(^end, ^end) -> ^pat2 + + ^pat2: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_result_1 +// CHECK: "test.success" +module @ir attributes { test.switch_result_1 } { + "test.op"() : () -> i32 +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::SwitchTypeOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %attr = pdl_interp.get_attribute "test_attr" of %root + pdl_interp.is_not_null %attr : !pdl.attribute -> ^pat1, ^end + + ^pat1: + %type = pdl_interp.get_attribute_type of %attr + pdl_interp.switch_type %type to [i32, i64](^pat2, ^end) -> ^end + + ^pat2: + pdl_interp.switch_type %type to [i16, i64](^end, ^end) -> ^pat3 + + ^pat3: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success"() -> () + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.switch_type_1 +// CHECK: "test.success" +module @ir attributes { test.switch_type_1 } { + "test.op"() { test_attr = 10 : i32 } : () -> () +} diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -0,0 +1,83 @@ +//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +/// Custom constraint invoked from PDL. +static LogicalResult customSingleEntityConstraint(PDLValue value, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + Operation *rootOp = value.cast(); + return success(rootOp->getName().getStringRef() == "test.op"); +} +static LogicalResult customMultiEntityConstraint(ArrayRef values, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + return customSingleEntityConstraint(values[1], constantParams, rewriter); +} + +// Custom creator invoked from PDL. +static PDLValue customCreate(ArrayRef args, ArrayAttr constantParams, + PatternRewriter &rewriter) { + return rewriter.createOperation( + OperationState(args[0].cast()->getLoc(), "test.success")); +} + +/// Custom rewriter invoked from PDL. +static void customRewriter(Operation *root, ArrayRef args, + ArrayAttr constantParams, + PatternRewriter &rewriter) { + OperationState successOpState(root->getLoc(), "test.success"); + successOpState.addOperands(args[0].cast()); + successOpState.addAttribute("constantParams", constantParams); + rewriter.createOperation(successOpState); + rewriter.eraseOp(root); +} + +namespace { +struct TestPDLByteCodePass + : public PassWrapper> { + void runOnOperation() final { + ModuleOp module = getOperation(); + + // The test cases are encompassed via two modules, one containing the + // patterns and one containing the operations to rewrite. + ModuleOp patternModule = module.lookupSymbol("patterns"); + ModuleOp irModule = module.lookupSymbol("ir"); + if (!patternModule || !irModule) + return; + + // Process the pattern module. + patternModule.getOperation()->removeFromParent(); + PDLPatternModule pdlPattern(patternModule); + pdlPattern.registerConstraintFunction("multi_entity_constraint", + customMultiEntityConstraint); + pdlPattern.registerConstraintFunction("single_entity_constraint", + customSingleEntityConstraint); + pdlPattern.registerCreateFunction("creator", customCreate); + pdlPattern.registerRewriteFunction("rewriter", customRewriter); + + OwningRewritePatternList patternList(std::move(pdlPattern)); + + // Invoke the pattern driver with the provided patterns. + (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), + std::move(patternList)); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTestPDLByteCodePass() { + PassRegistration("test-pdl-bytecode-pass", + "Test PDL ByteCode functionality"); +} +} // namespace mlir diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -220,18 +220,21 @@ FuncOp funcOp, StringRef startMarker, SmallVectorImpl &patternsVector) { MLIRContext *ctx = funcOp.getContext(); - patternsVector.emplace_back(LinalgTilingPattern( + patternsVector.emplace_back(std::make_unique>( ctx, LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}), LinalgMarker(Identifier::get(startMarker, ctx), Identifier::get("L1", ctx)))); - patternsVector.emplace_back(LinalgPromotionPattern( - ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx)))); + patternsVector.emplace_back( + std::make_unique>( + ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), + LinalgMarker(Identifier::get("L1", ctx), + Identifier::get("VEC", ctx)))); - patternsVector.emplace_back(LinalgVectorizationPattern( - ctx, LinalgMarker(Identifier::get("VEC", ctx)))); + patternsVector.emplace_back( + std::make_unique>( + ctx, LinalgMarker(Identifier::get("VEC", ctx)))); patternsVector.back() .insert, LinalgVectorizationPattern>(ctx); @@ -421,7 +424,7 @@ fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx), stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { - stage1Patterns.emplace_back(LinalgTilingPattern( + stage1Patterns.emplace_back(std::make_unique>( ctx, LinalgTilingOptions() .setTileSizes({768, 264, 768}) diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -70,6 +70,7 @@ void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); void registerTestOpaqueLoc(); +void registerTestPDLByteCodePass(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestPrintDefUsePass(); void registerTestPrintNestingPass(); @@ -126,6 +127,7 @@ registerTestMemRefDependenceCheck(); registerTestMemRefStrideCalculation(); registerTestOpaqueLoc(); + registerTestPDLByteCodePass(); registerTestPreparationPassWithAllowedMemrefResults(); registerTestPrintDefUsePass(); registerTestPrintNestingPass();