diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -380,7 +380,7 @@ let summary = "Breaks the current iteration"; let description = [{ `pdl_interp.continue` operation breaks the current iteration within the - `pdl_interp.iterate` region and continues with the next iteration from + `pdl_interp.foreach` region and continues with the next iteration from the beginning of the region. Example: @@ -823,20 +823,44 @@ : PDLInterp_Op<"get_users", [NoSideEffect]> { let summary = "Get the users of a `Value`"; let description = [{ - `pdl_interp.get_users` returns the operations using a value (or first value - from a range of values) at the specified operand position. + `pdl_interp.get_users` extracts the users that accept this value or a range + of values at the specified optional operand group. If an index is provided, + the index is the group as defined by the ODS definition of the operation. + If an index is not provided and a single value is given, this operation + extracts all the users of the value. If an index is not provided and a range + of values is given, this operation extracts all the users that accept this + range of values as all their operands. Example: ```mlir - %ops = pdl_interp.get_users of %value : !pdl.value at 2 + // Get the users of a single value for the given index. + %ops = pdl_interp.get_users 1 of %value : !pdl.value + + // Get the users of a range of values for the given index. + %ops = pdl_interp.get_users 2 of %values : !pdl.range + + // Get all the users of a single value. + %ops = pdl_interp.get_users of %value : !pdl.value + + // Get the users of a range of values for all operands. + %ops = pdl_interp.get_users of %values : !pdl.range ``` }]; - let arguments = (ins PDL_InstOrRangeOf:$value, - Confined:$index); + let arguments = (ins + PDL_InstOrRangeOf:$value, + OptionalAttr>:$index + ); let results = (outs PDL_RangeOf:$operations); - let assemblyFormat = "`of` $value `:` type($value) `at` $index attr-dict"; + let assemblyFormat = "($index^)? `of` $value `:` type($value) attr-dict"; + let builders = [ + OpBuilder<(ins "Value":$value, "Optional":$index), [{ + build($_builder, $_state, + pdl::RangeType::get($_builder.getType()), + value, index ? $_builder.getI32IntegerAttr(*index) : IntegerAttr()); + }]>, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -28,6 +28,7 @@ /// entries. ByteCodeAddr refers to size of indices into the bytecode. using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; +using OpRange = llvm::OwningArrayRef; //===----------------------------------------------------------------------===// // PDLByteCodePattern @@ -80,6 +81,12 @@ std::vector memory; /// A mutable block of memory used during the matching and rewriting phase of + /// the bytecode to store ranges of operations. These are always stored by + /// owning references, because at no point in the execution of the byte code + /// we get an indexed range (view) of operations. + std::vector opRangeMemory; + + /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector typeRangeMemory; /// A set of type ranges that have been allocated by the byte code interpreter @@ -93,6 +100,11 @@ /// interpreter to provide a guaranteed lifetime. std::vector> allocatedValueRangeMemory; + /// The current index of ranges being iterated over for each level of nesting. + /// These are always maintained at 0 for the loops that are not active, so we + /// do not need to have a separate initialization phase for each loop. + std::vector loopIndex; + /// The up-to-date benefits of the patterns held by the bytecode. The order /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. std::vector currentPatternBenefits; @@ -188,8 +200,12 @@ ByteCodeField maxValueMemoryIndex = 0; /// The maximum number of different types of ranges. + ByteCodeField maxOpRangeCount = 0; ByteCodeField maxTypeRangeCount = 0; ByteCodeField maxValueRangeCount = 0; + + /// The maximum number of nested loops. + ByteCodeField maxLoopLevel = 0; }; } // end namespace detail diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -95,6 +95,8 @@ CheckResultCount, /// Compare a range of types to a constant range of types. CheckTypes, + /// Continue to the next iteration of a loop. + Continue, /// Create an operation. CreateOperation, /// Create a range of types. @@ -103,6 +105,8 @@ EraseOp, /// Terminate a matcher or rewrite sequence. Finalize, + /// Iterate over a range of values. + ForEach, /// Get a specific attribute of an operation. GetAttribute, /// Get the type of an attribute. @@ -125,12 +129,16 @@ GetResultN, /// Get a specific result group of an operation. GetResults, + /// Gets all users of a value and store them in a range. + GetUsers, /// Get the type of a value. GetValueType, /// Get the types of a value range. GetValueRangeTypes, /// Check if a generic value is not null. IsNotNull, + /// No operation, used in the debug mode to consume the line number. + NoOp, /// Record a successful pattern match. RecordMatch, /// Replace an operation. @@ -158,6 +166,7 @@ // Generator namespace { +struct ByteCodeLiveRange; struct ByteCodeWriter; /// This class represents the main generator for the pattern bytecode. @@ -168,15 +177,19 @@ SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, + ByteCodeField &maxOpRangeMemoryIndex, ByteCodeField &maxTypeRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex, + ByteCodeField &maxLoopLevel, llvm::StringMap &constraintFns, llvm::StringMap &rewriteFns) : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), + maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), - maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { + maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), + maxLoopLevel(maxLoopLevel) { for (auto it : llvm::enumerate(constraintFns)) constraintToMemIndex.try_emplace(it.value().first(), it.index()); for (auto it : llvm::enumerate(rewriteFns)) @@ -221,6 +234,7 @@ void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); /// Generate the bytecode for the given operation. + void generate(Region *region, ByteCodeWriter &writer); void generate(Operation *op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer); @@ -232,6 +246,7 @@ void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); @@ -245,9 +260,11 @@ void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer); void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer); @@ -279,17 +296,25 @@ /// `uniquedData`. DenseMap uniquedDataToMemIndex; + /// The current level of the foreach loop. + ByteCodeField curLoopLevel = 0; + /// The current MLIR context. MLIRContext *ctx; + /// Mapping from block to its address. + DenseMap blockToAddr; + /// Data of the ByteCode class to be populated. std::vector &uniquedData; SmallVectorImpl &matcherByteCode; SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; + ByteCodeField &maxOpRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; + ByteCodeField &maxLoopLevel; }; /// This class provides utilities for writing a bytecode stream. @@ -311,15 +336,20 @@ bytecode.append({fieldParts[0], fieldParts[1]}); } + /// Append a single successor to the bytecode, the exact address will need to + /// be resolved later. + void append(Block *successor) { + // Add back a reference to the successor so that the address can be resolved + // later. + unresolvedSuccessorRefs[successor].push_back(bytecode.size()); + append(ByteCodeAddr(0)); + } + /// Append a successor range to the bytecode, the exact address will need to /// be resolved later. void append(SuccessorRange successors) { - // Add back references to the any successors so that the address can be - // resolved later. - for (Block *successor : successors) { - unresolvedSuccessorRefs[successor].push_back(bytecode.size()); - append(ByteCodeAddr(0)); - } + for (Block *successor : successors) + append(successor); } /// Append a range of values that will be read as generic PDLValues. @@ -336,10 +366,12 @@ } /// Append the PDLValue::Kind of the given value. - void appendPDLValueKind(Value value) { - // Append the type of the value in addition to the value itself. + void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); } + + /// Append the PDLValue::Kind of the given type. + void appendPDLValueKind(Type type) { PDLValue::Kind kind = - TypeSwitch(value.getType()) + TypeSwitch(type) .Case( [](Type) { return PDLValue::Kind::Attribute; }) .Case( @@ -396,25 +428,35 @@ /// This class represents a live range of PDL Interpreter values, containing /// information about when values are live within a match/rewrite. struct ByteCodeLiveRange { - using Set = llvm::IntervalMap; + using Set = llvm::IntervalMap; using Allocator = Set::Allocator; - ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} + ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} /// Union this live range with the one provided. void unionWith(const ByteCodeLiveRange &rhs) { - for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) - liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); + for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; + ++it) + liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); } /// Returns true if this range overlaps with the one provided. bool overlaps(const ByteCodeLiveRange &rhs) const { - return llvm::IntervalMapOverlaps(liveness, rhs.liveness).valid(); + return llvm::IntervalMapOverlaps(*liveness, *rhs.liveness) + .valid(); } /// A map representing the ranges of the match/rewrite that a value is live in /// the interpreter. - llvm::IntervalMap liveness; + /// + /// We use std::unique_ptr here, because the move constructor of IntervalMap + /// appears to have a bug that shows up when it is included here directly. + /// Specifically, the start bound of the interval gets populated by wrong + /// value; this only shows up in release builds and is hard to reproduce. + std::unique_ptr> liveness; + + /// The operation range storage index for this range. + Optional opRangeIndex; /// The type range storage index for this range. Optional typeRangeIndex; @@ -446,15 +488,8 @@ "unexpected branches in rewriter function"); // Generate code for the matcher function. - DenseMap blockToAddr; - llvm::ReversePostOrderTraversal rpot(&matcherFunc.getBody()); ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this); - for (Block *block : rpot) { - // Keep track of where this block begins within the matcher function. - blockToAddr.try_emplace(block, matcherByteCode.size()); - for (Operation &op : *block) - generate(&op, matcherByteCodeWriter); - } + generate(&matcherFunc.getBody(), matcherByteCodeWriter); // Resolve successor references in the matcher. for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) { @@ -501,7 +536,7 @@ // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. - DenseMap opToIndex; + DenseMap opToIndex; matcherFunc.getBody().walk([&](Operation *op) { opToIndex.insert(std::make_pair(op, opToIndex.size())); }); @@ -516,8 +551,8 @@ // Walk each of the blocks, computing the def interval that the value is used. Liveness matcherLiveness(matcherFunc); - for (Block &block : matcherFunc.getBody()) { - const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block); + matcherFunc->walk([&](Block *block) { + const LivenessBlockInfo *info = matcherLiveness.getLiveness(block); assert(info && "expected liveness info for block"); auto processValue = [&](Value value, Operation *firstUseOrDef) { // We don't need to process the root op argument, this value is always @@ -527,7 +562,7 @@ // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; - defRangeIt->second.liveness.insert( + defRangeIt->second.liveness->insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); @@ -535,7 +570,9 @@ // Check to see if this value is a range type. if (auto rangeTy = value.getType().dyn_cast()) { Type eleType = rangeTy.getElementType(); - if (eleType.isa()) + if (eleType.isa()) + defRangeIt->second.opRangeIndex = 0; + else if (eleType.isa()) defRangeIt->second.typeRangeIndex = 0; else if (eleType.isa()) defRangeIt->second.valueRangeIndex = 0; @@ -543,18 +580,38 @@ }; // Process the live-ins of this block. - for (Value liveIn : info->in()) - processValue(liveIn, &block.front()); + for (Value liveIn : info->in()) { + // Only process the value if it has been defined in the current region. + // Other values that span across pdl_interp.foreach will be added higher + // up. This ensures that the we keep them alive for the entire duration + // of the loop. + if (liveIn.getParentRegion() == block->getParent()) + processValue(liveIn, &block->front()); + } + + // Process the block arguments for the entry block (those are not live-in). + if (block->isEntryBlock()) { + for (Value argument : block->getArguments()) { + processValue(argument, &block->front()); + } + } // Process any new defs within this block. - for (Operation &op : block) + for (Operation &op : *block) for (Value result : op.getResults()) processValue(result, &op); - } + }); // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; - ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; + + // The number of memory indices currently allocated (and its next value). + // Recall that the root gets allocated memory index 0. + ByteCodeField numIndices = 1; + + // The number of memory ranges of various types (and their next values). + ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; + for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; ByteCodeLiveRange &defRange = defIt.second; @@ -566,7 +623,11 @@ existingRange.unionWith(defRange); memIndex = existingIndexIt.index() + 1; - if (defRange.typeRangeIndex) { + if (defRange.opRangeIndex) { + if (!existingRange.opRangeIndex) + existingRange.opRangeIndex = numOpRanges++; + valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; + } else if (defRange.typeRangeIndex) { if (!existingRange.typeRangeIndex) existingRange.typeRangeIndex = numTypeRanges++; valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; @@ -585,8 +646,11 @@ ByteCodeLiveRange &newRange = allocatedIndices.back(); newRange.unionWith(defRange); - // Allocate an index for type/value ranges. - if (defRange.typeRangeIndex) { + // Allocate an index for op/type/value ranges. + if (defRange.opRangeIndex) { + newRange.opRangeIndex = numOpRanges; + valueToRangeIndex[defIt.first] = numOpRanges++; + } else if (defRange.typeRangeIndex) { newRange.typeRangeIndex = numTypeRanges; valueToRangeIndex[defIt.first] = numTypeRanges++; } else if (defRange.valueRangeIndex) { @@ -599,34 +663,62 @@ } } + // Print the index usage and ensure that we did not run out of index space. + LLVM_DEBUG({ + llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " + << "(down from initial " << valueDefRanges.size() << ").\n"; + }); + assert(allocatedIndices.size() <= std::numeric_limits::max() && + "Ran out of memory for allocated indices"); + // Update the max number of indices. if (numIndices > maxValueMemoryIndex) maxValueMemoryIndex = numIndices; + if (numOpRanges > maxOpRangeMemoryIndex) + maxOpRangeMemoryIndex = numOpRanges; if (numTypeRanges > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = numTypeRanges; if (numValueRanges > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = numValueRanges; } +void Generator::generate(Region *region, ByteCodeWriter &writer) { + llvm::ReversePostOrderTraversal rpot(region); + for (Block *block : rpot) { + // Keep track of where this block begins within the matcher function. + blockToAddr.try_emplace(block, matcherByteCode.size()); + for (Operation &op : *block) + generate(&op, writer); + } +} + void Generator::generate(Operation *op, ByteCodeWriter &writer) { + LLVM_DEBUG({ + if (auto loc = op->getLoc().dyn_cast()) + writer.append(ByteCodeAddr(loc.getLine())); + else + writer.append(ByteCodeAddr(0)); + }); TypeSwitch(op) .Case( + pdl_interp::GetResultsOp, pdl_interp::GetUsersOp, + pdl_interp::GetValueTypeOp, pdl_interp::InferredTypesOp, + pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, + pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, + pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, + pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, + pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); @@ -707,10 +799,15 @@ void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); } +void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { + assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); + writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); +} void Generator::generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.attribute()) = getMemIndex(op.value()); + LLVM_DEBUG(writer.append(OpCode::NoOp)); } void Generator::generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer) { @@ -731,6 +828,7 @@ void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) { // Simply repoint the memory index of the result to the constant. getMemIndex(op.result()) = getMemIndex(op.value()); + LLVM_DEBUG(writer.append(OpCode::NoOp)); } void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) { writer.append(OpCode::CreateTypes, op.result(), @@ -796,6 +894,13 @@ writer.append(std::numeric_limits::max()); writer.append(result); } +void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) { + Value operations = op.operations(); + Optional index = op.index(); + writer.append(OpCode::GetUsers, operations, getRangeStorageIndex(operations), + index.getValueOr(std::numeric_limits::max())); + writer.appendPDLValue(op.value()); +} void Generator::generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer) { if (op.getType().isa()) { @@ -811,10 +916,22 @@ ByteCodeWriter &writer) { // InferType maps to a null type as a marker for inferring result types. getMemIndex(op.type()) = getMemIndex(Type()); + LLVM_DEBUG(writer.append(OpCode::NoOp)); } void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); } +void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) { + BlockArgument arg = op.getLoopVariable(); + writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg); + writer.appendPDLValueKind(arg.getType()); + writer.append(curLoopLevel, op.successor()); + ++curLoopLevel; + if (curLoopLevel > maxLoopLevel) + maxLoopLevel = curLoopLevel; + generate(&op.region(), writer); + --curLoopLevel; +} void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( @@ -868,8 +985,8 @@ llvm::StringMap rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, - maxTypeRangeCount, maxValueRangeCount, constraintFns, - rewriteFns); + maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, + maxLoopLevel, constraintFns, rewriteFns); generator.generate(module); // Initialize the external functions. @@ -883,8 +1000,10 @@ /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); + state.opRangeMemory.resize(maxOpRangeCount); state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); + state.loopIndex.resize(maxLoopLevel, 0); state.currentPatternBenefits.reserve(patterns.size()); for (const PDLByteCodePattern &pattern : patterns) state.currentPatternBenefits.push_back(pattern.getBenefit()); @@ -899,20 +1018,23 @@ public: ByteCodeExecutor( const ByteCodeField *curCodeIt, MutableArrayRef memory, + MutableArrayRef> opRangeMemory, MutableArrayRef typeRangeMemory, std::vector> &allocatedTypeRangeMemory, MutableArrayRef valueRangeMemory, std::vector> &allocatedValueRangeMemory, - ArrayRef uniquedMemory, ArrayRef code, + MutableArrayRef loopIndex, ArrayRef uniquedMemory, + ArrayRef code, ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, ArrayRef rewriteFunctions) - : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), + : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), + typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), valueRangeMemory(valueRangeMemory), allocatedValueRangeMemory(allocatedValueRangeMemory), - uniquedMemory(uniquedMemory), code(code), + loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), rewriteFunctions(rewriteFunctions) {} @@ -935,10 +1057,13 @@ void executeCheckOperationName(); void executeCheckResultCount(); void executeCheckTypes(); + void executeContinue(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); void executeCreateTypes(); void executeEraseOp(PatternRewriter &rewriter); + void executeFinalize(); + void executeForEach(); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); @@ -946,9 +1071,11 @@ void executeGetOperands(); void executeGetResult(unsigned index); void executeGetResults(); + void executeGetUsers(); void executeGetValueType(); void executeGetValueRangeTypes(); void executeIsNotNull(); + void executeNoOp(); void executeRecordMatch(PatternRewriter &rewriter, SmallVectorImpl &matches); void executeReplaceOp(PatternRewriter &rewriter); @@ -959,6 +1086,16 @@ void executeSwitchType(); void executeSwitchTypes(); + /// Pushes a code iterator to the stack. + void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } + + /// Pops a code iterator from the stack, returning true on success. + void popCodeIt() { + assert(!resumeCodeIt.empty() && "Attempt to pop code off empty stack"); + curCodeIt = resumeCodeIt.back(); + resumeCodeIt.pop_back(); + } + /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. @@ -1079,13 +1216,20 @@ /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; + /// The stack of bytecode positions at which to resume operation. + SmallVector resumeCodeIt; + /// The current execution memory. MutableArrayRef memory; + MutableArrayRef opRangeMemory; MutableArrayRef typeRangeMemory; std::vector> &allocatedTypeRangeMemory; MutableArrayRef valueRangeMemory; std::vector> &allocatedValueRangeMemory; + /// The current loop indices. + MutableArrayRef loopIndex; + /// References to ByteCode data necessary for execution. ArrayRef uniquedMemory; ArrayRef code; @@ -1280,6 +1424,14 @@ selectJump(*lhs == rhs.cast().getAsValueRange()); } +void ByteCodeExecutor::executeContinue() { + ByteCodeField level = read(); + LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" + << " * Level: " << level << "\n"); + ++loopIndex[level]; + popCodeIt(); +} + void ByteCodeExecutor::executeCreateTypes() { LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); unsigned memIndex = read(); @@ -1360,6 +1512,46 @@ rewriter.eraseOp(op); } +void ByteCodeExecutor::executeFinalize() { + LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); +} + +void ByteCodeExecutor::executeForEach() { + LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); + const ByteCodeField *it = curCodeIt - 1; // Subtract 1 for the op code. + LLVM_DEBUG(it -= 2); // Subtract 2 for the line number. + unsigned rangeIndex = read(); + unsigned memIndex = read(); + const void *value = nullptr; + + switch (read()) { + case PDLValue::Kind::Operation: { + unsigned &index = loopIndex[read()]; + const ArrayRef &array = opRangeMemory[rangeIndex]; + assert(index <= array.size() && "iterated past the end"); + if (index < array.size()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); + value = array[index]; + break; + } + + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + index = 0; + selectJump(size_t(0)); + return; + } + default: + llvm_unreachable("unexpected `ForEach` value kind"); + } + + // Store the iterate value and the stack address. + memory[memIndex] = value; + pushCodeIt(it); + + // Skip over the successor (we will enter the body of the loop). + read(); +} + void ByteCodeExecutor::executeGetAttribute() { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); unsigned memIndex = read(); @@ -1424,7 +1616,7 @@ static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, - MutableArrayRef &valueRangeMemory) { + const MutableArrayRef &valueRangeMemory) { // Check for the sentinel index that signals that all values should be // returned. if (index == std::numeric_limits::max()) { @@ -1512,6 +1704,79 @@ memory[read()] = result; } +void ByteCodeExecutor::executeGetUsers() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); + unsigned memIndex = read(); + unsigned rangeIndex = read(); + unsigned operandNumber = read(); + OpRange &range = opRangeMemory[rangeIndex]; + memory[memIndex] = ⦥ + bool single = false; + + // A single value or a representative value of a range. + Value value; + // A (possibly empty) range of all the values. + ValueRange values; + + // Read the value(s). + if (read() == PDLValue::Kind::Value) { + single = true; + value = read(); + if (value) + values = ValueRange(value); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + } else { + if (auto *value_ptr = read()) + values = *value_ptr; + if (!values.empty()) + value = values.front(); + LLVM_DEBUG({ + llvm::dbgs() << " * Values (" << values.size() << "): "; + llvm::interleaveComma(values, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + } + + // Extract the users. + if (!value) { + // No value or empty range of values given, so no users. + range = OpRange(); + } else if (single && operandNumber == std::numeric_limits::max()) { + // Special case: all users of a single value. + range = OpRange(std::distance(value.user_begin(), value.user_end())); + llvm::copy(value.getUsers(), range.begin()); + } else { + // Default case: users for a specific group or the entire range. + std::vector users; + users.reserve(std::distance(value.user_begin(), value.user_end())); + + // Iterate over all the users of the representative value, extract the + // operands of specified group, and compare with the value range. + for (Operation *op : value.getUsers()) { + // Extract the operands for the specified group. + ValueRange operands; + void *result = + executeGetOperandsResults( + op->getOperands(), op, operandNumber, 0, "operand_segment_sizes", + MutableArrayRef(operands)); + if (result) { + // Either the extracted operands are the same as the values, or + // we were given a single value and we are acceessing a single operand. + if (llvm::equal(operands, values) || + (single && !op->hasTrait() && + !operands.empty() && operands[0] == value)) + users.push_back(op); + } + } + + // Populated the user range. + range = OpRange(users.size()); + llvm::copy(users, range.begin()); + } + + LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); +} + void ByteCodeExecutor::executeGetValueType() { LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); unsigned memIndex = read(); @@ -1553,6 +1818,10 @@ selectJump(value != nullptr); } +void ByteCodeExecutor::executeNoOp() { + LLVM_DEBUG(llvm::dbgs() << "Executing NoOp\n"); +} + void ByteCodeExecutor::executeRecordMatch( PatternRewriter &rewriter, SmallVectorImpl &matches) { @@ -1705,6 +1974,10 @@ SmallVectorImpl *matches, Optional mainRewriteLoc) { while (true) { + LLVM_DEBUG({ + if (unsigned line = read()) + llvm::dbgs() << line << ": "; + }); OpCode opCode = static_cast(read()); switch (opCode) { case ApplyConstraint: @@ -1734,6 +2007,9 @@ case CheckTypes: executeCheckTypes(); break; + case Continue: + executeContinue(); + break; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; @@ -1744,8 +2020,12 @@ executeEraseOp(rewriter); break; case Finalize: - LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); + executeFinalize(); + LLVM_DEBUG(llvm::dbgs() << "\n"); return; + case ForEach: + executeForEach(); + break; case GetAttribute: executeGetAttribute(); break; @@ -1787,6 +2067,9 @@ case GetResults: executeGetResults(); break; + case GetUsers: + executeGetUsers(); + break; case GetValueType: executeGetValueType(); break; @@ -1796,6 +2079,9 @@ case IsNotNull: executeIsNotNull(); break; + case NoOp: + executeNoOp(); + break; case RecordMatch: assert(matches && "expected matches to be provided when executing the matcher"); @@ -1837,11 +2123,11 @@ // The matcher function always starts at code address 0. ByteCodeExecutor executor( - matcherByteCode.data(), state.memory, state.typeRangeMemory, - state.allocatedTypeRangeMemory, state.valueRangeMemory, - state.allocatedValueRangeMemory, uniquedData, matcherByteCode, - state.currentPatternBenefits, patterns, constraintFunctions, - rewriteFunctions); + matcherByteCode.data(), state.memory, state.opRangeMemory, + state.typeRangeMemory, state.allocatedTypeRangeMemory, + state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, + uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, + constraintFunctions, rewriteFunctions); executor.execute(rewriter, &matches); // Order the found matches by benefit. @@ -1860,8 +2146,9 @@ ByteCodeExecutor executor( &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, - state.typeRangeMemory, state.allocatedTypeRangeMemory, - state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, + state.opRangeMemory, state.typeRangeMemory, + state.allocatedTypeRangeMemory, state.valueRangeMemory, + state.allocatedValueRangeMemory, state.loopIndex, uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, constraintFunctions, rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); diff --git a/mlir/test/Dialect/PDLInterp/ops.mlir b/mlir/test/Dialect/PDLInterp/ops.mlir --- a/mlir/test/Dialect/PDLInterp/ops.mlir +++ b/mlir/test/Dialect/PDLInterp/ops.mlir @@ -39,12 +39,18 @@ // ----- -func @users(%input: !pdl.value, %inputs: !pdl.range) { - // single value - %ops1 = pdl_interp.get_users of %input : !pdl.value at 2 +func @users(%value: !pdl.value, %values: !pdl.range) { + // the users of a single value for the given index + %ops1 = pdl_interp.get_users 1 of %value : !pdl.value - // a range of values - %ops2 = pdl_interp.get_users of %inputs : !pdl.range at 3 + // the users of a range of values for the given index + %ops2 = pdl_interp.get_users 2 of %values : !pdl.range + + // all the users of a single value + %ops3 = pdl_interp.get_users of %value : !pdl.value + + // the users of a range of values for all operands + %ops4 = pdl_interp.get_users of %values : !pdl.range pdl_interp.finalize } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -515,6 +515,12 @@ // ----- //===----------------------------------------------------------------------===// +// pdl_interp::ContinueOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// // pdl_interp::CreateAttributeOp //===----------------------------------------------------------------------===// @@ -583,6 +589,204 @@ // Fully tested within the tests for other operations. //===----------------------------------------------------------------------===// +// pdl_interp::ForEachOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %val1 = pdl_interp.get_result 0 of %root + %ops1 = pdl_interp.get_users of %val1 : !pdl.value + pdl_interp.foreach %op1 : !pdl.operation in %ops1 { + %val2 = pdl_interp.get_result 0 of %op1 + %ops2 = pdl_interp.get_users of %val2 : !pdl.value + pdl_interp.foreach %op2 : !pdl.operation in %ops2 { + pdl_interp.record_match @rewriters::@success(%op2 : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.foreach +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[ROOT:.*]] = "test.op" +// CHECK: %[[VALA:.*]] = "test.op"(%[[ROOT]]) +// CHECK: %[[VALB:.*]] = "test.op"(%[[ROOT]]) +module @ir attributes { test.foreach } { + %root = "test.op"() : () -> i32 + %valA = "test.op"(%root) : (i32) -> (i32) + "test.op"(%valA) : (i32) -> (i32) + "test.op"(%valA) : (i32) -> (i32) + %valB = "test.op"(%root) : (i32) -> (i32) + "test.op"(%valB) : (i32) -> (i32) + "test.op"(%valB) : (i32) -> (i32) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl_interp::GetUsersOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %val = pdl_interp.get_result 0 of %root + %ops = pdl_interp.get_users of %val : !pdl.value + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_all_users_of_value +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[OPERAND:.*]] = "test.op" +module @ir attributes { test.get_all_users_of_value } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> (i32) + "test.op"(%operand, %operand) : (i32, i32) -> (i32) +} + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + %val = pdl_interp.get_result 0 of %root + %ops = pdl_interp.get_users 1 of %val : !pdl.value + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_specific_users_of_value +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[OPERAND:.*]] = "test.op" +// CHECK: "test.op"(%[[OPERAND]]) +module @ir attributes { test.get_specific_users_of_value } { + %operand = "test.op"() : () -> i32 + "test.op"(%operand) : (i32) -> (i32) + "test.op"(%operand, %operand) : (i32, i32) -> (i32) + "test.op"(%operand, %operand, %operand) : (i32, i32, i32) -> (i32) +} + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end + ^next: + %vals = pdl_interp.get_results of %root : !pdl.range + %ops = pdl_interp.get_users of %vals : !pdl.range + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_all_users_of_range +// CHECK: "test.success" +// CHECK: %[[OPERANDS:.*]]:2 = "test.op" +// CHECK: "test.op"(%[[OPERANDS]]#0, %[[OPERANDS]]#0) +module @ir attributes { test.get_all_users_of_range } { + %operands:2 = "test.op"() : () -> (i32, i32) + "test.op"(%operands#0, %operands#1) : (i32, i32) -> (i32) + "test.op"(%operands#0, %operands#0) : (i32, i32) -> (i32) +} + +// ----- + +module @patterns { + func @matcher(%root : !pdl.operation) { + pdl_interp.check_result_count of %root is at_least 2 -> ^next, ^end + ^next: + %vals = pdl_interp.get_results of %root : !pdl.range + %ops = pdl_interp.get_users 1 of %vals : !pdl.range + pdl_interp.foreach %op : !pdl.operation in %ops { + pdl_interp.record_match @rewriters::@success(%op : !pdl.operation) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%matched : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %matched + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_specific_users_of_range +// CHECK: "test.success" +// CHECK: %[[OPERANDS:.*]]:2 = "test.op" +// CHECK: "test.op"(%[[OPERANDS]]#0, %[[OPERANDS]]#1) +module @ir attributes { test.get_specific_users_of_range } { + %operands:2 = "test.op"() : () -> (i32, i32) + "test.op"(%operands#0, %operands#1) : (i32, i32) -> (i32) + "test.op"(%operands#0, %operands#0, %operands#1) : (i32, i32, i32) -> (i32) +} + +// ----- + +//===----------------------------------------------------------------------===// // pdl_interp::GetAttributeOp //===----------------------------------------------------------------------===//