diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -963,7 +963,7 @@ let builders = [ OpBuilder<(ins "Value":$operation, "ArrayRef":$counts, - "Block *":$defaultDest, "BlockRange":$dests), [{ + "Block *":$defaultDest, "BlockRange":$dests), [{ build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts), defaultDest, dests); }]>]; @@ -1033,7 +1033,7 @@ let builders = [ OpBuilder<(ins "Value":$operation, "ArrayRef":$counts, - "Block *":$defaultDest, "BlockRange":$dests), [{ + "Block *":$defaultDest, "BlockRange":$dests), [{ build($_builder, $_state, operation, $_builder.getI32VectorAttr(counts), defaultDest, dests); }]>]; diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -238,63 +238,96 @@ /// Storage type of byte-code interpreter values. These are passed to constraint /// functions as arguments. class PDLValue { - /// The internal implementation type when the value is an Attribute, - /// Operation*, or Type. See `impl` below for more details. - using AttrOpTypeImplT = llvm::PointerUnion; - public: - PDLValue(const PDLValue &other) : impl(other.impl) {} - PDLValue(std::nullptr_t = nullptr) : impl() {} - PDLValue(Attribute value) : impl(value) {} - PDLValue(Operation *value) : impl(value) {} - PDLValue(Type value) : impl(value) {} - PDLValue(Value value) : impl(value) {} + /// The underlying kind of a PDL value. + enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange }; + + /// Construct a new PDL value. + PDLValue(const PDLValue &other) = default; + PDLValue(std::nullptr_t = nullptr) : value(nullptr), kind(Kind::Attribute) {} + PDLValue(Attribute value) + : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {} + PDLValue(Operation *value) : value(value), kind(Kind::Operation) {} + PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {} + PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {} + PDLValue(Value value) + : value(value.getAsOpaquePointer()), kind(Kind::Value) {} + PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {} /// Returns true if the type of the held value is `T`. template - std::enable_if_t::value, bool> isa() const { - return impl.is(); - } - template - std::enable_if_t::value, bool> isa() const { - auto attrOpTypeImpl = impl.dyn_cast(); - return attrOpTypeImpl && attrOpTypeImpl.is(); + bool isa() const { + assert(value && "isa<> used on a null value"); + return kind == getKindOf(); } /// Attempt to dynamically cast this value to type `T`, returns null if this /// value is not an instance of `T`. - template - std::enable_if_t::value, T> dyn_cast() const { - return impl.dyn_cast(); - } - template - std::enable_if_t::value, T> dyn_cast() const { - auto attrOpTypeImpl = impl.dyn_cast(); - return attrOpTypeImpl && attrOpTypeImpl.dyn_cast(); + template ::value, T, Optional>> + ResultT dyn_cast() const { + return isa() ? castImpl() : ResultT(); } /// Cast this value to type `T`, asserts if this value is not an instance of /// `T`. template - std::enable_if_t::value, T> cast() const { - return impl.get(); - } - template - std::enable_if_t::value, T> cast() const { - return impl.get().get(); + T cast() const { + assert(isa() && "expected value to be of type `T`"); + return castImpl(); } /// Get an opaque pointer to the value. - void *getAsOpaquePointer() { return impl.getOpaqueValue(); } + const void *getAsOpaquePointer() const { return value; } + + /// Return if this value is null or not. + explicit operator bool() const { return value; } + + /// Return the kind of this value. + Kind getKind() const { return kind; } /// Print this value to the provided output stream. - void print(raw_ostream &os); + void print(raw_ostream &os) const; private: - /// The internal opaque representation of a PDLValue. We use a nested - /// PointerUnion structure here because `Value` only has 1 low bit - /// available, where as the remaining types all have 3. - llvm::PointerUnion impl; + /// Find the index of a given type in a range of other types. + template + struct index_of_t; + template + struct index_of_t : std::integral_constant {}; + template + struct index_of_t + : std::integral_constant::value> {}; + + /// Return the kind used for the given T. + template + static Kind getKindOf() { + return static_cast(index_of_t::value); + } + + /// The internal implementation of `cast`, that returns the underlying value + /// as the given type `T`. + template + std::enable_if_t::value, T> + castImpl() const { + return T::getFromOpaquePointer(value); + } + template + std::enable_if_t::value, T> + castImpl() const { + return *reinterpret_cast(const_cast(value)); + } + template + std::enable_if_t::value, T> castImpl() const { + return reinterpret_cast(const_cast(value)); + } + + /// The internal opaque representation of a PDLValue. + const void *value; + /// The kind of the opaque value. + Kind kind; }; inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { @@ -319,14 +352,66 @@ /// Push a new Type onto the result list. void push_back(Type value) { results.push_back(value); } + /// Push a new TypeRange onto the result list. + void push_back(TypeRange value) { + // The lifetime of a TypeRange can't be guaranteed, so we'll need to + // allocate a storage for it. + llvm::OwningArrayRef storage(value.size()); + llvm::copy(value, storage.begin()); + allocatedTypeRanges.emplace_back(std::move(storage)); + typeRanges.push_back(allocatedTypeRanges.back()); + results.push_back(&typeRanges.back()); + } + void push_back(ValueTypeRange value) { + typeRanges.push_back(value); + results.push_back(&typeRanges.back()); + } + void push_back(ValueTypeRange value) { + typeRanges.push_back(value); + results.push_back(&typeRanges.back()); + } + /// Push a new Value onto the result list. void push_back(Value value) { results.push_back(value); } + /// Push a new ValueRange onto the result list. + void push_back(ValueRange value) { + // The lifetime of a ValueRange can't be guaranteed, so we'll need to + // allocate a storage for it. + llvm::OwningArrayRef storage(value.size()); + llvm::copy(value, storage.begin()); + allocatedValueRanges.emplace_back(std::move(storage)); + valueRanges.push_back(allocatedValueRanges.back()); + results.push_back(&valueRanges.back()); + } + void push_back(OperandRange value) { + valueRanges.push_back(value); + results.push_back(&valueRanges.back()); + } + void push_back(ResultRange value) { + valueRanges.push_back(value); + results.push_back(&valueRanges.back()); + } + protected: - PDLResultList() = default; + /// Create a new result list with the expected number of results. + PDLResultList(unsigned maxNumResults) { + // For now just reserve enough space for all of the results. We could do + // separate counts per range type, but it isn't really worth it unless there + // are a "large" number of results. + typeRanges.reserve(maxNumResults); + valueRanges.reserve(maxNumResults); + } /// The PDL results held by this list. SmallVector results; + /// Memory used to store ranges held by the list. + SmallVector typeRanges; + SmallVector valueRanges; + /// Memory allocated to store ranges in the result list whose lifetime was + /// generated in the native function. + SmallVector> allocatedTypeRanges; + SmallVector> allocatedValueRanges; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h --- a/mlir/include/mlir/IR/TypeRange.h +++ b/mlir/include/mlir/IR/TypeRange.h @@ -82,6 +82,12 @@ return ::llvm::hash_combine_range(arg.begin(), arg.end()); } +/// Emit a type range to the given output stream. +inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) { + llvm::interleaveComma(types, os); + return os; +} + //===----------------------------------------------------------------------===// // ValueTypeRange diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -73,22 +73,31 @@ // PDLValue //===----------------------------------------------------------------------===// -void PDLValue::print(raw_ostream &os) { - if (!impl) { - os << ""; +void PDLValue::print(raw_ostream &os) const { + if (!value) { + os << ""; return; } - if (Value val = impl.dyn_cast()) { - os << val; - return; + switch (kind) { + case Kind::Attribute: + os << cast(); + break; + case Kind::Operation: + os << *cast(); + break; + case Kind::Type: + os << cast(); + break; + case Kind::TypeRange: + llvm::interleaveComma(cast(), os); + break; + case Kind::Value: + os << cast(); + break; + case Kind::ValueRange: + llvm::interleaveComma(cast(), os); + break; } - AttrOpTypeImplT aotImpl = impl.get(); - if (Attribute attr = aotImpl.dyn_cast()) - os << attr; - else if (Operation *op = aotImpl.dyn_cast()) - os << *op; - else - os << aotImpl.get(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -25,8 +25,7 @@ class PDLByteCode; /// Use generic bytecode types. ByteCodeField refers to the actual bytecode -/// entries (set to uint8_t for "byte" bytecode). ByteCodeAddr refers to size of -/// indices into the bytecode. Correctness is checked with static asserts. +/// entries. ByteCodeAddr refers to size of indices into the bytecode. using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; @@ -62,14 +61,16 @@ /// threads/drivers. class PDLByteCodeMutableState { public: - /// Initialize the state from a bytecode instance. - void initialize(PDLByteCode &bytecode); - /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds /// to the position of the pattern within the range returned by /// `PDLByteCode::getPatterns`. void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); + /// Cleanup any allocated state after a match/rewrite has been completed. This + /// method should be called irregardless of whether the match+rewrite was a + /// success or not. + void cleanupAfterMatchAndRewrite(); + private: /// Allow access to data fields. friend class PDLByteCode; @@ -78,6 +79,20 @@ /// of the bytecode. std::vector memory; + /// A mutable block of memory used during the matching and rewriting phase of + /// the bytecode to store ranges of types. + std::vector typeRangeMemory; + /// A set of type ranges that have been allocated by the byte code interpreter + /// to provide a guaranteed lifetime. + std::vector> allocatedTypeRangeMemory; + + /// A mutable block of memory used during the matching and rewriting phase of + /// the bytecode to store ranges of values. + std::vector valueRangeMemory; + /// A set of value ranges that have been allocated by the byte code + /// interpreter to provide a guaranteed lifetime. + std::vector> allocatedValueRangeMemory; + /// The up-to-date benefits of the patterns held by the bytecode. The order /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. std::vector currentPatternBenefits; @@ -98,11 +113,19 @@ MatchResult(Location loc, const PDLByteCodePattern &pattern, PatternBenefit benefit) : location(loc), pattern(&pattern), benefit(benefit) {} + MatchResult(const MatchResult &) = delete; + MatchResult &operator=(const MatchResult &) = delete; + MatchResult(MatchResult &&other) = default; + MatchResult &operator=(MatchResult &&) = default; /// The location of operations to be replaced. Location location; /// Memory values defined in the matcher that are passed to the rewriter. - SmallVector values; + SmallVector values; + /// Memory used for the range input values. + SmallVector typeRangeValues; + SmallVector valueRangeValues; + /// The originating pattern that was matched. This is always non-null, but /// represented with a pointer to allow for assignment. const PDLByteCodePattern *pattern; @@ -163,6 +186,10 @@ /// The maximum memory index used by a value. ByteCodeField maxValueMemoryIndex = 0; + + /// The maximum number of different types of ranges. + ByteCodeField maxTypeRangeCount = 0; + ByteCodeField maxValueRangeCount = 0; }; } // end namespace detail diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -20,6 +20,9 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/FormatVariadic.h" +#include #define DEBUG_TYPE "pdl-bytecode" @@ -60,6 +63,14 @@ currentPatternBenefits[patternIndex] = benefit; } +/// Cleanup any allocated state after a full match/rewrite has been completed. +/// This method should be called irregardless of whether the match+rewrite was a +/// success or not. +void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { + allocatedTypeRangeMemory.clear(); + allocatedValueRangeMemory.clear(); +} + //===----------------------------------------------------------------------===// // Bytecode OpCodes //===----------------------------------------------------------------------===// @@ -72,6 +83,8 @@ ApplyRewrite, /// Check if two generic values are equal. AreEqual, + /// Check if two ranges are equal. + AreRangesEqual, /// Unconditional branch. Branch, /// Compare the operand count of an operation with a constant. @@ -80,8 +93,12 @@ CheckOperationName, /// Compare the result count of an operation with a constant. CheckResultCount, + /// Compare a range of types to a constant range of types. + CheckTypes, /// Create an operation. CreateOperation, + /// Create a range of types. + CreateTypes, /// Erase an operation. EraseOp, /// Terminate a matcher or rewrite sequence. @@ -98,14 +115,20 @@ GetOperand2, GetOperand3, GetOperandN, + /// Get a specific operand group of an operation. + GetOperands, /// Get a specific result of an operation. GetResult0, GetResult1, GetResult2, GetResult3, GetResultN, + /// Get a specific result group of an operation. + GetResults, /// Get the type of a value. GetValueType, + /// Get the types of a value range. + GetValueRangeTypes, /// Check if a generic value is not null. IsNotNull, /// Record a successful pattern match. @@ -122,9 +145,9 @@ SwitchResultCount, /// Compare a type with a set of constants. SwitchType, + /// Compare a range of types with a set of constants. + SwitchTypes, }; - -enum class PDLValueKind { Attribute, Operation, Type, Value }; } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -145,11 +168,15 @@ SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, + ByteCodeField &maxTypeRangeMemoryIndex, + ByteCodeField &maxValueRangeMemoryIndex, llvm::StringMap &constraintFns, llvm::StringMap &rewriteFns) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), - maxValueMemoryIndex(maxValueMemoryIndex) { + maxValueMemoryIndex(maxValueMemoryIndex), + maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), + maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { for (auto it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(rewriteFns)) @@ -166,6 +193,13 @@ return valueToMemIndex[value]; } + /// Return the range memory index used to store the given range value. + ByteCodeField &getRangeStorageIndex(Value value) { + assert(valueToRangeIndex.count(value) && + "expected range index to be assigned"); + return valueToRangeIndex[value]; + } + /// Return an index to use when referring to the given data that is uniqued in /// the MLIR context. template @@ -197,16 +231,20 @@ void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); @@ -214,6 +252,7 @@ void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer); @@ -221,6 +260,9 @@ /// Mapping from value to its corresponding memory index. DenseMap valueToMemIndex; + /// Mapping from a range value to its corresponding range storage index. + DenseMap valueToRangeIndex; + /// Mapping from the name of an externally registered rewrite to its index in /// the bytecode registry. llvm::StringMap externalRewriterToMemIndex; @@ -246,6 +288,8 @@ SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; + ByteCodeField &maxTypeRangeMemoryIndex; + ByteCodeField &maxValueRangeMemoryIndex; }; /// This class provides utilities for writing a bytecode stream. @@ -281,19 +325,33 @@ /// Append a range of values that will be read as generic PDLValues. void appendPDLValueList(OperandRange values) { bytecode.push_back(values.size()); - for (Value value : values) { - // Append the type of the value in addition to the value itself. - PDLValueKind kind = - TypeSwitch(value.getType()) - .Case( - [](Type) { return PDLValueKind::Attribute; }) - .Case( - [](Type) { return PDLValueKind::Operation; }) - .Case([](Type) { return PDLValueKind::Type; }) - .Case([](Type) { return PDLValueKind::Value; }); - bytecode.push_back(static_cast(kind)); - append(value); - } + for (Value value : values) + appendPDLValue(value); + } + + /// Append a value as a PDLValue. + void appendPDLValue(Value value) { + appendPDLValueKind(value); + append(value); + } + + /// Append the PDLValue::Kind of the given value. + void appendPDLValueKind(Value value) { + // Append the type of the value in addition to the value itself. + PDLValue::Kind kind = + TypeSwitch(value.getType()) + .Case( + [](Type) { return PDLValue::Kind::Attribute; }) + .Case( + [](Type) { return PDLValue::Kind::Operation; }) + .Case([](pdl::RangeType rangeTy) { + if (rangeTy.getElementType().isa()) + return PDLValue::Kind::TypeRange; + return PDLValue::Kind::ValueRange; + }) + .Case([](Type) { return PDLValue::Kind::Type; }) + .Case([](Type) { return PDLValue::Kind::Value; }); + bytecode.push_back(static_cast(kind)); } /// Check if the given class `T` has an iterator type. @@ -334,6 +392,36 @@ /// The main generator producing PDL. Generator &generator; }; + +/// This class represents a live range of PDL Interpreter values, containing +/// information about when values are live within a match/rewrite. +struct ByteCodeLiveRange { + using Set = llvm::IntervalMap; + using Allocator = Set::Allocator; + + ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} + + /// Union this live range with the one provided. + void unionWith(const ByteCodeLiveRange &rhs) { + for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) + liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); + } + + /// Returns true if this range overlaps with the one provided. + bool overlaps(const ByteCodeLiveRange &rhs) const { + return llvm::IntervalMapOverlaps(liveness, rhs.liveness).valid(); + } + + /// A map representing the ranges of the match/rewrite that a value is live in + /// the interpreter. + llvm::IntervalMap liveness; + + /// The type range storage index for this range. + Optional typeRangeIndex; + + /// The value range storage index for this range. + Optional valueRangeIndex; +}; } // end anonymous namespace void Generator::generate(ModuleOp module) { @@ -381,15 +469,30 @@ // Rewriters use simplistic allocation scheme that simply assigns an index to // each result. for (FuncOp rewriterFunc : rewriterModule.getOps()) { - ByteCodeField index = 0; + ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0; + auto processRewriterValue = [&](Value val) { + valueToMemIndex.try_emplace(val, index++); + if (pdl::RangeType rangeType = val.getType().dyn_cast()) { + Type elementTy = rangeType.getElementType(); + if (elementTy.isa()) + valueToRangeIndex.try_emplace(val, typeRangeIndex++); + else if (elementTy.isa()) + valueToRangeIndex.try_emplace(val, valueRangeIndex++); + } + }; + for (BlockArgument arg : rewriterFunc.getArguments()) - valueToMemIndex.try_emplace(arg, index++); + processRewriterValue(arg); rewriterFunc.getBody().walk([&](Operation *op) { for (Value result : op->getResults()) - valueToMemIndex.try_emplace(result, index++); + processRewriterValue(result); }); if (index > maxValueMemoryIndex) maxValueMemoryIndex = index; + if (typeRangeIndex > maxTypeRangeMemoryIndex) + maxTypeRangeMemoryIndex = typeRangeIndex; + if (valueRangeIndex > maxValueRangeMemoryIndex) + maxValueRangeMemoryIndex = valueRangeIndex; } // The matcher function uses a more sophisticated numbering that tries to @@ -404,9 +507,8 @@ }); // Liveness info for each of the defs within the matcher. - using LivenessSet = llvm::IntervalMap; - LivenessSet::Allocator allocator; - DenseMap valueDefRanges; + ByteCodeLiveRange::Allocator allocator; + DenseMap valueDefRanges; // Assign the root operation being matched to slot 0. BlockArgument rootOpArg = matcherFunc.getArgument(0); @@ -425,10 +527,19 @@ // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; - defRangeIt->second.insert( + defRangeIt->second.liveness.insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); + + // Check to see if this value is a range type. + if (auto rangeTy = value.getType().dyn_cast()) { + Type eleType = rangeTy.getElementType(); + if (eleType.isa()) + defRangeIt->second.typeRangeIndex = 0; + else if (eleType.isa()) + defRangeIt->second.valueRangeIndex = 0; + } }; // Process the live-ins of this block. @@ -442,37 +553,59 @@ } // Greedily allocate memory slots using the computed def live ranges. - std::vector allocatedIndices; + std::vector allocatedIndices; + ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; - LivenessSet &defSet = defIt.second; + ByteCodeLiveRange &defRange = defIt.second; // Try to allocate to an existing index. for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) { - LivenessSet &existingIndex = existingIndexIt.value(); - llvm::IntervalMapOverlaps overlaps( - defIt.second, existingIndex); - if (overlaps.valid()) - continue; - // Union the range of the def within the existing index. - for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) - existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0); - memIndex = existingIndexIt.index() + 1; + ByteCodeLiveRange &existingRange = existingIndexIt.value(); + if (!defRange.overlaps(existingRange)) { + existingRange.unionWith(defRange); + memIndex = existingIndexIt.index() + 1; + + if (defRange.typeRangeIndex) { + if (!existingRange.typeRangeIndex) + existingRange.typeRangeIndex = numTypeRanges++; + valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; + } else if (defRange.valueRangeIndex) { + if (!existingRange.valueRangeIndex) + existingRange.valueRangeIndex = numValueRanges++; + valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex; + } + break; + } } // If no existing index could be used, add a new one. if (memIndex == 0) { allocatedIndices.emplace_back(allocator); - for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it) - allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0); + ByteCodeLiveRange &newRange = allocatedIndices.back(); + newRange.unionWith(defRange); + + // Allocate an index for type/value ranges. + if (defRange.typeRangeIndex) { + newRange.typeRangeIndex = numTypeRanges; + valueToRangeIndex[defIt.first] = numTypeRanges++; + } else if (defRange.valueRangeIndex) { + newRange.valueRangeIndex = numValueRanges; + valueToRangeIndex[defIt.first] = numValueRanges++; + } + memIndex = allocatedIndices.size(); + ++numIndices; } } // Update the max number of indices. - ByteCodeField numMatcherIndices = allocatedIndices.size() + 1; - if (numMatcherIndices > maxValueMemoryIndex) - maxValueMemoryIndex = numMatcherIndices; + if (numIndices > maxValueMemoryIndex) + maxValueMemoryIndex = numIndices; + if (numTypeRanges > maxTypeRangeMemoryIndex) + maxTypeRangeMemoryIndex = numTypeRanges; + if (numValueRanges > maxValueRangeMemoryIndex) + maxValueRangeMemoryIndex = numValueRanges; } void Generator::generate(Operation *op, ByteCodeWriter &writer) { @@ -481,17 +614,19 @@ pdl_interp::AreEqualOp, pdl_interp::BranchOp, pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, - pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp, - pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, + pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, + pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, + pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, pdl_interp::EraseOp, pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp, - pdl_interp::GetResultOp, pdl_interp::GetValueTypeOp, + pdl_interp::GetOperandsOp, pdl_interp::GetResultOp, + pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp, pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, - pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, - pdl_interp::SwitchResultCountOp>( + pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, + pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); @@ -515,16 +650,31 @@ op.constParamsAttr()); writer.appendPDLValueList(op.args()); + ResultRange results = op.results(); + writer.append(ByteCodeField(results.size())); + for (Value result : results) { + // In debug mode we also record the expected kind of the result, so that we + // can provide extra verification of the native rewrite function. #ifndef NDEBUG - // In debug mode we also append the number of results so that we can assert - // that the native creation function gave us the correct number of results. - writer.append(ByteCodeField(op.results().size())); + writer.appendPDLValueKind(result); #endif - for (Value result : op.results()) + + // Range results also need to append the range storage index. + if (result.getType().isa()) + writer.append(getRangeStorageIndex(result)); writer.append(result); + } } void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) { - writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors()); + Value lhs = op.lhs(); + if (lhs.getType().isa()) { + writer.append(OpCode::AreRangesEqual); + writer.appendPDLValueKind(lhs); + writer.append(op.lhs(), op.rhs(), op.getSuccessors()); + return; + } + + writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors()); } void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) { writer.append(OpCode::Branch, SuccessorRange(op.getOperation())); @@ -537,6 +687,7 @@ void Generator::generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckOperandCount, op.operation(), op.count(), + static_cast(op.compareAtLeast()), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckOperationNameOp op, @@ -547,11 +698,15 @@ void Generator::generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckResultCount, op.operation(), op.count(), + static_cast(op.compareAtLeast()), op.getSuccessors()); } void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) { writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors()); } +void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { + writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); +} void Generator::generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. @@ -560,7 +715,8 @@ void Generator::generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer) { writer.append(OpCode::CreateOperation, op.operation(), - OperationName(op.name(), ctx), op.operands()); + OperationName(op.name(), ctx)); + writer.appendPDLValueList(op.operands()); // Add the attributes. OperandRange attributes = op.attributes(); @@ -570,12 +726,16 @@ Identifier::get(std::get<0>(it).cast().getValue(), ctx), std::get<1>(it)); } - writer.append(op.types()); + writer.appendPDLValueList(op.types()); } void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.result()) = getMemIndex(op.value()); } +void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { + writer.append(OpCode::CreateTypes, op.result(), + getRangeStorageIndex(op.result()), op.value()); +} void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { writer.append(OpCode::EraseOp, op.operation()); } @@ -593,7 +753,8 @@ } void Generator::generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer) { - writer.append(OpCode::GetDefiningOp, op.operation(), op.value()); + writer.append(OpCode::GetDefiningOp, op.operation()); + writer.appendPDLValue(op.value()); } void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); @@ -603,6 +764,18 @@ writer.append(OpCode::GetOperandN, index); writer.append(op.operation(), op.value()); } +void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) { + Value result = op.value(); + Optional index = op.index(); + writer.append(OpCode::GetOperands, + index.getValueOr(std::numeric_limits::max()), + op.operation()); + if (result.getType().isa()) + writer.append(getRangeStorageIndex(result)); + else + writer.append(std::numeric_limits::max()); + writer.append(result); +} void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); if (index < 4) @@ -611,10 +784,29 @@ writer.append(OpCode::GetResultN, index); writer.append(op.operation(), op.value()); } +void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) { + Value result = op.value(); + Optional index = op.index(); + writer.append(OpCode::GetResults, + index.getValueOr(std::numeric_limits::max()), + op.operation()); + if (result.getType().isa()) + writer.append(getRangeStorageIndex(result)); + else + writer.append(std::numeric_limits::max()); + writer.append(result); +} void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { - writer.append(OpCode::GetValueType, op.result(), op.value()); + if (op.getType().isa()) { + Value result = op.result(); + writer.append(OpCode::GetValueRangeTypes, result, + getRangeStorageIndex(result), op.value()); + } else { + writer.append(OpCode::GetValueType, op.result(), op.value()); + } } + void Generator::generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer) { // InferType maps to a null type as a marker for inferring result types. @@ -628,11 +820,12 @@ patterns.emplace_back(PDLByteCodePattern::create( op, rewriterToAddr[op.rewriter().getLeafReference()])); writer.append(OpCode::RecordMatch, patternIndex, - SuccessorRange(op.getOperation()), op.matchedOps(), - op.inputs()); + SuccessorRange(op.getOperation()), op.matchedOps()); + writer.appendPDLValueList(op.inputs()); } void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) { - writer.append(OpCode::ReplaceOp, op.operation(), op.replValues()); + writer.append(OpCode::ReplaceOp, op.operation()); + writer.appendPDLValueList(op.replValues()); } void Generator::generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer) { @@ -661,6 +854,10 @@ writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(), op.getSuccessors()); } +void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) { + writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(), + op.getSuccessors()); +} //===----------------------------------------------------------------------===// // PDLByteCode @@ -671,7 +868,8 @@ llvm::StringMap rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, - constraintFns, rewriteFns); + maxTypeRangeCount, maxValueRangeCount, constraintFns, + rewriteFns); generator.generate(module); // Initialize the external functions. @@ -685,6 +883,8 @@ /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); + state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); + state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); state.currentPatternBenefits.reserve(patterns.size()); for (const PDLByteCodePattern &pattern : patterns) state.currentPatternBenefits.push_back(pattern.getBenefit()); @@ -697,17 +897,24 @@ /// This class provides support for executing a bytecode stream. class ByteCodeExecutor { public: - ByteCodeExecutor(const ByteCodeField *curCodeIt, - MutableArrayRef memory, - ArrayRef uniquedMemory, - ArrayRef code, - ArrayRef currentPatternBenefits, - ArrayRef patterns, - ArrayRef constraintFunctions, - ArrayRef rewriteFunctions) - : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory), - code(code), currentPatternBenefits(currentPatternBenefits), - patterns(patterns), constraintFunctions(constraintFunctions), + ByteCodeExecutor( + const ByteCodeField *curCodeIt, MutableArrayRef memory, + MutableArrayRef typeRangeMemory, + std::vector> &allocatedTypeRangeMemory, + MutableArrayRef valueRangeMemory, + std::vector> &allocatedValueRangeMemory, + ArrayRef uniquedMemory, ArrayRef code, + ArrayRef currentPatternBenefits, + ArrayRef patterns, + ArrayRef constraintFunctions, + ArrayRef rewriteFunctions) + : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), + allocatedTypeRangeMemory(allocatedTypeRangeMemory), + valueRangeMemory(valueRangeMemory), + allocatedValueRangeMemory(allocatedValueRangeMemory), + uniquedMemory(uniquedMemory), code(code), + currentPatternBenefits(currentPatternBenefits), patterns(patterns), + constraintFunctions(constraintFunctions), rewriteFunctions(rewriteFunctions) {} /// Start executing the code at the current bytecode index. `matches` is an @@ -722,19 +929,25 @@ void executeApplyConstraint(PatternRewriter &rewriter); void executeApplyRewrite(PatternRewriter &rewriter); void executeAreEqual(); + void executeAreRangesEqual(); void executeBranch(); void executeCheckOperandCount(); void executeCheckOperationName(); void executeCheckResultCount(); + void executeCheckTypes(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); + void executeCreateTypes(); void executeEraseOp(PatternRewriter &rewriter); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); void executeGetOperand(unsigned index); + void executeGetOperands(); void executeGetResult(unsigned index); + void executeGetResults(); void executeGetValueType(); + void executeGetValueRangeTypes(); void executeIsNotNull(); void executeRecordMatch(PatternRewriter &rewriter, SmallVectorImpl &matches); @@ -744,6 +957,7 @@ void executeSwitchOperationName(); void executeSwitchResultCount(); void executeSwitchType(); + void executeSwitchTypes(); /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point @@ -763,6 +977,19 @@ list.push_back(read()); } + /// Read a list of values from the bytecode buffer. The values may be encoded + /// as either Value or ValueRange elements. + void readValueList(SmallVectorImpl &list) { + for (unsigned i = 0, e = read(); i != e; ++i) { + if (read() == PDLValue::Kind::Value) { + list.push_back(read()); + } else { + ValueRange *values = read(); + list.append(values->begin(), values->end()); + } + } + } + /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. @@ -771,8 +998,8 @@ } /// Handle a switch operation with the provided value and cases. - template - void handleSwitch(const T &value, RangeT &&cases) { + template > + void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { LLVM_DEBUG({ llvm::dbgs() << " * Value: " << value << "\n" << " * Cases: "; @@ -783,7 +1010,7 @@ // Check to see if the attribute value is within the case list. Jump to // the correct successor index based on the result. for (auto it = cases.begin(), e = cases.end(); it != e; ++it) - if (*it == value) + if (cmp(*it, value)) return selectJump(size_t((it - cases.begin()) + 1)); selectJump(size_t(0)); } @@ -795,7 +1022,9 @@ size_t index = *curCodeIt++; // If this type is an SSA value, it can only be stored in non-const memory. - if (llvm::is_one_of::value || index < memory.size()) + if (llvm::is_one_of::value || + index < memory.size()) return memory[index]; // Otherwise, if this index is not inbounds it is uniqued. @@ -813,17 +1042,21 @@ } template std::enable_if_t::value, T> readImpl() { - switch (static_cast(read())) { - case PDLValueKind::Attribute: + switch (read()) { + case PDLValue::Kind::Attribute: return read(); - case PDLValueKind::Operation: + case PDLValue::Kind::Operation: return read(); - case PDLValueKind::Type: + case PDLValue::Kind::Type: return read(); - case PDLValueKind::Value: + case PDLValue::Kind::Value: return read(); + case PDLValue::Kind::TypeRange: + return read(); + case PDLValue::Kind::ValueRange: + return read(); } - llvm_unreachable("unhandled PDLValueKind"); + llvm_unreachable("unhandled PDLValue::Kind"); } template std::enable_if_t::value, T> readImpl() { @@ -838,12 +1071,20 @@ std::enable_if_t::value, T> readImpl() { return *curCodeIt++; } + template + std::enable_if_t::value, T> readImpl() { + return static_cast(readImpl()); + } /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; /// The current execution memory. MutableArrayRef memory; + MutableArrayRef typeRangeMemory; + std::vector> &allocatedTypeRangeMemory; + MutableArrayRef valueRangeMemory; + std::vector> &allocatedValueRangeMemory; /// References to ByteCode data necessary for execution. ArrayRef uniquedMemory; @@ -859,8 +1100,21 @@ /// overexposing access to information specific solely to the ByteCode. class ByteCodeRewriteResultList : public PDLResultList { public: + ByteCodeRewriteResultList(unsigned maxNumResults) + : PDLResultList(maxNumResults) {} + /// Return the list of PDL results. MutableArrayRef getResults() { return results; } + + /// Return the type ranges allocated by this list. + MutableArrayRef> getAllocatedTypeRanges() { + return allocatedTypeRanges; + } + + /// Return the value ranges allocated by this list. + MutableArrayRef> getAllocatedValueRanges() { + return allocatedValueRanges; + } }; } // end anonymous namespace @@ -893,21 +1147,46 @@ llvm::interleaveComma(args, llvm::dbgs()); llvm::dbgs() << "\n * Parameters: " << constParams << "\n"; }); - ByteCodeRewriteResultList results; + + // Execute the rewrite function. + ByteCodeField numResults = read(); + ByteCodeRewriteResultList results(numResults); rewriteFn(args, constParams, rewriter, results); - // Store the results in the bytecode memory. -#ifndef NDEBUG - ByteCodeField expectedNumberOfResults = read(); - assert(results.getResults().size() == expectedNumberOfResults && + assert(results.getResults().size() == numResults && "native PDL rewrite function returned unexpected number of results"); -#endif // Store the results in the bytecode memory. for (PDLValue &result : results.getResults()) { LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); - memory[read()] = result.getAsOpaquePointer(); + +// In debug mode we also verify the expected kind of the result. +#ifndef NDEBUG + assert(result.getKind() == read() && + "native PDL rewrite function returned an unexpected type of result"); +#endif + + // If the result is a range, we need to copy it over to the bytecodes + // range memory. + if (Optional typeRange = result.dyn_cast()) { + unsigned rangeIndex = read(); + typeRangeMemory[rangeIndex] = *typeRange; + memory[read()] = &typeRangeMemory[rangeIndex]; + } else if (Optional valueRange = + result.dyn_cast()) { + unsigned rangeIndex = read(); + valueRangeMemory[rangeIndex] = *valueRange; + memory[read()] = &valueRangeMemory[rangeIndex]; + } else { + memory[read()] = result.getAsOpaquePointer(); + } } + + // Copy over any underlying storage allocated for result ranges. + for (auto &it : results.getAllocatedTypeRanges()) + allocatedTypeRangeMemory.push_back(std::move(it)); + for (auto &it : results.getAllocatedValueRanges()) + allocatedValueRangeMemory.push_back(std::move(it)); } void ByteCodeExecutor::executeAreEqual() { @@ -919,6 +1198,32 @@ selectJump(lhs == rhs); } +void ByteCodeExecutor::executeAreRangesEqual() { + LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); + PDLValue::Kind valueKind = read(); + const void *lhs = read(); + const void *rhs = read(); + + switch (valueKind) { + case PDLValue::Kind::TypeRange: { + const TypeRange *lhsRange = reinterpret_cast(lhs); + const TypeRange *rhsRange = reinterpret_cast(rhs); + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + selectJump(*lhsRange == *rhsRange); + break; + } + case PDLValue::Kind::ValueRange: { + auto *lhsRange = reinterpret_cast(lhs); + auto *rhsRange = reinterpret_cast(rhs); + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + selectJump(*lhsRange == *rhsRange); + break; + } + default: + llvm_unreachable("unexpected `AreRangesEqual` value kind"); + } +} + void ByteCodeExecutor::executeBranch() { LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); curCodeIt = &code[read()]; @@ -928,10 +1233,16 @@ LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); Operation *op = read(); uint32_t expectedCount = read(); + bool compareAtLeast = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" - << " * Expected: " << expectedCount << "\n"); - selectJump(op->getNumOperands() == expectedCount); + << " * Expected: " << expectedCount << "\n" + << " * Comparator: " + << (compareAtLeast ? ">=" : "==") << "\n"); + if (compareAtLeast) + selectJump(op->getNumOperands() >= expectedCount); + else + selectJump(op->getNumOperands() == expectedCount); } void ByteCodeExecutor::executeCheckOperationName() { @@ -948,10 +1259,44 @@ LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); Operation *op = read(); uint32_t expectedCount = read(); + bool compareAtLeast = read(); LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" - << " * Expected: " << expectedCount << "\n"); - selectJump(op->getNumResults() == expectedCount); + << " * Expected: " << expectedCount << "\n" + << " * Comparator: " + << (compareAtLeast ? ">=" : "==") << "\n"); + if (compareAtLeast) + selectJump(op->getNumResults() >= expectedCount); + else + selectJump(op->getNumResults() == expectedCount); +} + +void ByteCodeExecutor::executeCheckTypes() { + LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); + TypeRange *lhs = read(); + Attribute rhs = read(); + LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + + selectJump(*lhs == rhs.cast().getAsValueRange()); +} + +void ByteCodeExecutor::executeCreateTypes() { + LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + ArrayAttr typesAttr = read().cast(); + + LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); + + // Allocate a buffer for this type range. + llvm::OwningArrayRef storage(typesAttr.size()); + llvm::copy(typesAttr.getAsValueRange(), storage.begin()); + allocatedTypeRangeMemory.emplace_back(std::move(storage)); + + // Assign this to the range slot and use the range as the value for the + // memory index. + typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back(); + memory[memIndex] = &typeRangeMemory[rangeIndex]; } void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, @@ -960,22 +1305,26 @@ unsigned memIndex = read(); OperationState state(mainRewriteLoc, read()); - readList(state.operands); + readValueList(state.operands); for (unsigned i = 0, e = read(); i != e; ++i) { Identifier name = read(); if (Attribute attr = read()) state.addAttribute(name, attr); } - bool hasInferredTypes = false; for (unsigned i = 0, e = read(); i != e; ++i) { - Type resultType = read(); - hasInferredTypes |= !resultType; - state.types.push_back(resultType); - } + if (read() == PDLValue::Kind::Type) { + state.types.push_back(read()); + continue; + } + + // If we find a null range, this signals that the types are infered. + if (TypeRange *resultTypes = read()) { + state.types.append(resultTypes->begin(), resultTypes->end()); + continue; + } - // Handle the case where the operation has inferred types. - if (hasInferredTypes) { + // Handle the case where the operation has inferred types. InferTypeOpInterface::Concept *concept = state.name.getAbstractOperation()->getInterface(); @@ -986,7 +1335,9 @@ state.attributes.getDictionary(state.getContext()), state.regions, state.types))) return; + break; } + Operation *resultOp = rewriter.createOperation(state); memory[memIndex] = resultOp; @@ -1036,11 +1387,21 @@ void ByteCodeExecutor::executeGetDefiningOp() { LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); unsigned memIndex = read(); - Value value = read(); - Operation *op = value ? value.getDefiningOp() : nullptr; + Operation *op = nullptr; + if (read() == PDLValue::Kind::Value) { + Value value = read(); + if (value) + op = value.getDefiningOp(); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + } else { + ValueRange *values = read(); + if (values && !values->empty()) { + op = values->front().getDefiningOp(); + } + LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); + } - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" - << " * Result: " << *op << "\n"); + LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); memory[memIndex] = op; } @@ -1056,6 +1417,75 @@ memory[memIndex] = operand.getAsOpaquePointer(); } +/// This function is the internal implementation of `GetResults` and +/// `GetOperands` that provides support for extracting a value range from the +/// given operation. +template