diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -169,6 +169,10 @@ struct ByteCodeLiveRange; struct ByteCodeWriter; +/// Check if the given class `T` has an iterator type. +template +using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); + /// This class represents the main generator for the pattern bytecode. class Generator { public: @@ -387,10 +391,6 @@ bytecode.push_back(static_cast(kind)); } - /// Check if the given class `T` has an iterator type. - template - using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); - /// Append a value that will be stored in a memory slot and not inline within /// the bytecode. template @@ -416,6 +416,17 @@ append(field2, fields...); } + /// Appends a value as a pointer, stored inline within the bytecode. + template + std::enable_if_t::value> + appendInline(T value) { + constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); + const void *pointer = value.getAsOpaquePointer(); + ByteCodeField fieldParts[numParts]; + std::memcpy(fieldParts, &pointer, sizeof(const void *)); + bytecode.append(fieldParts, fieldParts + numParts); + } + /// Successor references in the bytecode that have yet to be resolved. DenseMap> unresolvedSuccessorRefs; @@ -692,6 +703,13 @@ } void Generator::generate(Operation *op, ByteCodeWriter &writer) { + LLVM_DEBUG({ + // The following list must contain all the operations that do not + // produce any bytecode. + if (!isa(op)) + writer.appendInline(op->getLoc()); + }); TypeSwitch(op) .Case + std::enable_if_t::value, T> + readInline() { + const void *pointer; + std::memcpy(&pointer, curCodeIt, sizeof(const void *)); + curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); + return T::getFromOpaquePointer(pointer); + } + /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. @@ -1507,8 +1546,7 @@ void ByteCodeExecutor::executeForEach() { LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); - // Subtract 1 for the op code. - const ByteCodeField *it = curCodeIt - 1; + const ByteCodeField *prevCodeIt = getPrevCodeIt(); unsigned rangeIndex = read(); unsigned memIndex = read(); const void *value = nullptr; @@ -1535,7 +1573,7 @@ // Store the iterate value and the stack address. memory[memIndex] = value; - pushCodeIt(it); + pushCodeIt(prevCodeIt); // Skip over the successor (we will enter the body of the loop). read(); @@ -1945,6 +1983,7 @@ SmallVectorImpl *matches, Optional mainRewriteLoc) { while (true) { + LLVM_DEBUG(llvm::dbgs() << readInline() << "\n"); OpCode opCode = static_cast(read()); switch (opCode) { case ApplyConstraint: