diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -446,6 +446,9 @@ /// Print this value to the provided output stream. void print(raw_ostream &os) const; + /// Print the specified value kind to an output stream. + static void print(raw_ostream &os, Kind kind); + private: /// Find the index of a given type in a range of other types. template @@ -491,6 +494,11 @@ return os; } +inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { + PDLValue::print(os, kind); + return os; +} + //===----------------------------------------------------------------------===// // PDLResultList diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -126,6 +126,29 @@ } } +void PDLValue::print(raw_ostream &os, Kind kind) { + switch (kind) { + case Kind::Attribute: + os << "Attribute"; + break; + case Kind::Operation: + os << "Operation"; + break; + case Kind::Type: + os << "Type"; + break; + case Kind::TypeRange: + os << "TypeRange"; + break; + case Kind::Value: + os << "Value"; + break; + case Kind::ValueRange: + os << "ValueRange"; + break; + } +} + //===----------------------------------------------------------------------===// // PDLPatternModule //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -28,6 +28,7 @@ /// entries. ByteCodeAddr refers to size of indices into the bytecode. using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; +using OwningOpRange = llvm::OwningArrayRef; //===----------------------------------------------------------------------===// // PDLByteCodePattern @@ -79,6 +80,12 @@ /// of the bytecode. std::vector memory; + /// A mutable block of memory used during the matching and rewriting phase of + /// the bytecode to store ranges of operations. These are always stored by + /// owning references, because at no point in the execution of the byte code + /// we get an indexed range (view) of operations. + std::vector opRangeMemory; + /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector typeRangeMemory; @@ -93,6 +100,11 @@ /// interpreter to provide a guaranteed lifetime. std::vector> allocatedValueRangeMemory; + /// The current index of ranges being iterated over for each level of nesting. + /// These are always maintained at 0 for the loops that are not active, so we + /// do not need to have a separate initialization phase for each loop. + std::vector loopIndex; + /// The up-to-date benefits of the patterns held by the bytecode. The order /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. std::vector currentPatternBenefits; @@ -188,8 +200,12 @@ ByteCodeField maxValueMemoryIndex = 0; /// The maximum number of different types of ranges. + ByteCodeField maxOpRangeCount = 0; ByteCodeField maxTypeRangeCount = 0; ByteCodeField maxValueRangeCount = 0; + + /// The maximum number of nested loops. + ByteCodeField maxLoopLevel = 0; }; } // end namespace detail diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -95,14 +95,24 @@ CheckResultCount, /// Compare a range of types to a constant range of types. CheckTypes, + /// Continue to the next iteration of a loop. + Continue, /// Create an operation. CreateOperation, /// Create a range of types. CreateTypes, /// Erase an operation. EraseOp, + /// Extract the op from a range at the specified index. + ExtractOp, + /// Extract the type from a range at the specified index. + ExtractType, + /// Extract the value from a range at the specified index. + ExtractValue, /// Terminate a matcher or rewrite sequence. Finalize, + /// Iterate over a range of values. + ForEach, /// Get a specific attribute of an operation. GetAttribute, /// Get the type of an attribute. @@ -125,6 +135,8 @@ GetResultN, /// Get a specific result group of an operation. GetResults, + /// Get the users of a value or a range of values. + GetUsers, /// Get the type of a value. GetValueType, /// Get the types of a value range. @@ -158,8 +170,13 @@ // Generator namespace { +struct ByteCodeLiveRange; struct ByteCodeWriter; +/// Check if the given class `T` can be converted to an opaque pointer. +template +using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); + /// This class represents the main generator for the pattern bytecode. class Generator { public: @@ -168,15 +185,19 @@ SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, + ByteCodeField &maxOpRangeMemoryIndex, ByteCodeField &maxTypeRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex, + ByteCodeField &maxLoopLevel, llvm::StringMap &constraintFns, llvm::StringMap &rewriteFns) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), + maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), - maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { + maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), + maxLoopLevel(maxLoopLevel) { for (auto it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(rewriteFns)) @@ -221,6 +242,7 @@ void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); /// Generate the bytecode for the given operation. + void generate(Region *region, ByteCodeWriter &writer); void generate(Operation *op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); @@ -232,12 +254,15 @@ void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer); void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); @@ -245,6 +270,7 @@ void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); @@ -279,17 +305,25 @@ /// `uniquedData`. DenseMap uniquedDataToMemIndex; + /// The current level of the foreach loop. + ByteCodeField curLoopLevel = 0; + /// The current MLIR context. MLIRContext *ctx; + /// Mapping from block to its address. + DenseMap blockToAddr; + /// Data of the ByteCode class to be populated. std::vector &uniquedData; SmallVectorImpl &matcherByteCode; SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; + ByteCodeField &maxOpRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; + ByteCodeField &maxLoopLevel; }; /// This class provides utilities for writing a bytecode stream. @@ -311,15 +345,20 @@ bytecode.append({fieldParts[0], fieldParts[1]}); } + /// Append a single successor to the bytecode, the exact address will need to + /// be resolved later. + void append(Block *successor) { + // Add back a reference to the successor so that the address can be resolved + // later. + unresolvedSuccessorRefs[successor].push_back(bytecode.size()); + append(ByteCodeAddr(0)); + } + /// Append a successor range to the bytecode, the exact address will need to /// be resolved later. void append(SuccessorRange successors) { - // Add back references to the any successors so that the address can be - // resolved later. - for (Block *successor : successors) { - unresolvedSuccessorRefs[successor].push_back(bytecode.size()); - append(ByteCodeAddr(0)); - } + for (Block *successor : successors) + append(successor); } /// Append a range of values that will be read as generic PDLValues. @@ -336,10 +375,12 @@ } /// Append the PDLValue::Kind of the given value. - void appendPDLValueKind(Value value) { - // Append the type of the value in addition to the value itself. + void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } + + /// Append the PDLValue::Kind of the given type. + void appendPDLValueKind(Type type) { PDLValue::Kind kind = - TypeSwitch(value.getType()) + TypeSwitch(type) .Case( [](Type) { return PDLValue::Kind::Attribute; }) .Case( @@ -354,10 +395,6 @@ bytecode.push_back(static_cast(kind)); } - /// Check if the given class `T` has an iterator type. - template - using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); - /// Append a value that will be stored in a memory slot and not inline within /// the bytecode. template @@ -396,25 +433,34 @@ /// This class represents a live range of PDL Interpreter values, containing /// information about when values are live within a match/rewrite. struct ByteCodeLiveRange { - using Set = llvm::IntervalMap; + using Set = llvm::IntervalMap; using Allocator = Set::Allocator; - ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} + ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} /// Union this live range with the one provided. void unionWith(const ByteCodeLiveRange &rhs) { - for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) - liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); + for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; + ++it) + liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); } /// Returns true if this range overlaps with the one provided. bool overlaps(const ByteCodeLiveRange &rhs) const { - return llvm::IntervalMapOverlaps(liveness, rhs.liveness).valid(); + return llvm::IntervalMapOverlaps(*liveness, *rhs.liveness) + .valid(); } /// A map representing the ranges of the match/rewrite that a value is live in /// the interpreter. - llvm::IntervalMap liveness; + /// + /// We use std::unique_ptr here, because IntervalMap does not provide a + /// correct copy or move constructor. We can eliminate the pointer once + /// https://reviews.llvm.org/D113240 lands. + std::unique_ptr> liveness; + + /// The operation range storage index for this range. + Optional opRangeIndex; /// The type range storage index for this range. Optional typeRangeIndex; @@ -446,15 +492,8 @@ "unexpected branches in rewriter function"); // Generate code for the matcher function. - DenseMap blockToAddr; - llvm::ReversePostOrderTraversal rpot(&matcherFunc.getBody()); ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); - for (Block *block : rpot) { - // Keep track of where this block begins within the matcher function. - blockToAddr.try_emplace(block, matcherByteCode.size()); - for (Operation &op : *block) - generate(&op, matcherByteCodeWriter); - } + generate(&matcherFunc.getBody(), matcherByteCodeWriter); // Resolve successor references in the matcher. for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { @@ -501,7 +540,7 @@ // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. - DenseMap opToIndex; + DenseMap opToIndex; matcherFunc.getBody().walk([&](Operation *op) { opToIndex.insert(std::make_pair(op, opToIndex.size())); }); @@ -516,8 +555,8 @@ // Walk each of the blocks, computing the def interval that the value is used. Liveness matcherLiveness(matcherFunc); - for (Block &block : matcherFunc.getBody()) { - const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); + matcherFunc->walk([&](Block *block) { + const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); assert(info && "expected liveness info for block"); auto processValue = [&](Value value, Operation *firstUseOrDef) { // We don't need to process the root op argument, this value is always @@ -527,7 +566,7 @@ // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; - defRangeIt->second.liveness.insert( + defRangeIt->second.liveness->insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); @@ -535,7 +574,9 @@ // Check to see if this value is a range type. if (auto rangeTy = value.getType().dyn_cast()) { Type eleType = rangeTy.getElementType(); - if (eleType.isa()) + if (eleType.isa()) + defRangeIt->second.opRangeIndex = 0; + else if (eleType.isa()) defRangeIt->second.typeRangeIndex = 0; else if (eleType.isa()) defRangeIt->second.valueRangeIndex = 0; @@ -543,18 +584,37 @@ }; // Process the live-ins of this block. - for (Value liveIn : info->in()) - processValue(liveIn, &block.front()); + for (Value liveIn : info->in()) { + // Only process the value if it has been defined in the current region. + // Other values that span across pdl_interp.foreach will be added higher + // up. This ensures that the we keep them alive for the entire duration + // of the loop. + if (liveIn.getParentRegion() == block->getParent()) + processValue(liveIn, &block->front()); + } + + // Process the block arguments for the entry block (those are not live-in). + if (block->isEntryBlock()) { + for (Value argument : block->getArguments()) + processValue(argument, &block->front()); + } // Process any new defs within this block. - for (Operation &op : block) + for (Operation &op : *block) for (Value result : op.getResults()) processValue(result, &op); - } + }); // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; - ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; + + // The number of memory indices currently allocated (and its next value). + // Recall that the root gets allocated memory index 0. + ByteCodeField numIndices = 1; + + // The number of memory ranges of various types (and their next values). + ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; + for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; ByteCodeLiveRange &defRange = defIt.second; @@ -566,7 +626,11 @@ existingRange.unionWith(defRange); memIndex = existingIndexIt.index() + 1; - if (defRange.typeRangeIndex) { + if (defRange.opRangeIndex) { + if (!existingRange.opRangeIndex) + existingRange.opRangeIndex = numOpRanges++; + valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; + } else if (defRange.typeRangeIndex) { if (!existingRange.typeRangeIndex) existingRange.typeRangeIndex = numTypeRanges++; valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; @@ -585,8 +649,11 @@ ByteCodeLiveRange &newRange = allocatedIndices.back(); newRange.unionWith(defRange); - // Allocate an index for type/value ranges. - if (defRange.typeRangeIndex) { + // Allocate an index for op/type/value ranges. + if (defRange.opRangeIndex) { + newRange.opRangeIndex = numOpRanges; + valueToRangeIndex[defIt.first] = numOpRanges++; + } else if (defRange.typeRangeIndex) { newRange.typeRangeIndex = numTypeRanges; valueToRangeIndex[defIt.first] = numTypeRanges++; } else if (defRange.valueRangeIndex) { @@ -599,15 +666,35 @@ } } + // Print the index usage and ensure that we did not run out of index space. + LLVM_DEBUG({ + llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " + << "(down from initial " << valueDefRanges.size() << ").\n"; + }); + assert(allocatedIndices.size() <= std::numeric_limits::max() && + "Ran out of memory for allocated indices"); + // Update the max number of indices. if (numIndices > maxValueMemoryIndex) maxValueMemoryIndex = numIndices; + if (numOpRanges > maxOpRangeMemoryIndex) + maxOpRangeMemoryIndex = numOpRanges; if (numTypeRanges > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = numTypeRanges; if (numValueRanges > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = numValueRanges; } +void Generator::generate(Region *region, ByteCodeWriter &writer) { + llvm::ReversePostOrderTraversal rpot(region); + for (Block *block : rpot) { + // Keep track of where this block begins within the matcher function. + blockToAddr.try_emplace(block, matcherByteCode.size()); + for (Operation &op : *block) + generate(&op, writer); + } +} + void Generator::generate(Operation *op, ByteCodeWriter &writer) { TypeSwitch(op) .Case 0 && "encountered pdl_interp.continue at top level"); + writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); +} void Generator::generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. @@ -736,9 +829,31 @@ void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { writer.append(OpCode::EraseOp, op.operation()); } +void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) { + OpCode opCode = + TypeSwitch(op.result().getType()) + .Case([](pdl::OperationType) { return OpCode::ExtractOp; }) + .Case([](pdl::ValueType) { return OpCode::ExtractValue; }) + .Case([](pdl::TypeType) { return OpCode::ExtractType; }) + .Default([](Type) -> OpCode { + llvm_unreachable("unsupported element type"); + }); + writer.append(opCode, op.range(), op.index(), op.result()); +} void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) { writer.append(OpCode::Finalize); } +void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { + BlockArgument arg = op.getLoopVariable(); + writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); + writer.appendPDLValueKind(arg.getType()); + writer.append(curLoopLevel, op.successor()); + ++curLoopLevel; + if (curLoopLevel > maxLoopLevel) + maxLoopLevel = curLoopLevel; + generate(&op.region(), writer); + --curLoopLevel; +} void Generator::generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer) { writer.append(OpCode::GetAttribute, op.attribute(), op.operation(), @@ -793,6 +908,12 @@ writer.append(std::numeric_limits::max()); writer.append(result); } +void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { + Value operations = op.operations(); + ByteCodeField rangeIndex = getRangeStorageIndex(operations); + writer.append(OpCode::GetUsers, operations, rangeIndex); + writer.appendPDLValue(op.value()); +} void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { if (op.getType().isa()) { @@ -865,8 +986,8 @@ llvm::StringMap rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, - maxTypeRangeCount, maxValueRangeCount, constraintFns, - rewriteFns); + maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, + maxLoopLevel, constraintFns, rewriteFns); generator.generate(module); // Initialize the external functions. @@ -880,8 +1001,10 @@ /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); + state.opRangeMemory.resize(maxOpRangeCount); state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); + state.loopIndex.resize(maxLoopLevel, 0); state.currentPatternBenefits.reserve(patterns.size()); for (const PDLByteCodePattern &pattern : patterns) state.currentPatternBenefits.push_back(pattern.getBenefit()); @@ -896,20 +1019,23 @@ public: ByteCodeExecutor( const ByteCodeField *curCodeIt, MutableArrayRef memory, + MutableArrayRef> opRangeMemory, MutableArrayRef typeRangeMemory, std::vector> &allocatedTypeRangeMemory, MutableArrayRef valueRangeMemory, std::vector> &allocatedValueRangeMemory, - ArrayRef uniquedMemory, ArrayRef code, + MutableArrayRef loopIndex, ArrayRef uniquedMemory, + ArrayRef code, ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, ArrayRef rewriteFunctions) - : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), + : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), + typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), valueRangeMemory(valueRangeMemory), allocatedValueRangeMemory(allocatedValueRangeMemory), - uniquedMemory(uniquedMemory), code(code), + loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), rewriteFunctions(rewriteFunctions) {} @@ -932,10 +1058,15 @@ void executeCheckOperationName(); void executeCheckResultCount(); void executeCheckTypes(); + void executeContinue(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); void executeCreateTypes(); void executeEraseOp(PatternRewriter &rewriter); + template + void executeExtract(); + void executeFinalize(); + void executeForEach(); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); @@ -943,6 +1074,7 @@ void executeGetOperands(); void executeGetResult(unsigned index); void executeGetResults(); + void executeGetUsers(); void executeGetValueType(); void executeGetValueRangeTypes(); void executeIsNotNull(); @@ -956,6 +1088,16 @@ void executeSwitchType(); void executeSwitchTypes(); + /// Pushes a code iterator to the stack. + void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } + + /// Pops a code iterator from the stack, returning true on success. + void popCodeIt() { + assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack"); + curCodeIt = resumeCodeIt.back(); + resumeCodeIt.pop_back(); + } + /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. @@ -1012,6 +1154,18 @@ selectJump(size_t(0)); } + /// Store a pointer to memory. + void storeToMemory(unsigned index, const void *value) { + memory[index] = value; + } + + /// Store a value to memory as an opaque pointer. + template + std::enable_if_t::value> + storeToMemory(unsigned index, T value) { + memory[index] = value.getAsOpaquePointer(); + } + /// Internal implementation of reading various data types from the bytecode /// stream. template @@ -1076,13 +1230,20 @@ /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; + /// The stack of bytecode positions at which to resume operation. + SmallVector resumeCodeIt; + /// The current execution memory. MutableArrayRef memory; + MutableArrayRef opRangeMemory; MutableArrayRef typeRangeMemory; std::vector> &allocatedTypeRangeMemory; MutableArrayRef valueRangeMemory; std::vector> &allocatedValueRangeMemory; + /// The current loop indices. + MutableArrayRef loopIndex; + /// References to ByteCode data necessary for execution. ArrayRef uniquedMemory; ArrayRef code; @@ -1277,6 +1438,14 @@ selectJump(*lhs == rhs.cast().getAsValueRange()); } +void ByteCodeExecutor::executeContinue() { + ByteCodeField level = read(); + LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" + << " * Level: " << level << "\n"); + ++loopIndex[level]; + popCodeIt(); +} + void ByteCodeExecutor::executeCreateTypes() { LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); unsigned memIndex = read(); @@ -1357,6 +1526,65 @@ rewriter.eraseOp(op); } +template +void ByteCodeExecutor::executeExtract() { + LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); + Range *range = read(); + unsigned index = read(); + unsigned memIndex = read(); + + if (!range) { + memory[memIndex] = nullptr; + return; + } + + T result = index < range->size() ? (*range)[index] : T(); + LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" + << " * Index: " << index << "\n" + << " * Result: " << result << "\n"); + storeToMemory(memIndex, result); +} + +void ByteCodeExecutor::executeFinalize() { + LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); +} + +void ByteCodeExecutor::executeForEach() { + LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); + // Subtract 1 for the op code. + const ByteCodeField *it = curCodeIt - 1; + unsigned rangeIndex = read(); + unsigned memIndex = read(); + const void *value = nullptr; + + switch (read()) { + case PDLValue::Kind::Operation: { + unsigned &index = loopIndex[read()]; + ArrayRef array = opRangeMemory[rangeIndex]; + assert(index <= array.size() && "iterated past the end"); + if (index < array.size()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); + value = array[index]; + break; + } + + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + index = 0; + selectJump(size_t(0)); + return; + } + default: + llvm_unreachable("unexpected `ForEach` value kind"); + } + + // Store the iterate value and the stack address. + memory[memIndex] = value; + pushCodeIt(it); + + // Skip over the successor (we will enter the body of the loop). + read(); +} + void ByteCodeExecutor::executeGetAttribute() { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); unsigned memIndex = read(); @@ -1421,7 +1649,7 @@ static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, - MutableArrayRef &valueRangeMemory) { + MutableArrayRef valueRangeMemory) { // Check for the sentinel index that signals that all values should be // returned. if (index == std::numeric_limits::max()) { @@ -1509,6 +1737,46 @@ memory[read()] = result; } +void ByteCodeExecutor::executeGetUsers() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + OwningOpRange &range = opRangeMemory[rangeIndex]; + memory[memIndex] = ⦥ + + range = OwningOpRange(); + if (read() == PDLValue::Kind::Value) { + // Read the value. + Value value = read(); + if (!value) + return; + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + + // Extract the users of a single value. + range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); + llvm::copy(value.getUsers(), range.begin()); + } else { + // Read a range of values. + ValueRange *values = read(); + if (!values) + return; + LLVM_DEBUG({ + llvm::dbgs() << " * Values (" << values->size() << "): "; + llvm::interleaveComma(*values, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + // Extract all the users of a range of values. + SmallVector users; + for (Value value : *values) + users.append(value.user_begin(), value.user_end()); + range = OwningOpRange(users.size()); + llvm::copy(users, range.begin()); + } + + LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); +} + void ByteCodeExecutor::executeGetValueType() { LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); unsigned memIndex = read(); @@ -1731,6 +1999,9 @@ case CheckTypes: executeCheckTypes(); break; + case Continue: + executeContinue(); + break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; @@ -1740,9 +2011,22 @@ case EraseOp: executeEraseOp(rewriter); break; + case ExtractOp: + executeExtract(); + break; + case ExtractType: + executeExtract(); + break; + case ExtractValue: + executeExtract(); + break; case Finalize: - LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); + executeFinalize(); + LLVM_DEBUG(llvm::dbgs() << "\n"); return; + case ForEach: + executeForEach(); + break; case GetAttribute: executeGetAttribute(); break; @@ -1784,6 +2068,9 @@ case GetResults: executeGetResults(); break; + case GetUsers: + executeGetUsers(); + break; case GetValueType: executeGetValueType(); break; @@ -1834,11 +2121,11 @@ // The matcher function always starts at code address 0. ByteCodeExecutor executor( - matcherByteCode.data(), state.memory, state.typeRangeMemory, - state.allocatedTypeRangeMemory, state.valueRangeMemory, - state.allocatedValueRangeMemory, uniquedData, matcherByteCode, - state.currentPatternBenefits, patterns, constraintFunctions, - rewriteFunctions); + matcherByteCode.data(), state.memory, state.opRangeMemory, + state.typeRangeMemory, state.allocatedTypeRangeMemory, + state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, + uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, + constraintFunctions, rewriteFunctions); executor.execute(rewriter, &matches); // Order the found matches by benefit. @@ -1857,8 +2144,9 @@ ByteCodeExecutor executor( &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, - state.typeRangeMemory, state.allocatedTypeRangeMemory, - state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, + state.opRangeMemory, state.typeRangeMemory, + state.allocatedTypeRangeMemory, state.valueRangeMemory, + state.allocatedValueRangeMemory, state.loopIndex, uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -514,6 +514,12 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl_interp::ContinueOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + //===----------------------------------------------------------------------===// // pdl_interp::CreateAttributeOp //===----------------------------------------------------------------------===// @@ -576,12 +582,277 @@ // Fully tested within the tests for other operations. +//===----------------------------------------------------------------------===// +// pdl_interp::ExtractOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %val = pdl_interp.get_result 0 of %root + %ops = pdl_interp.get_users of %val : !pdl.value + %op1 = pdl_interp.extract 1 of %ops : !pdl.operation + pdl_interp.is_not_null %op1 : !pdl.operation -> ^success, ^end + ^success: + pdl_interp.record_match @rewriters::@success(%op1 : !pdl.operation) : benefit(1), loc([%root]) -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.extract_op +// CHECK: "test.success" +// CHECK: %[[OPERAND:.*]] = "test.op" +// CHECK: "test.op"(%[[OPERAND]]) +module @ir attributes { test.extract_op } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> (i32) + "test.op"(%operand, %operand) : (i32, i32) -> (i32) +} + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + %vals = pdl_interp.get_results of %root : !pdl.range + %types = pdl_interp.get_value_type of %vals : !pdl.range + %type1 = pdl_interp.extract 1 of %types : !pdl.type + pdl_interp.is_not_null %type1 : !pdl.type -> ^success, ^end + ^success: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.extract_type +// CHECK: %[[OPERAND:.*]] = "test.op" +// CHECK: "test.success" +// CHECK: "test.op"(%[[OPERAND]]) +module @ir attributes { test.extract_type } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> (i32, i32) + "test.op"(%operand) : (i32) -> (i32) +} + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + %vals = pdl_interp.get_results of %root : !pdl.range + %val1 = pdl_interp.extract 1 of %vals : !pdl.value + pdl_interp.is_not_null %val1 : !pdl.value -> ^success, ^end + ^success: + pdl_interp.record_match @rewriters::@success(%root : !pdl.operation) : benefit(1), loc([%root]) -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.extract_value +// CHECK: %[[OPERAND:.*]] = "test.op" +// CHECK: "test.success" +// CHECK: "test.op"(%[[OPERAND]]) +module @ir attributes { test.extract_value } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> (i32, i32) + "test.op"(%operand) : (i32) -> (i32) +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::FinalizeOp //===----------------------------------------------------------------------===// // Fully tested within the tests for other operations. +//===----------------------------------------------------------------------===// +// pdl_interp::ForEachOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %val1 = pdl_interp.get_result 0 of %root + %ops1 = pdl_interp.get_users of %val1 : !pdl.value + pdl_interp.foreach %op1 : !pdl.operation in %ops1 { + %val2 = pdl_interp.get_result 0 of %op1 + %ops2 = pdl_interp.get_users of %val2 : !pdl.value + pdl_interp.foreach %op2 : !pdl.operation in %ops2 { + pdl_interp.record_match @rewriters::@success(%op2 : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.foreach +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[ROOT:.*]] = "test.op" +// CHECK: %[[VALA:.*]] = "test.op"(%[[ROOT]]) +// CHECK: %[[VALB:.*]] = "test.op"(%[[ROOT]]) +module @ir attributes { test.foreach } { + %root = "test.op"() : () -> i32 + %valA = "test.op"(%root) : (i32) -> (i32) + "test.op"(%valA) : (i32) -> (i32) + "test.op"(%valA) : (i32) -> (i32) + %valB = "test.op"(%root) : (i32) -> (i32) + "test.op"(%valB) : (i32) -> (i32) + "test.op"(%valB) : (i32) -> (i32) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::GetUsersOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %val = pdl_interp.get_result 0 of %root + %ops = pdl_interp.get_users of %val : !pdl.value + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_users_of_value +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[OPERAND:.*]] = "test.op" +module @ir attributes { test.get_users_of_value } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> (i32) + "test.op"(%operand, %operand) : (i32, i32) -> (i32) +} + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end + ^next: + %vals = pdl_interp.get_results of %root : !pdl.range + %ops = pdl_interp.get_users of %vals : !pdl.range + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_all_users_of_range +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[OPERANDS:.*]]:2 = "test.op" +module @ir attributes { test.get_all_users_of_range } { + %operands:2 = "test.op"() : () -> (i32, i32) + "test.op"(%operands#0) : (i32) -> (i32) + "test.op"(%operands#1) : (i32) -> (i32) +} + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end + ^next: + %vals = pdl_interp.get_results of %root : !pdl.range + %val = pdl_interp.extract 0 of %vals : !pdl.value + %ops = pdl_interp.get_users of %val : !pdl.value + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_first_users_of_range +// CHECK: "test.success" +// CHECK: %[[OPERANDS:.*]]:2 = "test.op" +// CHECK: "test.op" +module @ir attributes { test.get_first_users_of_range } { + %operands:2 = "test.op"() : () -> (i32, i32) + "test.op"(%operands#0) : (i32) -> (i32) + "test.op"(%operands#1) : (i32) -> (i32) +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::GetAttributeOp //===----------------------------------------------------------------------===//