diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -29,6 +29,9 @@ using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; +/// A range over operations. +using OpRange = llvm::ArrayRef; + //===----------------------------------------------------------------------===// // PDLByteCodePattern //===----------------------------------------------------------------------===// @@ -80,6 +83,13 @@ std::vector memory; /// A mutable block of memory used during the matching and rewriting phase of + /// the bytecode to store ranges of operations. + std::vector opRangeMemory; + /// A set of operation ranges that have been allocated by the byte code + /// interpreter to provide a guaranteed lifetime. + std::vector> allocatedOpRangeMemory; + + /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector typeRangeMemory; /// A set of type ranges that have been allocated by the byte code interpreter @@ -171,6 +181,9 @@ /// * Attribute, Identifier, OperationName, Type std::vector uniquedData; + /// A map from address (index into matcherByteCode) to corresponding block. + DenseMap addrToBlock; + /// A vector containing the generated bytecode for the matcher. SmallVector matcherByteCode; @@ -188,6 +201,7 @@ ByteCodeField maxValueMemoryIndex = 0; /// The maximum number of different types of ranges. + ByteCodeField maxOpRangeCount = 0; ByteCodeField maxTypeRangeCount = 0; ByteCodeField maxValueRangeCount = 0; }; diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -67,6 +67,7 @@ /// This method should be called irregardless of whether the match+rewrite was a /// success or not. void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { + allocatedOpRangeMemory.clear(); allocatedTypeRangeMemory.clear(); allocatedValueRangeMemory.clear(); } @@ -95,6 +96,8 @@ CheckResultCount, /// Compare a range of types to a constant range of types. CheckTypes, + /// Performs an iterative choice over a range of operations. + ChooseOp, /// Create an operation. CreateOperation, /// Create a range of types. @@ -103,6 +106,8 @@ EraseOp, /// Terminate a matcher or rewrite sequence. Finalize, + /// Gets all accepting operations of a value and store them in a range. + GetAcceptingOps, /// Get a specific attribute of an operation. GetAttribute, /// Get the type of an attribute. @@ -158,23 +163,27 @@ // Generator namespace { +struct ByteCodeLiveRange; struct ByteCodeWriter; /// This class represents the main generator for the pattern bytecode. class Generator { public: Generator(MLIRContext *ctx, std::vector &uniquedData, + DenseMap &addrToBlock, SmallVectorImpl &matcherByteCode, SmallVectorImpl &rewriterByteCode, SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, + ByteCodeField &maxOpRangeMemoryIndex, ByteCodeField &maxTypeRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex, llvm::StringMap &constraintFns, llvm::StringMap &rewriteFns) - : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode), - rewriterByteCode(rewriterByteCode), patterns(patterns), - maxValueMemoryIndex(maxValueMemoryIndex), + : ctx(ctx), uniquedData(uniquedData), addrToBlock(addrToBlock), + matcherByteCode(matcherByteCode), rewriterByteCode(rewriterByteCode), + patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), + maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) { for (auto it : llvm::enumerate(constraintFns)) @@ -220,6 +229,11 @@ /// and rewriters. void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); + /// Extend the lifetime of a value to its block and block descendants. + void extendLifetime(ByteCodeLiveRange &range, Value value); + void extendLifetime(ByteCodeLiveRange &range, Block *block, + DenseSet &visited); + /// Generate the bytecode for the given operation. void generate(Operation *op, ByteCodeWriter &writer); void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer); @@ -232,12 +246,14 @@ void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ChooseOpOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer); void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer); void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetAcceptingOpsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); @@ -282,12 +298,17 @@ /// The current MLIR context. MLIRContext *ctx; + /// Mapping from operation to its index. + DenseMap opToIndex; + /// Data of the ByteCode class to be populated. std::vector &uniquedData; + DenseMap &addrToBlock; SmallVectorImpl &matcherByteCode; SmallVectorImpl &rewriterByteCode; SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; + ByteCodeField &maxOpRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; }; @@ -396,25 +417,35 @@ /// This class represents a live range of PDL Interpreter values, containing /// information about when values are live within a match/rewrite. struct ByteCodeLiveRange { - using Set = llvm::IntervalMap; + using Set = llvm::IntervalMap; using Allocator = Set::Allocator; - ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {} + ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {} /// Union this live range with the one provided. void unionWith(const ByteCodeLiveRange &rhs) { - for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it) - liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0); + for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e; + ++it) + liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0); } /// Returns true if this range overlaps with the one provided. bool overlaps(const ByteCodeLiveRange &rhs) const { - return llvm::IntervalMapOverlaps(liveness, rhs.liveness).valid(); + return llvm::IntervalMapOverlaps(*liveness, *rhs.liveness) + .valid(); } /// A map representing the ranges of the match/rewrite that a value is live in /// the interpreter. - llvm::IntervalMap liveness; + /// + /// We use std::unique_ptr here, because the move constructor of IntervalMap + /// appears to have a bug that shows up when it is included here directly. + /// Specifically, the start bound of the interval gets populated by wrong + /// value; this only shows up in release builds and is hard to reproduce. + std::unique_ptr> liveness; + + /// The operation range storage index for this range. + Optional opRangeIndex; /// The type range storage index for this range. Optional typeRangeIndex; @@ -452,8 +483,10 @@ for (Block *block : rpot) { // Keep track of where this block begins within the matcher function. blockToAddr.try_emplace(block, matcherByteCode.size()); - for (Operation &op : *block) + for (Operation &op : *block) { + LLVM_DEBUG(addrToBlock.try_emplace(matcherByteCode.size(), block)); generate(&op, matcherByteCodeWriter); + } } // Resolve successor references in the matcher. @@ -501,8 +534,11 @@ // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. - DenseMap opToIndex; + SmallVector> iterativeIndices; matcherFunc.getBody().walk([&](Operation *op) { + // Extract the indices of iterative operations. + if (isa(op)) + iterativeIndices.emplace_back(op, opToIndex.size()); opToIndex.insert(std::make_pair(op, opToIndex.size())); }); @@ -527,7 +563,7 @@ // Set indices for the range of this block that the value is used. auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first; - defRangeIt->second.liveness.insert( + defRangeIt->second.liveness->insert( opToIndex[firstUseOrDef], opToIndex[info->getEndOperation(value, firstUseOrDef)], /*dummyValue*/ 0); @@ -535,7 +571,9 @@ // Check to see if this value is a range type. if (auto rangeTy = value.getType().dyn_cast()) { Type eleType = rangeTy.getElementType(); - if (eleType.isa()) + if (eleType.isa()) + defRangeIt->second.opRangeIndex = 0; + else if (eleType.isa()) defRangeIt->second.typeRangeIndex = 0; else if (eleType.isa()) defRangeIt->second.valueRangeIndex = 0; @@ -552,9 +590,40 @@ processValue(result, &op); } + // For the values that are alive at the time an iterative operation is + // performed, extend their lifetime until the end. + bool first = true; + for (auto &valueDefRange : valueDefRanges) { + Value value = valueDefRange.first; + ByteCodeLiveRange &range = valueDefRange.second; + unsigned valueIndex = opToIndex[value.getDefiningOp()]; + for (auto p : iterativeIndices) { + Operation *op = p.first; + unsigned index = p.second; + if (index != valueIndex && range.liveness->overlaps(index, index)) { + if (first) { + LLVM_DEBUG(llvm::dbgs() << "Extending lifetime:\n"); + first = false; + } + LLVM_DEBUG(llvm::dbgs() + << " * Value: " << valueIndex << " " << value + << "\n * Overlaps: " << index << " " << *op << "\n"); + extendLifetime(range, value); + break; + } + } + } + // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; - ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0; + + // The number of memory indices currently allocated (and its next value). + // Recall that the roots gets allocated memory index 0. + ByteCodeField numIndices = 1; + + // The number of memory ranges of various types (and their next values). + ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; + for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; ByteCodeLiveRange &defRange = defIt.second; @@ -566,7 +635,11 @@ existingRange.unionWith(defRange); memIndex = existingIndexIt.index() + 1; - if (defRange.typeRangeIndex) { + if (defRange.opRangeIndex) { + if (!existingRange.opRangeIndex) + existingRange.opRangeIndex = numOpRanges++; + valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; + } else if (defRange.typeRangeIndex) { if (!existingRange.typeRangeIndex) existingRange.typeRangeIndex = numTypeRanges++; valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex; @@ -585,8 +658,11 @@ ByteCodeLiveRange &newRange = allocatedIndices.back(); newRange.unionWith(defRange); - // Allocate an index for type/value ranges. - if (defRange.typeRangeIndex) { + // Allocate an index for op/type/value ranges. + if (defRange.opRangeIndex) { + newRange.opRangeIndex = numOpRanges; + valueToRangeIndex[defIt.first] = numOpRanges++; + } else if (defRange.typeRangeIndex) { newRange.typeRangeIndex = numTypeRanges; valueToRangeIndex[defIt.first] = numTypeRanges++; } else if (defRange.valueRangeIndex) { @@ -599,15 +675,81 @@ } } + // Ensure that we did not run out of index space. + LLVM_DEBUG({ + llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " + << "(down from initial " << valueDefRanges.size() << ").\n"; + }); + assert(allocatedIndices.size() < std::numeric_limits::max() && + "Ran out of memory for allocated indices"); + // Update the max number of indices. if (numIndices > maxValueMemoryIndex) maxValueMemoryIndex = numIndices; + if (numOpRanges > maxOpRangeMemoryIndex) + maxOpRangeMemoryIndex = numOpRanges; if (numTypeRanges > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = numTypeRanges; if (numValueRanges > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = numValueRanges; } +void Generator::extendLifetime(ByteCodeLiveRange &range, Value value) { + DenseSet visited; + Block *block = value.getParentBlock(); + visited.insert(block); + + // Determine the endpoints of the value in its parent block + unsigned start = opToIndex[value.getDefiningOp()]; + unsigned stop = opToIndex[&block->back()]; + + // Start is in the range. Check for the rare case when the value is not alive + // at the end of its parent block. TODO: is this even possible? + auto it = range.liveness->find(start); + assert(it.valid() && "Value not found in live range"); + if (stop > it.stop()) { + range.liveness->insert(it.stop() + 1, stop, /*dummyValue*/ 0); + } + + // Descend to the successor blocks + for (Block *successor : block->getSuccessors()) + extendLifetime(range, successor, visited); +} + +void Generator::extendLifetime(ByteCodeLiveRange &range, Block *block, + DenseSet &visited) { + // Mark the block as visited; return if previously visited. + bool newlyVisited = visited.insert(block).second; + if (!newlyVisited) + return; + + // Skip if the block contains the sole finalize operation + if (block->getOperations().size() == 1 && + isa(block->front())) + return; + + // Extend the liveness to this entire block + unsigned start = opToIndex[&block->front()]; + unsigned stop = opToIndex[&block->back()]; + auto it = range.liveness->find(start); + if (!it.valid() || stop < it.start()) + // No overlap => insert the entire range. + ; + else if (stop > it.stop()) + // Extend past the original liveness. + start = it.stop() + 1; + else + // Already covering the interval + start = stop + 1; + + if (start <= stop) { + range.liveness->insert(start, stop, /*dummyValue*/ 0); + } + + for (Block *successor : block->getSuccessors()) + extendLifetime(range, successor, visited); +} + void Generator::generate(Operation *op, ByteCodeWriter &writer) { TypeSwitch(op) .Case constraintFns, llvm::StringMap rewriteFns) { - Generator generator(module.getContext(), uniquedData, matcherByteCode, - rewriterByteCode, patterns, maxValueMemoryIndex, - maxTypeRangeCount, maxValueRangeCount, constraintFns, - rewriteFns); + Generator generator(module.getContext(), uniquedData, addrToBlock, + matcherByteCode, rewriterByteCode, patterns, + maxValueMemoryIndex, maxOpRangeCount, maxTypeRangeCount, + maxValueRangeCount, constraintFns, rewriteFns); generator.generate(module); // Initialize the external functions. @@ -883,6 +1036,7 @@ /// bytecode. void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); + state.opRangeMemory.resize(maxOpRangeCount, OpRange()); state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); state.currentPatternBenefits.reserve(patterns.size()); @@ -899,6 +1053,8 @@ public: ByteCodeExecutor( const ByteCodeField *curCodeIt, MutableArrayRef memory, + MutableArrayRef opRangeMemory, + std::vector> &allocatedOpRangeMemory, MutableArrayRef typeRangeMemory, std::vector> &allocatedTypeRangeMemory, MutableArrayRef valueRangeMemory, @@ -907,15 +1063,18 @@ ArrayRef currentPatternBenefits, ArrayRef patterns, ArrayRef constraintFunctions, - ArrayRef rewriteFunctions) - : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory), + ArrayRef rewriteFunctions, + const DenseMap *addrToBlock = nullptr) + : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), + allocatedOpRangeMemory(allocatedOpRangeMemory), + typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), valueRangeMemory(valueRangeMemory), allocatedValueRangeMemory(allocatedValueRangeMemory), uniquedMemory(uniquedMemory), code(code), currentPatternBenefits(currentPatternBenefits), patterns(patterns), constraintFunctions(constraintFunctions), - rewriteFunctions(rewriteFunctions) {} + rewriteFunctions(rewriteFunctions), addrToBlock(addrToBlock) {} /// Start executing the code at the current bytecode index. `matches` is an /// optional field provided when this function is executed in a matching @@ -935,10 +1094,13 @@ void executeCheckOperationName(); void executeCheckResultCount(); void executeCheckTypes(); + bool executeChooseOp(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); void executeCreateTypes(); void executeEraseOp(PatternRewriter &rewriter); + bool executeFinalize(); + void executeGetAcceptingOps(); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); @@ -959,6 +1121,19 @@ void executeSwitchType(); void executeSwitchTypes(); + /// Pushes a code iterator to the stack. + void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); } + + /// Pops a code iterator from the stack, returning true on success. + bool popCodeIt() { + if (resumeCodeIt.empty()) + return false; + + curCodeIt = resumeCodeIt.back(); + resumeCodeIt.pop_back(); + return true; + } + /// Read a value from the bytecode buffer, optionally skipping a certain /// number of prefix values. These methods always update the buffer to point /// to the next field after the read data. @@ -1079,8 +1254,13 @@ /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; + /// The stack of bytecode positions at which to resume operation. + SmallVector resumeCodeIt; + /// The current execution memory. MutableArrayRef memory; + MutableArrayRef opRangeMemory; + std::vector> &allocatedOpRangeMemory; MutableArrayRef typeRangeMemory; std::vector> &allocatedTypeRangeMemory; MutableArrayRef valueRangeMemory; @@ -1093,6 +1273,9 @@ ArrayRef patterns; ArrayRef constraintFunctions; ArrayRef rewriteFunctions; + + /// Optional map from code index to the corresponding block. + const DenseMap *addrToBlock; }; /// This class is an instantiation of the PDLResultList that provides access to @@ -1280,6 +1463,26 @@ selectJump(*lhs == rhs.cast().getAsValueRange()); } +bool ByteCodeExecutor::executeChooseOp() { + LLVM_DEBUG(llvm::dbgs() << "Executing ChooseOp:\n"); + const ByteCodeField *it = curCodeIt - 1; // Subtract 1 for the op code + unsigned rangeIndex = read(); + unsigned memIndex = read(); + OpRange &range = opRangeMemory[rangeIndex]; + if (range.empty()) { + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + return popCodeIt(); + } + + Operation *op = range.front(); + // TODO: here we mutate the range, maybe need a defensive copy + range = range.drop_front(); + LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); + memory[memIndex] = op; + pushCodeIt(it); + return true; +} + void ByteCodeExecutor::executeCreateTypes() { LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n"); unsigned memIndex = read(); @@ -1338,6 +1541,21 @@ break; } + // Determine the insertion point. + // This is the last of the operands, default to beginning of the block. + Block *block = rewriter.getBlock(); + rewriter.setInsertionPointToStart(block); + DenseSet operationSet; + for (Value value : state.operands) + operationSet.insert(value.getDefiningOp()); + for (auto it = block->rbegin(); it != block->rend(); ++it) + if (operationSet.count(&*it)) { + LLVM_DEBUG(llvm::dbgs() + << "Inserting " << state.name << " after " << *it << "\n"); + rewriter.setInsertionPointAfter(&*it); + break; + } + Operation *resultOp = rewriter.createOperation(state); memory[memIndex] = resultOp; @@ -1360,6 +1578,47 @@ rewriter.eraseOp(op); } +bool ByteCodeExecutor::executeFinalize() { + LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); + return popCodeIt(); +} + +void ByteCodeExecutor::executeGetAcceptingOps() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetAcceptingOps:\n"); + unsigned rangeIndex = read(); + unsigned operandNumber = read(); + LLVM_DEBUG(llvm::dbgs() << " * Operand: " << operandNumber << "\n"); + + // Read the value. + Value value; + if (read() == PDLValue::Kind::Value) { + value = read(); + LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + } else { + ValueRange *values = read(); + if (values && !values->empty()) + value = values->front(); + LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); + } + + // Determine all uses at the specified operand number. + SmallVector users; + if (value) + for (OpOperand &use : value.getUses()) + if (use.getOperandNumber() == operandNumber) + users.push_back(use.getOwner()); + LLVM_DEBUG(llvm::dbgs() << " * Result: " << users.size() << " operations\n"); + + // Allocate a buffer for the operation range. + llvm::OwningArrayRef storage(users.size()); + llvm::copy(users, storage.begin()); + allocatedOpRangeMemory.emplace_back(std::move(storage)); + + // Assign this to the range slot. + // TODO: do we need to store this in memory[] as well? + opRangeMemory[rangeIndex] = allocatedOpRangeMemory.back(); +} + void ByteCodeExecutor::executeGetAttribute() { LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); unsigned memIndex = read(); @@ -1705,6 +1964,14 @@ SmallVectorImpl *matches, Optional mainRewriteLoc) { while (true) { + LLVM_DEBUG({ + if (addrToBlock) { + Block *block = addrToBlock->lookup(curCodeIt - code.data()); + assert(block && "Unknown address"); + block->printAsOperand(llvm::dbgs()); + llvm::dbgs() << " "; + } + }); OpCode opCode = static_cast(read()); switch (opCode) { case ApplyConstraint: @@ -1734,6 +2001,11 @@ case CheckTypes: executeCheckTypes(); break; + case ChooseOp: + if (executeChooseOp()) + break; + LLVM_DEBUG(llvm::dbgs() << "\n"); + return; case CreateOperation: executeCreateOperation(rewriter, *mainRewriteLoc); break; @@ -1744,8 +2016,13 @@ executeEraseOp(rewriter); break; case Finalize: - LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n"); + if (executeFinalize()) + break; + LLVM_DEBUG(llvm::dbgs() << "\n"); return; + case GetAcceptingOps: + executeGetAcceptingOps(); + break; case GetAttribute: executeGetAttribute(); break; @@ -1837,11 +2114,12 @@ // The matcher function always starts at code address 0. ByteCodeExecutor executor( - matcherByteCode.data(), state.memory, state.typeRangeMemory, + matcherByteCode.data(), state.memory, state.opRangeMemory, + state.allocatedOpRangeMemory, state.typeRangeMemory, state.allocatedTypeRangeMemory, state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, constraintFunctions, - rewriteFunctions); + rewriteFunctions, &addrToBlock); executor.execute(rewriter, &matches); // Order the found matches by benefit. @@ -1860,9 +2138,10 @@ ByteCodeExecutor executor( &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, - state.typeRangeMemory, state.allocatedTypeRangeMemory, - state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData, - rewriterByteCode, state.currentPatternBenefits, patterns, - constraintFunctions, rewriteFunctions); + state.opRangeMemory, state.allocatedOpRangeMemory, state.typeRangeMemory, + state.allocatedTypeRangeMemory, state.valueRangeMemory, + state.allocatedValueRangeMemory, uniquedData, rewriterByteCode, + state.currentPatternBenefits, patterns, constraintFunctions, + rewriteFunctions); executor.execute(rewriter, /*matches=*/nullptr, match.location); } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -255,8 +255,8 @@ } // CHECK-LABEL: test.are_equal_2 -// CHECK: "test.not_equal" // CHECK: "test.success" +// CHECK: "test.not_equal" // CHECK-NOT: "test.op" module @ir attributes { test.are_equal_2 } { "test.not_equal"() : () -> (i32) @@ -361,8 +361,8 @@ } // CHECK-LABEL: test.check_operand_count_1 -// CHECK: "test.op"() : () -> i32 // CHECK: "test.success" +// CHECK: "test.op"() : () -> i32 module @ir attributes { test.check_operand_count_1 } { %operand = "test.op"() : () -> i32 "test.op"(%operand, %operand) : (i32, i32) -> () @@ -430,8 +430,8 @@ } // CHECK-LABEL: test.check_result_count_1 -// CHECK: "test.op"() : () -> i32 // CHECK: "test.success"() : () -> () +// CHECK: "test.op"() : () -> i32 // CHECK-NOT: "test.op"() : () -> (i32, i32) module @ir attributes { test.check_result_count_1 } { "test.op"() : () -> i32 @@ -504,8 +504,8 @@ } // CHECK-LABEL: test.check_types_1 -// CHECK: "test.op"() : () -> (i32, i64) // CHECK: "test.success" +// CHECK: "test.op"() : () -> (i32, i64) // CHECK-NOT: "test.op"() : () -> i32 module @ir attributes { test.check_types_1 } { "test.op"() : () -> (i32, i64) @@ -515,6 +515,12 @@ // ----- //===----------------------------------------------------------------------===// +// pdl_interp::ChooseOpOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + +//===----------------------------------------------------------------------===// // pdl_interp::CreateAttributeOp //===----------------------------------------------------------------------===// @@ -583,6 +589,55 @@ // Fully tested within the tests for other operations. //===----------------------------------------------------------------------===// +// pdl_interp::GetAcceptingOpsOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%root : !pdl.operation) { + %result0 = pdl_interp.get_result 0 of %root + %ops = pdl_interp.get_accepting_ops of %result0 : !pdl.value at 1 + %op = pdl_interp.choose_op from %ops + + pdl_interp.check_operand_count of %op is 2 -> ^pat1, ^end + + ^pat1: + pdl_interp.record_match @rewriters::@success(%result0, %op : !pdl.value, !pdl.operation) : benefit(1), loc([%root]) -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + func @success(%result0 :!pdl.value, %matched : !pdl.operation) { + %type = pdl_interp.create_type i32 + %rep = pdl_interp.create_operation "test.op"(%result0 : !pdl.value) -> (%type : !pdl.type) + %val = pdl_interp.get_result 0 of %rep + pdl_interp.replace %matched with (%val : !pdl.value) + %op = pdl_interp.create_operation "test.success" + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.get_accepting_ops +// CHECK: "test.success" +// CHECK: "test.success" +// CHECK: %[[OPERAND0:.*]] = "test.op" +// CHECK: "test.op"(%[[OPERAND0]]) +// CHECK: %[[OPERAND1:.*]] = "test.op" +// CHECK: "test.op"(%[[OPERAND1]]) +// CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND1]], %[[OPERAND1]]) +module @ir attributes { test.get_accepting_ops } { + %operand1 = "test.op"() : () -> i32 + %operand2 = "test.op"() : () -> i32 + "test.op"(%operand1, %operand2) : (i32, i32) -> (i32) + "test.op"(%operand1, %operand2, %operand2) : (i32, i32, i32) -> (i32) + "test.op"(%operand2, %operand1) : (i32, i32) -> (i32) +} + +// ----- + +//===----------------------------------------------------------------------===// // pdl_interp::GetAttributeOp //===----------------------------------------------------------------------===// @@ -626,9 +681,9 @@ } // CHECK-LABEL: test.get_defining_op_1 +// CHECK: "test.success" // CHECK: %[[OPERAND0:.*]] = "test.op" // CHECK: %[[OPERAND1:.*]] = "test.op" -// CHECK: "test.success" // CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND1]]) module @ir attributes { test.get_defining_op_1 } { %operand = "test.op"() : () -> i32 @@ -737,11 +792,11 @@ } // CHECK-LABEL: test.get_operands_2 -// CHECK-NEXT: %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) // CHECK-NEXT: "test.success"() : () -> () -// CHECK-NEXT: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> () +// CHECK-NEXT: %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) // CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> () // CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> () +// CHECK-NEXT: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> () module @ir attributes { test.get_operands_2 } { %inputs:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) "test.attr_sized_operands"(%inputs#0, %inputs#1, %inputs#2, %inputs#3, %inputs#4) {operand_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : (i32, i32, i32, i32, i32) -> () @@ -891,10 +946,10 @@ } // CHECK-LABEL: test.get_results_2 -// CHECK: "test.success"() : () -> () -// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32) -// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32 // CHECK: %[[RESULTS_2_SINGLE:.*]] = "test.success"() : () -> i32 +// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32 +// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32) +// CHECK: "test.success"() : () -> () // CHECK: "test.consumer"(%[[RESULTS_1]]#0, %[[RESULTS_1]]#1, %[[RESULTS_1]]#2, %[[RESULTS_1]]#3, %[[RESULTS_2]]) : (i32, i32, i32, i32, i32) -> () module @ir attributes { test.get_results_2 } { %results:5 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : () -> (i32, i32, i32, i32, i32)