diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -29,6 +29,9 @@ using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; +/// A range over operations. +using OpRange = llvm::ArrayRef; + //===----------------------------------------------------------------------===// // PDLByteCodePattern //===----------------------------------------------------------------------===// @@ -80,6 +83,13 @@ std::vector memory; /// A mutable block of memory used during the matching and rewriting phase of + /// the bytecode to store ranges of operations. + std::vector opRangeMemory; + /// A set of operation ranges that have been allocated by the byte code + /// interpreter to provide a guaranteed lifetime. + std::vector> allocatedOpRangeMemory; + + /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector typeRangeMemory; /// A set of type ranges that have been allocated by the byte code interpreter @@ -171,6 +181,9 @@ /// * Attribute, Identifier, OperationName, Type std::vector uniquedData; + /// A map from address (index into matcherByteCode) to corresponding block. + DenseMap addrToBlock; + /// A vector containing the generated bytecode for the matcher. SmallVector matcherByteCode; @@ -188,6 +201,7 @@ ByteCodeField maxValueMemoryIndex = 0; /// The maximum number of different types of ranges. + ByteCodeField maxOpRangeCount = 0; ByteCodeField maxTypeRangeCount = 0; ByteCodeField maxValueRangeCount = 0; }; diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -67,6 +67,7 @@ /// This method should be called irregardless of whether the match+rewrite was a /// success or not. void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { + allocatedOpRangeMemory.clear(); allocatedTypeRangeMemory.clear(); allocatedValueRangeMemory.clear(); } @@ -95,6 +96,8 @@ CheckResultCount, /// Compare a range of types to a constant range of types. CheckTypes, + /// Continue to the next iteration of a loop. + Continue, /// Create an operation. CreateOperation, /// Create a range of types. @@ -103,6 +106,8 @@ EraseOp, /// Terminate a matcher or rewrite sequence. Finalize, + /// Iterate over a range of values. + ForEach, /// Get a specific attribute of an operation. GetAttribute, /// Get the type of an attribute. @@ -125,6 +130,8 @@ GetResultN, /// Get a specific result group of an operation. GetResults, + /// Gets all users of a value and store them in a range. + GetUsers, /// Get the type of a value. GetValueType, /// Get the types of a value range. @@ -158,23 +165,27 @@ // Generator namespace { +struct ByteCodeLiveRange; struct ByteCodeWriter; /// This class represents the main generator for the pattern bytecode. class Generator { public: Generator(MLIRContext *ctx, std::vector &uniquedData, + DenseMap &addrToBlock, SmallVectorImpl &matcherByteCode, SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, + ByteCodeField &maxOpRangeMemoryIndex, ByteCodeField &maxTypeRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex, llvm::StringMap &constraintFns, llvm::StringMap &rewriteFns) - : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), - rewriterByteCode(rewriterByteCode), patterns(patterns), - maxValueMemoryIndex(maxValueMemoryIndex), + : ctx(ctx), uniquedData(uniquedData), addrToBlock(addrToBlock), + matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), + patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), + maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { for (auto it : llvm::enumerate(constraintFns)) @@ -221,6 +232,7 @@ void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); /// Generate the bytecode for the given operation. + void generate(Region *region, ByteCodeWriter &writer); void generate(Operation *op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); @@ -232,6 +244,7 @@ void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); @@ -245,9 +258,11 @@ void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); @@ -282,12 +297,17 @@ /// The current MLIR context. MLIRContext *ctx; + /// Mapping from block to its address. + DenseMap blockToAddr; + /// Data of the ByteCode class to be populated. std::vector &uniquedData; + DenseMap &addrToBlock; SmallVectorImpl &matcherByteCode; SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; + ByteCodeField &maxOpRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; }; @@ -311,14 +331,20 @@ bytecode.append({fieldParts[0], fieldParts[1]}); } + /// Append a single successor to the bytecode, the exact address will need to + /// be resolved later. + void append(Block *successor) { + // Add back a reference to the successor so that the address can be resolved + // later. + unresolvedSuccessorRefs[successor].push_back(bytecode.size()); + append(ByteCodeAddr(0)); + } + /// Append a successor range to the bytecode, the exact address will need to /// be resolved later. void append(SuccessorRange successors) { - // Add back references to the any successors so that the address can be - // resolved later. for (Block *successor : successors) { - unresolvedSuccessorRefs[successor].push_back(bytecode.size()); - append(ByteCodeAddr(0)); + append(successor); } } @@ -336,10 +362,12 @@ } /// Append the PDLValue::Kind of the given value. - void appendPDLValueKind(Value value) { - // Append the type of the value in addition to the value itself. + void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } + + /// Append the PDLValue::Kind of the given type. + void appendPDLValueKind(Type type) { PDLValue::Kind kind = - TypeSwitch(value.getType()) + TypeSwitch(type) .Case( [](Type) { return PDLValue::Kind::Attribute; }) .Case( @@ -396,25 +424,35 @@ /// This class represents a live range of PDL Interpreter values, containing /// information about when values are live within a match/rewrite. struct ByteCodeLiveRange { - using Set = llvm::IntervalMap; + using Set = llvm::IntervalMap; using Allocator = Set::Allocator; - ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} + ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} /// Union this live range with the one provided. void unionWith(const ByteCodeLiveRange &rhs) { - for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) - liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); + for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; + ++it) + liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); } /// Returns true if this range overlaps with the one provided. bool overlaps(const ByteCodeLiveRange &rhs) const { - return llvm::IntervalMapOverlaps(liveness, rhs.liveness).valid(); + return llvm::IntervalMapOverlaps(*liveness, *rhs.liveness) + .valid(); } /// A map representing the ranges of the match/rewrite that a value is live in /// the interpreter. - llvm::IntervalMap liveness; + /// + /// We use std::unique_ptr here, because the move constructor of IntervalMap + /// appears to have a bug that shows up when it is included here directly. + /// Specifically, the start bound of the interval gets populated by wrong + /// value; this only shows up in release builds and is hard to reproduce. + std::unique_ptr> liveness; + + /// The operation range storage index for this range. + Optional opRangeIndex; /// The type range storage index for this range. Optional typeRangeIndex; @@ -446,15 +484,8 @@ "unexpected branches in rewriter function"); // Generate code for the matcher function. - DenseMap blockToAddr; - llvm::ReversePostOrderTraversal rpot(&matcherFunc.getBody()); ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); - for (Block *block : rpot) { - // Keep track of where this block begins within the matcher function. - blockToAddr.try_emplace(block, matcherByteCode.size()); - for (Operation &op : *block) - generate(&op, matcherByteCodeWriter); - } + generate(&matcherFunc.getBody(), matcherByteCodeWriter); // Resolve successor references in the matcher. for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { @@ -501,7 +532,7 @@ // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. - DenseMap opToIndex; + DenseMap opToIndex; matcherFunc.getBody().walk([&](Operation *op) { opToIndex.insert(std::make_pair(op, opToIndex.size())); }); @@ -516,8 +547,8 @@ // Walk each of the blocks, computing the def interval that the value is used. Liveness matcherLiveness(matcherFunc); - for (Block &block : matcherFunc.getBody()) { - const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); + matcherFunc->walk([&](Block *block) { + const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); assert(info && "expected liveness info for block"); auto processValue = [&](Value value, Operation *firstUseOrDef) { // We don't need to process the root op argument, this value is always @@ -527,7 +558,7 @@ // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; - defRangeIt->second.liveness.insert( + defRangeIt->second.liveness->insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); @@ -535,7 +566,9 @@ // Check to see if this value is a range type. if (auto rangeTy = value.getType().dyn_cast()) { Type eleType = rangeTy.getElementType(); - if (eleType.isa()) + if (eleType.isa()) + defRangeIt->second.opRangeIndex = 0; + else if (eleType.isa()) defRangeIt->second.typeRangeIndex = 0; else if (eleType.isa()) defRangeIt->second.valueRangeIndex = 0; @@ -544,17 +577,29 @@ // Process the live-ins of this block. for (Value liveIn : info->in()) - processValue(liveIn, &block.front()); + // Only process the value if it has been defined in the current region. + // Other values that span across pdl_interp.foreach will be added higher + // up. This ensures that the we keep them alive for the entire duration + // of the loop. + if (liveIn.getParentRegion() == block->getParent()) + processValue(liveIn, &block->front()); // Process any new defs within this block. - for (Operation &op : block) + for (Operation &op : *block) for (Value result : op.getResults()) processValue(result, &op); - } + }); // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; - ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; + + // The number of memory indices currently allocated (and its next value). + // Recall that the roots gets allocated memory index 0. + ByteCodeField numIndices = 1; + + // The number of memory ranges of various types (and their next values). + ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; + for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; ByteCodeLiveRange &defRange = defIt.second; @@ -566,7 +611,11 @@ existingRange.unionWith(defRange); memIndex = existingIndexIt.index() + 1; - if (defRange.typeRangeIndex) { + if (defRange.opRangeIndex) { + if (!existingRange.opRangeIndex) + existingRange.opRangeIndex = numOpRanges++; + valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; + } else if (defRange.typeRangeIndex) { if (!existingRange.typeRangeIndex) existingRange.typeRangeIndex = numTypeRanges++; valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; @@ -585,8 +634,11 @@ ByteCodeLiveRange &newRange = allocatedIndices.back(); newRange.unionWith(defRange); - // Allocate an index for type/value ranges. - if (defRange.typeRangeIndex) { + // Allocate an index for op/type/value ranges. + if (defRange.opRangeIndex) { + newRange.opRangeIndex = numOpRanges; + valueToRangeIndex[defIt.first] = numOpRanges++; + } else if (defRange.typeRangeIndex) { newRange.typeRangeIndex = numTypeRanges; valueToRangeIndex[defIt.first] = numTypeRanges++; } else if (defRange.valueRangeIndex) { @@ -599,15 +651,37 @@ } } + // Ensure that we did not run out of index space. + LLVM_DEBUG({ + llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " + << "(down from initial " << valueDefRanges.size() << ").\n"; + }); + assert(allocatedIndices.size() < std::numeric_limits::max() && + "Ran out of memory for allocated indices"); + // Update the max number of indices. if (numIndices > maxValueMemoryIndex) maxValueMemoryIndex = numIndices; + if (numOpRanges > maxOpRangeMemoryIndex) + maxOpRangeMemoryIndex = numOpRanges; if (numTypeRanges > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = numTypeRanges; if (numValueRanges > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = numValueRanges; } +void Generator::generate(Region *region, ByteCodeWriter &writer) { + llvm::ReversePostOrderTraversal rpot(region); + for (Block *block : rpot) { + // Keep track of where this block begins within the matcher function. + blockToAddr.try_emplace(block, matcherByteCode.size()); + for (Operation &op : *block) { + LLVM_DEBUG(addrToBlock.try_emplace(matcherByteCode.size(), block)); + generate(&op, writer); + } + } +} + void Generator::generate(Operation *op, ByteCodeWriter &writer) { TypeSwitch(op) .Case( + pdl_interp::GetResultsOp, pdl_interp::GetUsersOp, + pdl_interp::GetValueTypeOp, pdl_interp::InferredTypesOp, + pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, + pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, + pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, + pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, + pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); @@ -736,6 +812,9 @@ writer.append(OpCode::CreateTypes, op.result(), getRangeStorageIndex(op.result()), op.value()); } +void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { + writer.append(OpCode::Continue); +} void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) { writer.append(OpCode::EraseOp, op.operation()); } @@ -796,6 +875,11 @@ writer.append(std::numeric_limits::max()); writer.append(result); } +void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { + writer.append(OpCode::GetUsers, getRangeStorageIndex(op.operations()), + ByteCodeField(op.index())); + writer.appendPDLValue(op.value()); +} void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { if (op.getType().isa()) { @@ -815,6 +899,13 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); } +void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { + BlockArgument arg = op.getLoopVariable(); + writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); + writer.appendPDLValueKind(arg.getType()); + writer.append(op.successor()); + generate(&op.region(), writer); +} void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( @@ -866,10 +957,10 @@ PDLByteCode::PDLByteCode(ModuleOp module, llvm::StringMap constraintFns, llvm::StringMap rewriteFns) { - Generator generator(module.getContext(), uniquedData, matcherByteCode, - rewriterByteCode, patterns, maxValueMemoryIndex, - maxTypeRangeCount, maxValueRangeCount, constraintFns, - rewriteFns); + Generator generator(module.getContext(), uniquedData, addrToBlock, + matcherByteCode, rewriterByteCode, patterns, + maxValueMemoryIndex, maxOpRangeCount, maxTypeRangeCount, + maxValueRangeCount, constraintFns, rewriteFns); generator.generate(module); // Initialize the external functions. @@ -883,6 +974,7 @@ /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); + state.opRangeMemory.resize(maxOpRangeCount, OpRange()); state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); state.currentPatternBenefits.reserve(patterns.size()); @@ -899,6 +991,8 @@ public: ByteCodeExecutor( const ByteCodeField *curCodeIt, MutableArrayRef memory, + MutableArrayRef opRangeMemory, + std::vector> &allocatedOpRangeMemory, MutableArrayRef typeRangeMemory, std::vector> &allocatedTypeRangeMemory, MutableArrayRef valueRangeMemory, @@ -907,15 +1001,18 @@ ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, - ArrayRef rewriteFunctions) - : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), + ArrayRef rewriteFunctions, + const DenseMap *addrToBlock = nullptr) + : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), + allocatedOpRangeMemory(allocatedOpRangeMemory), + typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), valueRangeMemory(valueRangeMemory), allocatedValueRangeMemory(allocatedValueRangeMemory), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), - rewriteFunctions(rewriteFunctions) {} + rewriteFunctions(rewriteFunctions), addrToBlock(addrToBlock) {} /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching @@ -935,10 +1032,13 @@ void executeCheckOperationName(); void executeCheckResultCount(); void executeCheckTypes(); + void executeContinue(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); void executeCreateTypes(); void executeEraseOp(PatternRewriter &rewriter); + void executeFinalize(); + void executeForEach(); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); @@ -946,6 +1046,7 @@ void executeGetOperands(); void executeGetResult(unsigned index); void executeGetResults(); + void executeGetUsers(); void executeGetValueType(); void executeGetValueRangeTypes(); void executeIsNotNull(); @@ -959,6 +1060,16 @@ void executeSwitchType(); void executeSwitchTypes(); + /// Pushes a code iterator to the stack. + void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } + + /// Pops a code iterator from the stack, returning true on success. + void popCodeIt() { + assert(!resumeCodeIt.empty() && "Attempt to pop code off empty stack"); + curCodeIt = resumeCodeIt.back(); + resumeCodeIt.pop_back(); + } + /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. @@ -1079,8 +1190,13 @@ /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; + /// The stack of bytecode positions at which to resume operation. + SmallVector resumeCodeIt; + /// The current execution memory. MutableArrayRef memory; + MutableArrayRef opRangeMemory; + std::vector> &allocatedOpRangeMemory; MutableArrayRef typeRangeMemory; std::vector> &allocatedTypeRangeMemory; MutableArrayRef valueRangeMemory; @@ -1093,6 +1209,9 @@ ArrayRef patterns; ArrayRef constraintFunctions; ArrayRef rewriteFunctions; + + /// Optional map from code index to the corresponding block. + const DenseMap *addrToBlock; }; /// This class is an instantiation of the PDLResultList that provides access to @@ -1280,6 +1399,11 @@ selectJump(*lhs == rhs.cast().getAsValueRange()); } +void ByteCodeExecutor::executeContinue() { + LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"); + popCodeIt(); +} + void ByteCodeExecutor::executeCreateTypes() { LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); unsigned memIndex = read(); @@ -1338,6 +1462,21 @@ break; } + // Determine the insertion point. + // This is the last of the operands, default to beginning of the block. + Block *block = rewriter.getBlock(); + rewriter.setInsertionPointToStart(block); + DenseSet operationSet; + for (Value value : state.operands) + operationSet.insert(value.getDefiningOp()); + for (auto it = block->rbegin(); it != block->rend(); ++it) + if (operationSet.count(&*it)) { + LLVM_DEBUG(llvm::dbgs() + << "Inserting " << state.name << " after " << *it << "\n"); + rewriter.setInsertionPointAfter(&*it); + break; + } + Operation *resultOp = rewriter.createOperation(state); memory[memIndex] = resultOp; @@ -1360,6 +1499,42 @@ rewriter.eraseOp(op); } +void ByteCodeExecutor::executeFinalize() { + LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); +} + +void ByteCodeExecutor::executeForEach() { + LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); + const ByteCodeField *it = curCodeIt - 1; // Subtract 1 for the op code + unsigned rangeIndex = read(); + unsigned memIndex = read(); + const void *value = nullptr; + + switch (read()) { + case PDLValue::Kind::Operation: { + OpRange &range = opRangeMemory[rangeIndex]; + if (range.empty()) { + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + selectJump(size_t(0)); + return; + } + LLVM_DEBUG(llvm::dbgs() << " * Result: " << *range.front() << "\n"); + value = range.front(); + range = range.drop_front(); + break; + } + default: + llvm_unreachable("unexpected `ForEach` value kind"); + } + + // Store the iterate value and the stack address. + memory[memIndex] = value; + pushCodeIt(it); + + // Skip over the successor (we will enter the body of the loop). + read(); +} + void ByteCodeExecutor::executeGetAttribute() { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); unsigned memIndex = read(); @@ -1512,6 +1687,42 @@ memory[read()] = result; } +void ByteCodeExecutor::executeGetUsers() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); + unsigned rangeIndex = read(); + unsigned operandNumber = read(); + LLVM_DEBUG(llvm::dbgs() << " * Operand: " << operandNumber << "\n"); + + // Read the value. + Value value; + if (read() == PDLValue::Kind::Value) { + value = read(); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + } else { + ValueRange *values = read(); + if (values && !values->empty()) + value = values->front(); + LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); + } + + // Determine all uses at the specified operand number. + SmallVector users; + if (value) + for (OpOperand &use : value.getUses()) + if (use.getOperandNumber() == operandNumber) + users.push_back(use.getOwner()); + LLVM_DEBUG(llvm::dbgs() << " * Result: " << users.size() << " operations\n"); + + // Allocate a buffer for the operation range. + llvm::OwningArrayRef storage(users.size()); + llvm::copy(users, storage.begin()); + allocatedOpRangeMemory.emplace_back(std::move(storage)); + + // Assign this to the range slot. + // TODO: do we need to store this in memory[] as well? + opRangeMemory[rangeIndex] = allocatedOpRangeMemory.back(); +} + void ByteCodeExecutor::executeGetValueType() { LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); unsigned memIndex = read(); @@ -1705,6 +1916,14 @@ SmallVectorImpl *matches, Optional mainRewriteLoc) { while (true) { + LLVM_DEBUG({ + if (addrToBlock) { + Block *block = addrToBlock->lookup(curCodeIt - code.data()); + assert(block && "Unknown address"); + block->printAsOperand(llvm::dbgs()); + llvm::dbgs() << " "; + } + }); OpCode opCode = static_cast(read()); switch (opCode) { case ApplyConstraint: @@ -1734,6 +1953,9 @@ case CheckTypes: executeCheckTypes(); break; + case Continue: + executeContinue(); + break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; @@ -1744,8 +1966,12 @@ executeEraseOp(rewriter); break; case Finalize: - LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); + executeFinalize(); + LLVM_DEBUG(llvm::dbgs() << "\n"); return; + case ForEach: + executeForEach(); + break; case GetAttribute: executeGetAttribute(); break; @@ -1787,6 +2013,9 @@ case GetResults: executeGetResults(); break; + case GetUsers: + executeGetUsers(); + break; case GetValueType: executeGetValueType(); break; @@ -1837,11 +2066,12 @@ // The matcher function always starts at code address 0. ByteCodeExecutor executor( - matcherByteCode.data(), state.memory, state.typeRangeMemory, + matcherByteCode.data(), state.memory, state.opRangeMemory, + state.allocatedOpRangeMemory, state.typeRangeMemory, state.allocatedTypeRangeMemory, state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, constraintFunctions, - rewriteFunctions); + rewriteFunctions, &addrToBlock); executor.execute(rewriter, &matches); // Order the found matches by benefit. @@ -1860,9 +2090,10 @@ ByteCodeExecutor executor( &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, - state.typeRangeMemory, state.allocatedTypeRangeMemory, - state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, - rewriterByteCode, state.currentPatternBenefits, patterns, - constraintFunctions, rewriteFunctions); + state.opRangeMemory, state.allocatedOpRangeMemory, state.typeRangeMemory, + state.allocatedTypeRangeMemory, state.valueRangeMemory, + state.allocatedValueRangeMemory, uniquedData, rewriterByteCode, + state.currentPatternBenefits, patterns, constraintFunctions, + rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -255,8 +255,8 @@ } // CHECK-LABEL: test.are_equal_2 -// CHECK: "test.not_equal" // CHECK: "test.success" +// CHECK: "test.not_equal" // CHECK-NOT: "test.op" module @ir attributes { test.are_equal_2 } { "test.not_equal"() : () -> (i32) @@ -361,8 +361,8 @@ } // CHECK-LABEL: test.check_operand_count_1 -// CHECK: "test.op"() : () -> i32 // CHECK: "test.success" +// CHECK: "test.op"() : () -> i32 module @ir attributes { test.check_operand_count_1 } { %operand = "test.op"() : () -> i32 "test.op"(%operand, %operand) : (i32, i32) -> () @@ -430,8 +430,8 @@ } // CHECK-LABEL: test.check_result_count_1 -// CHECK: "test.op"() : () -> i32 // CHECK: "test.success"() : () -> () +// CHECK: "test.op"() : () -> i32 // CHECK-NOT: "test.op"() : () -> (i32, i32) module @ir attributes { test.check_result_count_1 } { "test.op"() : () -> i32 @@ -504,8 +504,8 @@ } // CHECK-LABEL: test.check_types_1 -// CHECK: "test.op"() : () -> (i32, i64) // CHECK: "test.success" +// CHECK: "test.op"() : () -> (i32, i64) // CHECK-NOT: "test.op"() : () -> i32 module @ir attributes { test.check_types_1 } { "test.op"() : () -> (i32, i64) @@ -515,6 +515,12 @@ // ----- //===----------------------------------------------------------------------===// +// pdl_interp::ContinueOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// // pdl_interp::CreateAttributeOp //===----------------------------------------------------------------------===// @@ -583,6 +589,61 @@ // Fully tested within the tests for other operations. //===----------------------------------------------------------------------===// +// pdl_interp::ForEachOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// +// pdl_interp::GetUsersOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %result0 = pdl_interp.get_result 0 of %root + %ops = pdl_interp.get_users of %result0 : !pdl.value at 1 + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.check_operand_count of %op is 2 -> ^pat1, ^cont + ^pat1: + pdl_interp.record_match @rewriters::@success(%result0, %op : !pdl.value, !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%result0 :!pdl.value, %matched : !pdl.operation) { + %type = pdl_interp.create_type i32 + %rep = pdl_interp.create_operation "test.op"(%result0 : !pdl.value) -> (%type : !pdl.type) + %val = pdl_interp.get_result 0 of %rep + pdl_interp.replace %matched with (%val : !pdl.value) + %op = pdl_interp.create_operation "test.success" + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_users +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[OPERAND0:.*]] = "test.op" +// CHECK: "test.op"(%[[OPERAND0]]) +// CHECK: %[[OPERAND1:.*]] = "test.op" +// CHECK: "test.op"(%[[OPERAND1]]) +// CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND1]], %[[OPERAND1]]) +module @ir attributes { test.get_users } { + %operand1 = "test.op"() : () -> i32 + %operand2 = "test.op"() : () -> i32 + "test.op"(%operand1, %operand2) : (i32, i32) -> (i32) + "test.op"(%operand1, %operand2, %operand2) : (i32, i32, i32) -> (i32) + "test.op"(%operand2, %operand1) : (i32, i32) -> (i32) +} + +// ----- + +//===----------------------------------------------------------------------===// // pdl_interp::GetAttributeOp //===----------------------------------------------------------------------===// @@ -626,9 +687,9 @@ } // CHECK-LABEL: test.get_defining_op_1 +// CHECK: "test.success" // CHECK: %[[OPERAND0:.*]] = "test.op" // CHECK: %[[OPERAND1:.*]] = "test.op" -// CHECK: "test.success" // CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND1]]) module @ir attributes { test.get_defining_op_1 } { %operand = "test.op"() : () -> i32 @@ -737,11 +798,11 @@ } // CHECK-LABEL: test.get_operands_2 -// CHECK-NEXT: %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) // CHECK-NEXT: "test.success"() : () -> () -// CHECK-NEXT: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> () +// CHECK-NEXT: %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) // CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> () // CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> () +// CHECK-NEXT: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> () module @ir attributes { test.get_operands_2 } { %inputs:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) "test.attr_sized_operands"(%inputs#0, %inputs#1, %inputs#2, %inputs#3, %inputs#4) {operand_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : (i32, i32, i32, i32, i32) -> () @@ -891,10 +952,10 @@ } // CHECK-LABEL: test.get_results_2 -// CHECK: "test.success"() : () -> () -// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32) -// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32 // CHECK: %[[RESULTS_2_SINGLE:.*]] = "test.success"() : () -> i32 +// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32 +// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32) +// CHECK: "test.success"() : () -> () // CHECK: "test.consumer"(%[[RESULTS_1]]#0, %[[RESULTS_1]]#1, %[[RESULTS_1]]#2, %[[RESULTS_1]]#3, %[[RESULTS_2]]) : (i32, i32, i32, i32, i32) -> () module @ir attributes { test.get_results_2 } { %results:5 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : () -> (i32, i32, i32, i32, i32)