diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -91,10 +91,10 @@ LogicalResult RangeType::verify(function_ref emitError, Type elementType) { - if (!elementType.isa() || elementType.isa()) { + if (!elementType.isa()) { return emitError() << "expected element of pdl.range to be one of [!pdl.attribute, " - "!pdl.operation, !pdl.type, !pdl.value], but got " + "!pdl.operation, !pdl.type, !pdl.value, !pdl.range], but got " << elementType; } return success(); diff --git a/mlir/lib/Rewrite/ByteCode.h b/mlir/lib/Rewrite/ByteCode.h --- a/mlir/lib/Rewrite/ByteCode.h +++ b/mlir/lib/Rewrite/ByteCode.h @@ -29,6 +29,7 @@ using ByteCodeField = uint16_t; using ByteCodeAddr = uint32_t; using OwningOpRange = llvm::OwningArrayRef; +using RangeRange = llvm::ArrayRef; //===----------------------------------------------------------------------===// // PDLByteCodePattern @@ -86,6 +87,13 @@ /// we get an indexed range (view) of operations. std::vector opRangeMemory; + /// A mutable block of memory used during the matching and rewriting phase of + /// the bytecode to store ranges of value range. + std::vector rangeRangeMemory; + /// A set of value range ranges that have been allocated by the byte code + /// interpreter to provide a guaranteed lifetime. + std::vector> allocatedRangeRangeMemory; + /// A mutable block of memory used during the matching and rewriting phase of /// the bytecode to store ranges of types. std::vector typeRangeMemory; @@ -201,6 +209,7 @@ /// The maximum number of different types of ranges. ByteCodeField maxOpRangeCount = 0; + ByteCodeField maxRangeRangeCount = 0; ByteCodeField maxTypeRangeCount = 0; ByteCodeField maxValueRangeCount = 0; diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -67,6 +67,7 @@ /// This method should be called irregardless of whether the match+rewrite was a /// success or not. void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() { + allocatedRangeRangeMemory.clear(); allocatedTypeRangeMemory.clear(); allocatedValueRangeMemory.clear(); } @@ -95,6 +96,8 @@ CheckResultCount, /// Compare a range of types to a constant range of types. CheckTypes, + /// Performs a nondeterministic choice over a range of orderings. + ChooseRange, /// Continue to the next iteration of a loop. Continue, /// Create an operation. @@ -119,6 +122,8 @@ GetAttributeType, /// Get the defining operation of a value. GetDefiningOp, + /// Get a specific value from a list of values. + GetItem, /// Get a specific operand of an operation. GetOperand0, GetOperand1, @@ -127,6 +132,8 @@ GetOperandN, /// Get a specific operand group of an operation. GetOperands, + /// Get a permutation of a list of values. + GetPermutations, /// Get a specific result of an operation. GetResult0, GetResult1, @@ -141,6 +148,8 @@ GetValueType, /// Get the types of a value range. GetValueRangeTypes, + /// Check if a given operation is commutative. + IsCommutative, /// Check if a generic value is not null. IsNotNull, /// Record a successful pattern match. @@ -186,6 +195,7 @@ SmallVectorImpl &patterns, ByteCodeField &maxValueMemoryIndex, ByteCodeField &maxOpRangeMemoryIndex, + ByteCodeField &maxRangeRangeMemoryIndex, ByteCodeField &maxTypeRangeMemoryIndex, ByteCodeField &maxValueRangeMemoryIndex, ByteCodeField &maxLoopLevel, @@ -195,6 +205,7 @@ rewriterByteCode(rewriterByteCode), patterns(patterns), maxValueMemoryIndex(maxValueMemoryIndex), maxOpRangeMemoryIndex(maxOpRangeMemoryIndex), + maxRangeRangeMemoryIndex(maxRangeRangeMemoryIndex), maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex), maxValueRangeMemoryIndex(maxValueRangeMemoryIndex), maxLoopLevel(maxLoopLevel) { @@ -241,6 +252,11 @@ /// and rewriters. void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule); + /// Extend the lifetime of a value to its block and block descendants. + void extendLifetime(ByteCodeLiveRange &range, Value value); + void extendLifetime(ByteCodeLiveRange &range, Block *block, + DenseSet &visited); + /// Generate the bytecode for the given operation. void generate(Region *region, ByteCodeWriter &writer); void generate(Operation *op, ByteCodeWriter &writer); @@ -254,6 +270,7 @@ void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer); + void generate(pdl_interp::ChooseRangeOp op, ByteCodeWriter &writer); void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer); @@ -266,13 +283,16 @@ void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetItemOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer); + void generate(pdl_interp::GetPermutationsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer); void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer); void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer); + void generate(pdl_interp::IsCommutativeOp op, ByteCodeWriter &writer); void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer); void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer); void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer); @@ -311,6 +331,10 @@ /// The current MLIR context. MLIRContext *ctx; + /// Mapping from operation to its first and last indices. + DenseMap opToFirstIndex; + DenseMap opToLastIndex; + /// Mapping from block to its address. DenseMap blockToAddr; @@ -321,6 +345,7 @@ SmallVectorImpl &patterns; ByteCodeField &maxValueMemoryIndex; ByteCodeField &maxOpRangeMemoryIndex; + ByteCodeField &maxRangeRangeMemoryIndex; ByteCodeField &maxTypeRangeMemoryIndex; ByteCodeField &maxValueRangeMemoryIndex; ByteCodeField &maxLoopLevel; @@ -473,6 +498,9 @@ /// The operation range storage index for this range. Optional opRangeIndex; + /// The range range storage index for this range. + Optional rangeRangeIndex; + /// The type range storage index for this range. Optional typeRangeIndex; @@ -551,14 +579,16 @@ // finding the minimal number of overlapping live ranges. This is essentially // a simplified form of register allocation where we don't necessarily have a // limited number of registers, but we still want to minimize the number used. - DenseMap opToFirstIndex; - DenseMap opToLastIndex; // A custom walk that marks the first and the last index of each operation. // The entry marks the beginning of the liveness range for this operation, // followed by nested operations, followed by the end of the liveness range. unsigned index = 0; + SmallVector> nondeterministicIndices; llvm::unique_function walk = [&](Operation *op) { + // Extract the indices of nondeterministic operations. + if (isa(op)) + nondeterministicIndices.emplace_back(op, index); opToFirstIndex.try_emplace(op, index++); for (Region ®ion : op->getRegions()) for (Block &block : region.getBlocks()) @@ -599,6 +629,8 @@ Type eleType = rangeTy.getElementType(); if (eleType.isa()) defRangeIt->second.opRangeIndex = 0; + else if (eleType.isa()) + defRangeIt->second.rangeRangeIndex = 0; else if (eleType.isa()) defRangeIt->second.typeRangeIndex = 0; else if (eleType.isa()) @@ -628,6 +660,30 @@ processValue(result, &op); }); + // For the values that are alive at the time a nondeterministic operation is + // performed, extend their lifetime until the end. + bool first = true; + for (auto &valueDefRange : valueDefRanges) { + Value value = valueDefRange.first; + ByteCodeLiveRange &range = valueDefRange.second; + unsigned valueIndex = opToFirstIndex[value.getDefiningOp()]; + for (auto p : nondeterministicIndices) { + Operation *op = p.first; + unsigned index = p.second; + if (index != valueIndex && range.liveness->overlaps(index, index)) { + if (first) { + LLVM_DEBUG(llvm::dbgs() << "Extending lifetime:\n"); + first = false; + } + LLVM_DEBUG(llvm::dbgs() + << " * Value: " << valueIndex << " " << value + << "\n * Overlaps: " << index << " " << *op << "\n"); + extendLifetime(range, value); + break; + } + } + } + // Greedily allocate memory slots using the computed def live ranges. std::vector allocatedIndices; @@ -636,7 +692,8 @@ ByteCodeField numIndices = 1; // The number of memory ranges of various types (and their next values). - ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0; + ByteCodeField numOpRanges = 0, numRangeRanges = 0, numTypeRanges = 0, + numValueRanges = 0; for (auto &defIt : valueDefRanges) { ByteCodeField &memIndex = valueToMemIndex[defIt.first]; @@ -653,6 +710,10 @@ if (!existingRange.opRangeIndex) existingRange.opRangeIndex = numOpRanges++; valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex; + } else if (defRange.rangeRangeIndex) { + if (!existingRange.rangeRangeIndex) + existingRange.rangeRangeIndex = numRangeRanges++; + valueToRangeIndex[defIt.first] = *existingRange.rangeRangeIndex; } else if (defRange.typeRangeIndex) { if (!existingRange.typeRangeIndex) existingRange.typeRangeIndex = numTypeRanges++; @@ -672,10 +733,13 @@ ByteCodeLiveRange &newRange = allocatedIndices.back(); newRange.unionWith(defRange); - // Allocate an index for op/type/value ranges. + // Allocate an index for op/value range/type/value ranges. if (defRange.opRangeIndex) { newRange.opRangeIndex = numOpRanges; valueToRangeIndex[defIt.first] = numOpRanges++; + } else if (defRange.rangeRangeIndex) { + newRange.rangeRangeIndex = numRangeRanges; + valueToRangeIndex[defIt.first] = numRangeRanges++; } else if (defRange.typeRangeIndex) { newRange.typeRangeIndex = numTypeRanges; valueToRangeIndex[defIt.first] = numTypeRanges++; @@ -702,12 +766,62 @@ maxValueMemoryIndex = numIndices; if (numOpRanges > maxOpRangeMemoryIndex) maxOpRangeMemoryIndex = numOpRanges; + if (numRangeRanges > maxRangeRangeMemoryIndex) + maxRangeRangeMemoryIndex = numRangeRanges; if (numTypeRanges > maxTypeRangeMemoryIndex) maxTypeRangeMemoryIndex = numTypeRanges; if (numValueRanges > maxValueRangeMemoryIndex) maxValueRangeMemoryIndex = numValueRanges; } +void Generator::extendLifetime(ByteCodeLiveRange &range, Value value) { + DenseSet visited; + Block *block = value.getParentBlock(); + visited.insert(block); + // Determine the endpoints of the value in its parent block + unsigned start = opToFirstIndex[value.getDefiningOp()]; + unsigned stop = opToFirstIndex[&block->back()]; + // Start is in the range. Check for the rare case when the value is not alive + // at the end of its parent block. TODO: is this even possible? + auto it = range.liveness->find(start); + assert(it.valid() && "Value not found in live range"); + if (stop > it.stop()) { + range.liveness->insert(it.stop() + 1, stop, /*dummyValue*/ 0); + } + // Descend to the successor blocks + for (Block *successor : block->getSuccessors()) + extendLifetime(range, successor, visited); +} +void Generator::extendLifetime(ByteCodeLiveRange &range, Block *block, + DenseSet &visited) { + // Mark the block as visited; return if previously visited. + bool newlyVisited = visited.insert(block).second; + if (!newlyVisited) + return; + // Skip if the block contains the sole finalize operation + if (block->getOperations().size() == 1 && + isa(block->front())) + return; + // Extend the liveness to this entire block + unsigned start = opToFirstIndex[&block->front()]; + unsigned stop = opToFirstIndex[&block->back()]; + auto it = range.liveness->find(start); + if (!it.valid() || stop < it.start()) + // No overlap => insert the entire range. + ; + else if (stop > it.stop()) + // Extend past the original liveness. + start = it.stop() + 1; + else + // Already covering the interval + start = stop + 1; + if (start <= stop) { + range.liveness->insert(start, stop, /*dummyValue*/ 0); + } + for (Block *successor : block->getSuccessors()) + extendLifetime(range, successor, visited); +} + void Generator::generate(Region *region, ByteCodeWriter &writer) { llvm::ReversePostOrderTraversal rpot(region); for (Block *block : rpot) { @@ -732,20 +846,22 @@ pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp, pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp, pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp, - pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp, - pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp, - pdl_interp::CreateTypesOp, pdl_interp::EraseOp, - pdl_interp::ExtractOp, pdl_interp::FinalizeOp, + pdl_interp::ChooseRangeOp, pdl_interp::ContinueOp, + pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp, + pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp, + pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp, pdl_interp::ForEachOp, pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp, - pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp, + pdl_interp::GetItemOp, pdl_interp::GetOperandOp, + pdl_interp::GetOperandsOp, pdl_interp::GetPermutationsOp, pdl_interp::GetResultOp, pdl_interp::GetResultsOp, pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp, - pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp, - pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp, - pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp, - pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp, - pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>( + pdl_interp::InferredTypesOp, pdl_interp::IsCommutativeOp, + pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp, + pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp, + pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp, + pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp, + pdl_interp::SwitchResultCountOp>( [&](auto interpOp) { this->generate(interpOp, writer); }) .Default([](Operation *) { llvm_unreachable("unknown `pdl_interp` operation"); @@ -826,6 +942,10 @@ void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) { writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors()); } +void Generator::generate(pdl_interp::ChooseRangeOp op, ByteCodeWriter &writer) { + writer.append(OpCode::ChooseRange, getRangeStorageIndex(op.list()), + getRangeStorageIndex(op.range())); +} void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) { assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level"); writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1)); @@ -898,6 +1018,11 @@ writer.append(OpCode::GetDefiningOp, op.operation()); writer.appendPDLValue(op.value()); } +void Generator::generate(pdl_interp::GetItemOp op, ByteCodeWriter &writer) { + uint32_t index = op.index(); + writer.append(OpCode::GetItem, index, getRangeStorageIndex(op.values())); + writer.append(op.value()); +} void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); if (index < 4) @@ -918,6 +1043,11 @@ writer.append(std::numeric_limits::max()); writer.append(result); } +void Generator::generate(pdl_interp::GetPermutationsOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::GetPermutations, getRangeStorageIndex(op.list()), + getRangeStorageIndex(op.permutations())); +} void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) { uint32_t index = op.index(); if (index < 4) @@ -960,6 +1090,10 @@ // InferType maps to a null type as a marker for inferring result types. getMemIndex(op.type()) = getMemIndex(Type()); } +void Generator::generate(pdl_interp::IsCommutativeOp op, + ByteCodeWriter &writer) { + writer.append(OpCode::IsCommutative, op.op(), op.getSuccessors()); +} void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) { writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors()); } @@ -1016,8 +1150,9 @@ llvm::StringMap rewriteFns) { Generator generator(module.getContext(), uniquedData, matcherByteCode, rewriterByteCode, patterns, maxValueMemoryIndex, - maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount, - maxLoopLevel, constraintFns, rewriteFns); + maxOpRangeCount, maxRangeRangeCount, maxTypeRangeCount, + maxValueRangeCount, maxLoopLevel, constraintFns, + rewriteFns); generator.generate(module); // Initialize the external functions. @@ -1032,6 +1167,7 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const { state.memory.resize(maxValueMemoryIndex, nullptr); state.opRangeMemory.resize(maxOpRangeCount); + state.rangeRangeMemory.resize(maxRangeRangeCount, RangeRange()); state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange()); state.valueRangeMemory.resize(maxValueRangeCount, ValueRange()); state.loopIndex.resize(maxLoopLevel, 0); @@ -1050,6 +1186,8 @@ ByteCodeExecutor( const ByteCodeField *curCodeIt, MutableArrayRef memory, MutableArrayRef> opRangeMemory, + MutableArrayRef rangeRangeMemory, + std::vector> &allocatedRangeRangeMemory, MutableArrayRef typeRangeMemory, std::vector> &allocatedTypeRangeMemory, MutableArrayRef valueRangeMemory, @@ -1061,6 +1199,8 @@ ArrayRef constraintFunctions, ArrayRef rewriteFunctions) : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory), + rangeRangeMemory(rangeRangeMemory), + allocatedRangeRangeMemory(allocatedRangeRangeMemory), typeRangeMemory(typeRangeMemory), allocatedTypeRangeMemory(allocatedTypeRangeMemory), valueRangeMemory(valueRangeMemory), @@ -1088,6 +1228,7 @@ void executeCheckOperationName(); void executeCheckResultCount(); void executeCheckTypes(); + bool executeChooseRange(); void executeContinue(); void executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc); @@ -1095,18 +1236,21 @@ void executeEraseOp(PatternRewriter &rewriter); template void executeExtract(); - void executeFinalize(); + bool executeFinalize(); void executeForEach(); void executeGetAttribute(); void executeGetAttributeType(); void executeGetDefiningOp(); + void executeGetItem(unsigned index); void executeGetOperand(unsigned index); void executeGetOperands(); + void executeGetPermutations(); void executeGetResult(unsigned index); void executeGetResults(); void executeGetUsers(); void executeGetValueType(); void executeGetValueRangeTypes(); + void executeIsCommutative(); void executeIsNotNull(); void executeRecordMatch(PatternRewriter &rewriter, SmallVectorImpl &matches); @@ -1287,6 +1431,8 @@ /// The current execution memory. MutableArrayRef memory; MutableArrayRef opRangeMemory; + MutableArrayRef rangeRangeMemory; + std::vector> &allocatedRangeRangeMemory; MutableArrayRef typeRangeMemory; std::vector> &allocatedTypeRangeMemory; MutableArrayRef valueRangeMemory; @@ -1489,6 +1635,28 @@ selectJump(*lhs == rhs.cast().getAsValueRange()); } +bool ByteCodeExecutor::executeChooseRange() { + LLVM_DEBUG(llvm::dbgs() << "Executing ChooseRange:\n"); + const ByteCodeField *it = curCodeIt - 1; // Subtract 1 for the op code. + unsigned permutationsIndex = read(); + unsigned rangeIndex = read(); + RangeRange &permutation = rangeRangeMemory[permutationsIndex]; + if (permutation.empty()) { + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + if (resumeCodeIt.empty()) + return false; + popCodeIt(); + return true; + } + + ValueRange range = permutation.front(); + + permutation = permutation.drop_front(); + valueRangeMemory[rangeIndex] = range; + pushCodeIt(it); + return true; +} + void ByteCodeExecutor::executeContinue() { ByteCodeField level = read(); LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" @@ -1555,6 +1723,21 @@ break; } + // Determine the insertion point. + // This is the last of the operands, default to beginning of the block. + Block *block = rewriter.getBlock(); + rewriter.setInsertionPointToStart(block); + DenseSet operationSet; + for (Value value : state.operands) + operationSet.insert(value.getDefiningOp()); + for (auto it = block->rbegin(); it != block->rend(); ++it) + if (operationSet.count(&*it)) { + LLVM_DEBUG(llvm::dbgs() + << "Inserting " << state.name << " after " << *it << "\n"); + rewriter.setInsertionPointAfter(&*it); + break; + } + Operation *resultOp = rewriter.createOperation(state); memory[memIndex] = resultOp; @@ -1596,8 +1779,12 @@ storeToMemory(memIndex, result); } -void ByteCodeExecutor::executeFinalize() { +bool ByteCodeExecutor::executeFinalize() { LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); + if (resumeCodeIt.empty()) + return false; + popCodeIt(); + return true; } void ByteCodeExecutor::executeForEach() { @@ -1680,6 +1867,16 @@ memory[memIndex] = op; } +void ByteCodeExecutor::executeGetItem(unsigned index) { + unsigned rangeIndex = read(); + unsigned memindex = read(); + ValueRange range = valueRangeMemory[rangeIndex]; + Value operand = index < range.size() ? range[index] : Value(); + LLVM_DEBUG(llvm::dbgs() << " * Index " << index << "\n" + << " * Result: " << operand << "\n"); + memory[memindex] = operand.getAsOpaquePointer(); +} + void ByteCodeExecutor::executeGetOperand(unsigned index) { Operation *op = read(); unsigned memIndex = read(); @@ -1761,6 +1958,71 @@ memory[read()] = result; } +static int64_t factorial(unsigned num) { + return (num == 1 || num == 0) ? 1 : num * factorial(num - 1); +} + +/// Populates the initial ordering of `operandsVector`. This gets stored in +/// `ordering`. +static void getOperandInitialOrdering(const ArrayRef &operandsVector, + std::vector &ordering) { + llvm::DenseMap valueToIndex; + for (unsigned i = 0, e = operandsVector.size(); i < e; i++) { + auto it = valueToIndex.try_emplace(operandsVector[i], i); + if (!it.second) + ordering.push_back(valueToIndex[operandsVector[i]]); + else + ordering.push_back(i); + } +} + +/// Generates permutations for a list of values. +static void getPermutationSequences( + ValueRange operands, ByteCodeField rangeIndex, unsigned numOperands, + MutableArrayRef &rangeRangeMemory, + std::vector> &allocatedRangeRangeMemory, + std::vector> &allocatedValueRangeMemory) { + auto operandsVector = llvm::to_vector<4>(operands); + int64_t numResults = factorial(numOperands); + SmallVector results; + results.reserve(numResults); + std::vector orderings; + getOperandInitialOrdering(operandsVector, orderings); + std::sort(orderings.begin(), orderings.end()); + do { + SmallVector values; + values.reserve(numOperands); + for (unsigned i = 0; i < numOperands; i++) { + values.push_back(operandsVector[orderings[i]]); + } + ValueRange range(values); + llvm::OwningArrayRef storage(range.size()); + llvm::copy(range, storage.begin()); + allocatedValueRangeMemory.emplace_back(std::move(storage)); + results.push_back(allocatedValueRangeMemory.back()); + } while (std::next_permutation(orderings.begin(), orderings.end())); + + // If the range index is valid, we are returning a range. + if (rangeIndex != std::numeric_limits::max()) { + llvm::OwningArrayRef storage(results.size()); + llvm::copy(results, storage.begin()); + allocatedRangeRangeMemory.emplace_back(std::move(storage)); + + // Assign this to the range slot. + rangeRangeMemory[rangeIndex] = allocatedRangeRangeMemory.back(); + } +} + +void ByteCodeExecutor::executeGetPermutations() { + LLVM_DEBUG(llvm::dbgs() << "Executing GetPermutations:\n"); + unsigned listRangeIndex = read(); + unsigned permutationsRangeIndex = read(); + ValueRange operands = valueRangeMemory[listRangeIndex]; + getPermutationSequences(operands, permutationsRangeIndex, operands.size(), + rangeRangeMemory, allocatedRangeRangeMemory, + allocatedValueRangeMemory); +} + void ByteCodeExecutor::executeGetResult(unsigned index) { Operation *op = read(); unsigned memIndex = read(); @@ -1860,6 +2122,15 @@ memory[memIndex] = &typeRangeMemory[rangeIndex]; } +void ByteCodeExecutor::executeIsCommutative() { + LLVM_DEBUG(llvm::dbgs() << "Executing IsCommutative:\n"); + Operation *op = read(); + + LLVM_DEBUG(llvm::dbgs() << " * Operation: " << op << "\n"); + selectJump(op && (op->hasTrait() || + op->hasAttr("commutative"))); +} + void ByteCodeExecutor::executeIsNotNull() { LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); const void *value = read(); @@ -2052,6 +2323,10 @@ case CheckTypes: executeCheckTypes(); break; + case ChooseRange: + if (executeChooseRange()) + break; + return; case Continue: executeContinue(); break; @@ -2074,7 +2349,8 @@ executeExtract(); break; case Finalize: - executeFinalize(); + if (executeFinalize()) + break; LLVM_DEBUG(llvm::dbgs() << "\n"); return; case ForEach: @@ -2089,6 +2365,9 @@ case GetDefiningOp: executeGetDefiningOp(); break; + case GetItem: + executeGetItem(read()); + break; case GetOperand0: case GetOperand1: case GetOperand2: @@ -2105,6 +2384,9 @@ case GetOperands: executeGetOperands(); break; + case GetPermutations: + executeGetPermutations(); + break; case GetResult0: case GetResult1: case GetResult2: @@ -2130,6 +2412,9 @@ case GetValueRangeTypes: executeGetValueRangeTypes(); break; + case IsCommutative: + executeIsCommutative(); + break; case IsNotNull: executeIsNotNull(); break; @@ -2175,6 +2460,7 @@ // The matcher function always starts at code address 0. ByteCodeExecutor executor( matcherByteCode.data(), state.memory, state.opRangeMemory, + state.rangeRangeMemory, state.allocatedRangeRangeMemory, state.typeRangeMemory, state.allocatedTypeRangeMemory, state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, uniquedData, matcherByteCode, state.currentPatternBenefits, patterns, @@ -2197,7 +2483,8 @@ ByteCodeExecutor executor( &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory, - state.opRangeMemory, state.typeRangeMemory, + state.opRangeMemory, state.rangeRangeMemory, + state.allocatedRangeRangeMemory, state.typeRangeMemory, state.allocatedTypeRangeMemory, state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex, uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns, diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -255,8 +255,8 @@ } // CHECK-LABEL: test.are_equal_2 -// CHECK: "test.not_equal" // CHECK: "test.success" +// CHECK: "test.not_equal" // CHECK-NOT: "test.op" module @ir attributes { test.are_equal_2 } { "test.not_equal"() : () -> (i32) @@ -361,8 +361,8 @@ } // CHECK-LABEL: test.check_operand_count_1 -// CHECK: "test.op"() : () -> i32 // CHECK: "test.success" +// CHECK: "test.op"() : () -> i32 module @ir attributes { test.check_operand_count_1 } { %operand = "test.op"() : () -> i32 "test.op"(%operand, %operand) : (i32, i32) -> () @@ -430,8 +430,8 @@ } // CHECK-LABEL: test.check_result_count_1 -// CHECK: "test.op"() : () -> i32 // CHECK: "test.success"() : () -> () +// CHECK: "test.op"() : () -> i32 // CHECK-NOT: "test.op"() : () -> (i32, i32) module @ir attributes { test.check_result_count_1 } { "test.op"() : () -> i32 @@ -504,8 +504,8 @@ } // CHECK-LABEL: test.check_types_1 -// CHECK: "test.op"() : () -> (i32, i64) // CHECK: "test.success" +// CHECK: "test.op"() : () -> (i32, i64) // CHECK-NOT: "test.op"() : () -> i32 module @ir attributes { test.check_types_1 } { "test.op"() : () -> (i32, i64) @@ -514,6 +514,12 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl_interp::ChooseRangeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + //===----------------------------------------------------------------------===// // pdl_interp::ContinueOp //===----------------------------------------------------------------------===// @@ -641,8 +647,8 @@ } // CHECK-LABEL: test.extract_type -// CHECK: %[[OPERAND:.*]] = "test.op" // CHECK: "test.success" +// CHECK: %[[OPERAND:.*]] = "test.op" // CHECK: "test.op"(%[[OPERAND]]) module @ir attributes { test.extract_type } { %operand = "test.op"() : () -> i32 @@ -673,8 +679,8 @@ } // CHECK-LABEL: test.extract_value -// CHECK: %[[OPERAND:.*]] = "test.op" // CHECK: "test.success" +// CHECK: %[[OPERAND:.*]] = "test.op" // CHECK: "test.op"(%[[OPERAND]]) module @ir attributes { test.extract_value } { %operand = "test.op"() : () -> i32 @@ -897,9 +903,9 @@ } // CHECK-LABEL: test.get_defining_op_1 +// CHECK: "test.success" // CHECK: %[[OPERAND0:.*]] = "test.op" // CHECK: %[[OPERAND1:.*]] = "test.op" -// CHECK: "test.success" // CHECK: "test.op"(%[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND0]], %[[OPERAND1]]) module @ir attributes { test.get_defining_op_1 } { %operand = "test.op"() : () -> i32 @@ -910,6 +916,12 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl_interp::GetItemOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + //===----------------------------------------------------------------------===// // pdl_interp::GetOperandOp //===----------------------------------------------------------------------===// @@ -1008,11 +1020,11 @@ } // CHECK-LABEL: test.get_operands_2 -// CHECK-NEXT: %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) // CHECK-NEXT: "test.success"() : () -> () -// CHECK-NEXT: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> () +// CHECK-NEXT: %[[INPUTS:.*]]:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) // CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> () // CHECK-NEXT: "test.success"(%[[INPUTS]]#4) : (i32) -> () +// CHECK-NEXT: "test.success"(%[[INPUTS]]#0, %[[INPUTS]]#1, %[[INPUTS]]#2, %[[INPUTS]]#3) : (i32, i32, i32, i32) -> () module @ir attributes { test.get_operands_2 } { %inputs:5 = "test.producer"() : () -> (i32, i32, i32, i32, i32) "test.attr_sized_operands"(%inputs#0, %inputs#1, %inputs#2, %inputs#3, %inputs#4) {operand_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : (i32, i32, i32, i32, i32) -> () @@ -1020,6 +1032,100 @@ // ----- +//===----------------------------------------------------------------------===// +// pdl_interp::GetPermutationsOp +//===----------------------------------------------------------------------===// + +module @patterns { + func @matcher(%arg0: !pdl.operation) { + pdl_interp.is_commutative %arg0 : !pdl.operation -> ^bb2, ^bb1 + ^bb1: // 22 preds: ^bb0, ^bb2, ^bb3, ^bb4, ^bb5, ^bb6, ^bb7, ^bb8, ^bb9, ^bb10, ^bb11, ^bb12, ^bb13, ^bb14, ^bb15, ^bb16, ^bb17, ^bb18, ^bb19, ^bb20, ^bb21, ^bb22 + pdl_interp.finalize + ^bb2: // pred: ^bb0 + %0 = pdl_interp.get_operands of %arg0 : !pdl.range + %1 = pdl_interp.get_permutations for %0 + %2 = pdl_interp.choose_range from %1 + %3 = pdl_interp.get_item 1 from %2 + %4 = pdl_interp.get_defining_op of %3 : !pdl.value + pdl_interp.is_not_null %4 : !pdl.operation -> ^bb3, ^bb1 + ^bb3: // pred: ^bb2 + %5 = pdl_interp.get_item 0 from %2 + %6 = pdl_interp.get_defining_op of %5 : !pdl.value + pdl_interp.is_not_null %6 : !pdl.operation -> ^bb4, ^bb1 + ^bb4: // pred: ^bb3 + pdl_interp.check_operation_name of %arg0 is "test.op" -> ^bb5, ^bb1 + ^bb5: // pred: ^bb4 + pdl_interp.check_operand_count of %arg0 is 2 -> ^bb6, ^bb1 + ^bb6: // pred: ^bb5 + pdl_interp.check_result_count of %arg0 is 1 -> ^bb7, ^bb1 + ^bb7: // pred: ^bb6 + pdl_interp.is_not_null %3 : !pdl.value -> ^bb8, ^bb1 + ^bb8: // pred: ^bb7 + pdl_interp.is_not_null %5 : !pdl.value -> ^bb9, ^bb1 + ^bb9: // pred: ^bb8 + %7 = pdl_interp.get_result 0 of %arg0 + pdl_interp.is_not_null %7 : !pdl.value -> ^bb10, ^bb1 + ^bb10: // pred: ^bb9 + pdl_interp.check_operand_count of %4 is 0 -> ^bb11, ^bb1 + ^bb11: // pred: ^bb10 + pdl_interp.check_operand_count of %6 is 0 -> ^bb12, ^bb1 + ^bb12: // pred: ^bb11 + pdl_interp.check_result_count of %6 is 1 -> ^bb13, ^bb1 + ^bb13: // pred: ^bb12 + pdl_interp.check_result_count of %4 is 1 -> ^bb14, ^bb1 + ^bb14: // pred: ^bb13 + %8 = pdl_interp.get_attribute "attr" of %6 + pdl_interp.is_not_null %8 : !pdl.attribute -> ^bb15, ^bb1 + ^bb15: // pred: ^bb14 + %9 = pdl_interp.get_attribute "attr1" of %4 + pdl_interp.is_not_null %9 : !pdl.attribute -> ^bb16, ^bb1 + ^bb16: // pred: ^bb15 + %10 = pdl_interp.get_result 0 of %6 + pdl_interp.is_not_null %10 : !pdl.value -> ^bb17, ^bb1 + ^bb17: // pred: ^bb16 + %11 = pdl_interp.get_result 0 of %4 + pdl_interp.is_not_null %11 : !pdl.value -> ^bb18, ^bb1 + ^bb18: // pred: ^bb17 + pdl_interp.are_equal %10, %5 : !pdl.value -> ^bb19, ^bb1 + ^bb19: // pred: ^bb18 + pdl_interp.are_equal %11, %3 : !pdl.value -> ^bb20, ^bb1 + ^bb20: // pred: ^bb19 + %12 = pdl_interp.get_value_type of %10 : !pdl.type + %13 = pdl_interp.get_value_type of %11 : !pdl.type + pdl_interp.are_equal %12, %13 : !pdl.type -> ^bb21, ^bb1 + ^bb21: // pred: ^bb20 + %14 = pdl_interp.get_value_type of %7 : !pdl.type + pdl_interp.are_equal %12, %14 : !pdl.type -> ^bb22, ^bb1 + ^bb22: // pred: ^bb21 + pdl_interp.record_match @rewriters::@success(%arg0 : !pdl.operation) : benefit(1), loc([%arg0, %4, %6]), root("test.op") -> ^bb1 + } + module @rewriters { + func @success(%root : !pdl.operation) { + %op = pdl_interp.create_operation "test.success" + pdl_interp.erase %root + pdl_interp.finalize + } + } + +} + +// We can see that if the ops were not commutative, only one "test.success" +// would have been obtained, but, since the ops are commutative, their chance of +// matching has increased (both permutations of their operands are matched +// above), which results in two "test.success" ops, one for each commutative +// "test.op". +// CHECK-LABEL: test.commutativity +// CHECK: "test.success" +// CHECK: "test.success" +module @ir attributes { test.commutativity} { + %0 = "test.op"() {"attr1" = 2} : () -> i32 + %1 = "test.op"() {"attr" = 1} : () -> i32 + %2 = "test.op"(%1, %0) {"commutative"} : (i32, i32) -> i32 + %3 = "test.op"(%0, %1) {"commutative"} : (i32, i32) -> i32 +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::GetResultOp //===----------------------------------------------------------------------===// @@ -1162,10 +1268,10 @@ } // CHECK-LABEL: test.get_results_2 -// CHECK: "test.success"() : () -> () -// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32) -// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32 // CHECK: %[[RESULTS_2_SINGLE:.*]] = "test.success"() : () -> i32 +// CHECK: %[[RESULTS_2:.*]] = "test.success"() : () -> i32 +// CHECK: %[[RESULTS_1:.*]]:4 = "test.success"() : () -> (i32, i32, i32, i32) +// CHECK: "test.success"() : () -> () // CHECK: "test.consumer"(%[[RESULTS_1]]#0, %[[RESULTS_1]]#1, %[[RESULTS_1]]#2, %[[RESULTS_1]]#3, %[[RESULTS_2]]) : (i32, i32, i32, i32, i32) -> () module @ir attributes { test.get_results_2 } { %results:5 = "test.attr_sized_results"() {result_segment_sizes = dense<[0, 4, 1, 0]> : vector<4xi32>} : () -> (i32, i32, i32, i32, i32) @@ -1186,6 +1292,12 @@ // Fully tested within the tests for other operations. +//===----------------------------------------------------------------------===// +// pdl_interp::IsCommutativeOp +//===----------------------------------------------------------------------===// + +// Fully tested within the tests for other operations. + //===----------------------------------------------------------------------===// // pdl_interp::IsNotNullOp //===----------------------------------------------------------------------===//