diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -420,6 +420,17 @@ append(field2, fields...); } + /// Appends a value as a pointer, stored inline within the bytecode. + template + std::enable_if_t::value> + appendInline(T value) { + constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); + const void *pointer = value.getAsOpaquePointer(); + ByteCodeField fieldParts[numParts]; + std::memcpy(fieldParts, &pointer, sizeof(const void *)); + bytecode.append(fieldParts, fieldParts + numParts); + } + /// Successor references in the bytecode that have yet to be resolved. DenseMap> unresolvedSuccessorRefs; @@ -696,6 +707,13 @@ } void Generator::generate(Operation *op, ByteCodeWriter &writer) { + LLVM_DEBUG({ + // The following list must contain all the operations that do not + // produce any bytecode. + if (!isa(op)) + writer.appendInline(op->getLoc()); + }); TypeSwitch(op) .Case + std::enable_if_t::value, T> + readInline() { + const void *pointer; + std::memcpy(&pointer, curCodeIt, sizeof(const void *)); + curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); + return T::getFromOpaquePointer(pointer); + } + /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. @@ -1553,8 +1592,7 @@ void ByteCodeExecutor::executeForEach() { LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); - // Subtract 1 for the op code. - const ByteCodeField *it = curCodeIt - 1; + const ByteCodeField *prevCodeIt = getPrevCodeIt(); unsigned rangeIndex = read(); unsigned memIndex = read(); const void *value = nullptr; @@ -1581,7 +1619,7 @@ // Store the iterate value and the stack address. memory[memIndex] = value; - pushCodeIt(it); + pushCodeIt(prevCodeIt); // Skip over the successor (we will enter the body of the loop). read(); @@ -1973,6 +2011,9 @@ SmallVectorImpl *matches, Optional mainRewriteLoc) { while (true) { + // Print the location of the operation being executed. + LLVM_DEBUG(llvm::dbgs() << readInline() << "\n"); + OpCode opCode = static_cast(read()); switch (opCode) { case ApplyConstraint: