diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -167,6 +167,10 @@ struct ByteCodeLiveRange; struct ByteCodeWriter; +/// Check if the given class `T` has an iterator type. +template +using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); + /// This class represents the main generator for the pattern bytecode. class Generator { public: @@ -384,10 +388,6 @@ bytecode.push_back(static_cast(kind)); } - /// Check if the given class `T` has an iterator type. - template - using has_pointer_traits = decltype(std::declval().getAsOpaquePointer()); - /// Append a value that will be stored in a memory slot and not inline within /// the bytecode. template @@ -413,6 +413,17 @@ append(field2, fields...); } + /// Appends a value as a pointer, stored inline within the bytecode. + template + std::enable_if_t::value> + appendInline(T value) { + constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField); + const void *pointer = value.getAsOpaquePointer(); + ByteCodeField fieldParts[numParts]; + std::memcpy(fieldParts, &pointer, sizeof(const void *)); + bytecode.append(fieldParts, fieldParts + numParts); + } + /// Successor references in the bytecode that have yet to be resolved. DenseMap> unresolvedSuccessorRefs; @@ -689,6 +700,13 @@ } void Generator::generate(Operation *op, ByteCodeWriter &writer) { + LLVM_DEBUG({ + // The following list must contain all the operations that do not + // produce any bytecode. + if (!isa(op)) + writer.appendInline(op->getLoc()); + }); TypeSwitch(op) .Case + std::enable_if_t::value, T> + readInline() { + const void *pointer; + std::memcpy(&pointer, curCodeIt, sizeof(const void *)); + curCodeIt += sizeof(const void *) / sizeof(ByteCodeField); + return T::getFromOpaquePointer(pointer); + } + /// Jump to a specific successor based on a predicate value. void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); } /// Jump to a specific successor based on a destination index. @@ -1201,6 +1229,9 @@ /// The underlying bytecode buffer. const ByteCodeField *curCodeIt; + /// The pointer to bytecode before the current operation got executed. + const ByteCodeField *prevCodeIt; + /// The stack of bytecode positions at which to resume operation. SmallVector resumeCodeIt; @@ -1503,8 +1534,6 @@ void ByteCodeExecutor::executeForEach() { LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); - // Subtract 1 for the op code. - const ByteCodeField *it = curCodeIt - 1; unsigned rangeIndex = read(); unsigned memIndex = read(); const void *value = nullptr; @@ -1531,7 +1560,7 @@ // Store the iterate value and the stack address. memory[memIndex] = value; - pushCodeIt(it); + pushCodeIt(prevCodeIt); // Skip over the successor (we will enter the body of the loop). read(); @@ -1911,6 +1940,10 @@ SmallVectorImpl *matches, Optional mainRewriteLoc) { while (true) { + // Store the current code iterator, needed for ForEach. + prevCodeIt = curCodeIt; + + LLVM_DEBUG(llvm::dbgs() << readInline() << "\n"); OpCode opCode = static_cast(read()); switch (opCode) { case ApplyConstraint: