diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -26,10 +26,15 @@ /// `producer` is an optional string that can be used to identify the producer /// of the bytecode when reading. It has no functional effect on the bytecode /// serialization. - BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING); + /// `shouldEmitUseListOrder` is an optional boolean that controls whether to + /// emit use-list order info to the bytecode to preserve walk behaviours after + /// a bytecode roundtrip. + BytecodeWriterConfig(bool shouldEmitUseListOrder = false, + StringRef producer = "MLIR" LLVM_VERSION_STRING); /// `map` is a fallback resource map, which when provided will attach resource /// printers for the fallback resources within the map. BytecodeWriterConfig(FallbackAsmResourceMap &map, + bool shouldEmitUseListOrder = false, StringRef producer = "MLIR" LLVM_VERSION_STRING); ~BytecodeWriterConfig(); diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -459,11 +459,15 @@ public: /// Construct a parser configuration with the given context. /// `verifyAfterParse` indicates if the IR should be verified after parsing. + /// `preserveUseListOrders` indicates if the in memory use-list orders should + /// be preserved when interacting with bytecode. /// `fallbackResourceMap` is an optional fallback handler that can be used to /// parse external resources not explicitly handled by another parser. ParserConfig(MLIRContext *context, bool verifyAfterParse = true, + bool preserveUseListOrders = false, FallbackAsmResourceMap *fallbackResourceMap = nullptr) : context(context), verifyAfterParse(verifyAfterParse), + preserveUseListOrders(preserveUseListOrders), fallbackResourceMap(fallbackResourceMap) { assert(context && "expected valid MLIR context"); } @@ -474,6 +478,10 @@ /// Returns if the parser should verify the IR after parsing. bool shouldVerifyAfterParse() const { return verifyAfterParse; } + /// Returns if the parser should preserve the use-list if available in + /// bytecode. + bool shouldPreserveUseListOrders() const { return preserveUseListOrders; } + /// Return the resource parser registered to the given name, or nullptr if no /// parser with `name` is registered. AsmResourceParser *getResourceParser(StringRef name) const { @@ -506,6 +514,7 @@ private: MLIRContext *context; bool verifyAfterParse; + bool preserveUseListOrders; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; }; diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -187,6 +187,10 @@ /// Returns true if the value is used outside of the given block. bool isUsedOutsideOfBlock(Block *block); + /// Set the use list order according to the provided indices. The indices need + /// to map the current use-list order chain to another valid use-list chain. + void setUseListOrder(ArrayReforder); + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -153,6 +153,8 @@ } bool shouldVerifyPasses() const { return verifyPassesFlag; } + bool shouldPreserveUseListOrders() const { return preserveUseListOrdersFlag; } + protected: /// Allow operation with no registered dialects. /// This option is for convenience during testing only and discouraged in @@ -199,6 +201,9 @@ /// Run the verifier after each transformation pass. bool verifyPassesFlag = true; + + /// Preserve uselist orders when interacting with bytecode. + bool preserveUseListOrdersFlag = false; }; /// This defines the function type used to setup the pass manager. This can be diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -87,6 +87,7 @@ kHasOperands = 0b00000100, kHasSuccessors = 0b00001000, kHasInlineRegions = 0b00010000, + kHasUseListOrders = 0b00100000, // clang-format on }; } // namespace OpEncodingMask diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -24,6 +24,7 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include #include #define DEBUG_TYPE "mlir-bytecode-reader" @@ -1192,6 +1193,17 @@ /// Create a value to use for a forward reference. Value createForwardRef(); + /// Parse use-list orders from bytecode if available. On success, return a map + /// between the op result number and the use-list encoding. + using UseListMapT = + DenseMap>>; + FailureOr parseOpResultsUseListOrder(EncodingReader &reader, + uint64_t numResults); + + /// Process the values with custom use-list and modify the use-chain according + /// to the order parsed. + LogicalResult processUseListOrderWorklist(); + //===--------------------------------------------------------------------===// // Fields @@ -1220,6 +1232,23 @@ SmallVector nextValueIDs; }; + struct UseListOrder { + UseListOrder(Value val, bool isIndexPairEncoding, + SmallVector &&indices) + : value(val), isIndexPairEncoding(isIndexPairEncoding), + indices(std::move(indices)){}; + // The Value owning the custom use-list the uselist. + Value value; + + // Whether the order is encoded as an explicit pair of indices `(src, dst)` + // or it is a direct indexing, such as `dst = order[src]`. + bool isIndexPairEncoding; + + // The vector containing the order information required to reorder the + // use-list of value. + SmallVector indices; + }; + /// The configuration of the parser. const ParserConfig &config; @@ -1242,17 +1271,24 @@ /// The reader used to process resources within the bytecode. ResourceSectionReader resourceReader; + /// Worklist of values with custom use-list orders to process before the end + /// of the parsing. + SmallVector useListOrderWorklist; + /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; /// The current set of available IR value scopes. std::vector valueScopes; + /// A block containing the set of operations defined to create forward /// references. Block forwardRefOps; + /// A block containing previously created, and no longer used, forward /// reference operations. Block openForwardRefOps; + /// An operation state used when instantiating forward references. OperationState forwardRefOpState; @@ -1487,6 +1523,88 @@ dialectReader, bufferOwnerRef); } +//===----------------------------------------------------------------------===// +// UseListOrder Helpers + +FailureOr +BytecodeReader::parseOpResultsUseListOrder(EncodingReader &reader, + uint64_t numResults) { + BytecodeReader::UseListMapT map; + uint64_t numValuesToRead = 1; + if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) + return failure(); + + for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { + uint64_t resultIdx = 0; + if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) + return failure(); + + uint64_t numValues; + bool indexPairEncoding; + if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) + return failure(); + + SmallVector useListOrders; + for (size_t idx = 0; idx < numValues; idx++) { + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + useListOrders.push_back(index); + } + + // Store in a map the result index + map.try_emplace(resultIdx, + std::make_pair(indexPairEncoding, useListOrders)); + } + + return map; +} + +LogicalResult BytecodeReader::processUseListOrderWorklist() { + while (!useListOrderWorklist.empty()) { + UseListOrder entry = useListOrderWorklist.pop_back_val(); + uint64_t numUses = std::distance(entry.value.getUses().begin(), + entry.value.getUses().end()); + SmallVector shuffle = std::move(entry.indices); + + // If the encoding was a pair of indices `(src, dst)` for every + // permutation, reconstruct the shuffle vector for every use. Initialize + // the shuffle vector as identity, and then apply the mapping encoded in + // the indices. + if (entry.isIndexPairEncoding) { + // Return failure if the number of indices was not representing pairs. + if (shuffle.size() & 1) + return failure(); + + SmallVector newShuffle(numUses); + size_t idx = 0; + std::iota(newShuffle.begin(), newShuffle.end(), idx); + for (idx = 0; idx < shuffle.size(); idx += 2) + newShuffle[shuffle[idx]] = shuffle[idx + 1]; + + shuffle = std::move(newShuffle); + } + + // Make sure that the indices represent a valid mapping. That is, the sum of + // all the values needs to be equal to (numUses - 1) * numUses / 2, and no + // duplicates are allowed in the list. + DenseSet set; + uint64_t accumulator = 0; + for (const auto &elem : shuffle) { + if (set.contains(elem)) + return failure(); + accumulator += elem; + set.insert(elem); + } + if (numUses != shuffle.size() || + accumulator != (((numUses - 1) * numUses) >> 1)) + return failure(); + + entry.value.setUseListOrder(shuffle); + } + return success(); +} + //===----------------------------------------------------------------------===// // IR Section @@ -1516,6 +1634,11 @@ "not all forward unresolved forward operand references"); } + if (config.shouldPreserveUseListOrders()) + if (failed(processUseListOrderWorklist())) + return reader.emitError( + "parsed use-list orders were invalid and could not be applied"); + // Resolve dialect version. for (const BytecodeDialect &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the @@ -1665,6 +1788,16 @@ } } + /// Parse the use-list orders for the results of the operation. + std::optional useListOrderMap = std::nullopt; + if (opMask & bytecode::OpEncodingMask::kHasUseListOrders) { + size_t numResults = opState.types.size(); + auto parseResult = parseOpResultsUseListOrder(reader, numResults); + if (failed(parseResult)) + return failure(); + useListOrderMap = std::move(*parseResult); + } + /// Parse the regions of the operation. if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { uint64_t numRegions; @@ -1684,6 +1817,18 @@ if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) return failure(); + // We can't update the uselist at this point, since the uses of the results + // don't yet exist. + if (useListOrderMap.has_value()) { + for (size_t idx = 0; idx < op->getNumResults(); idx++) { + if (useListOrderMap->contains(idx)) { + bool isIndexPairEncoding = useListOrderMap->at(idx).first; + auto order = useListOrderMap->at(idx).second; + useListOrderWorklist.emplace_back(UseListOrder( + op->getResult(idx), isIndexPairEncoding, std::move(order))); + } + } + } return op; } diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -25,7 +25,8 @@ //===----------------------------------------------------------------------===// struct BytecodeWriterConfig::Impl { - Impl(StringRef producer) : producer(producer) {} + Impl(bool shouldEmitUseListOrder, StringRef producer) + : producer(producer), shouldEmitUseListOrder(shouldEmitUseListOrder) {} /// Version to use when writing. /// Note: This only differs from kVersion if a specific version is set. @@ -34,15 +35,20 @@ /// The producer of the bytecode. StringRef producer; + /// Flag indicating if we should emit use-list order for operations. + bool shouldEmitUseListOrder; + /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; }; -BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer) - : impl(std::make_unique(producer)) {} +BytecodeWriterConfig::BytecodeWriterConfig(bool shouldEmitUseListOrder, + StringRef producer) + : impl(std::make_unique(shouldEmitUseListOrder, producer)) {} BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map, + bool shouldEmitUseListOrder, StringRef producer) - : BytecodeWriterConfig(producer) { + : BytecodeWriterConfig(shouldEmitUseListOrder, producer) { attachFallbackResourcePrinter(map); } BytecodeWriterConfig::~BytecodeWriterConfig() = default; @@ -470,6 +476,12 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Helpers + + void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask, + Operation *op); + //===--------------------------------------------------------------------===// // Fields @@ -718,6 +730,13 @@ emitter.emitVarInt(numberingState.getNumber(successor)); } + // If uselist order preservation is requested, record the indices that map the + // current use-list order of the op with respect to the order printed to the + // IR. We can use this information to rebuild the correct uselist when parsing + // the bytecode. + if (config.shouldEmitUseListOrder) + writeUseListOrders(emitter, opEncodingMask, op); + // Check for regions. unsigned numRegions = op->getNumRegions(); if (numRegions) @@ -739,6 +758,88 @@ } } +void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, + uint8_t &opEncodingMask, + Operation *op) { + /// Loop over the results and store the use-list order per result index. + DenseMap>> map; + for (auto resultIterator : llvm::enumerate(op->getResults())) { + auto result = resultIterator.value(); + // No need to store a custom use-list order if the result does not have + // multiple uses. + if (result.use_empty() || result.hasOneUse()) + continue; + + // For each result, assemble the list of pairs (use-list-index, + // global-value-index). While doing so, detect if the global-value-index is + // already ordered with respect to the use-list-index. + bool alreadyOrdered = true; + unsigned prevIndex = + numberingState.getNumber(result.getUses().begin()->getOwner()); + llvm::SmallVector> useListPairs( + {{0, prevIndex}}); + + for (auto item : llvm::drop_begin(llvm::enumerate(result.getUses()))) { + unsigned currentIdx = numberingState.getNumber(item.value().getOwner()); + alreadyOrdered &= (prevIndex >= currentIdx); + useListPairs.push_back({item.index(), currentIdx}); + prevIndex = currentIdx; + } + + // Do not emit if the order is already sorted. + if (alreadyOrdered) + continue; + + map.try_emplace(resultIterator.index(), useListPairs); + } + + if (map.empty()) + return; + + opEncodingMask |= bytecode::OpEncodingMask::kHasUseListOrders; + // Emit the number of results that have a custom use-list order if the number + // of results is greater than one. + if (op->getNumResults() != 1) + emitter.emitVarInt(map.size()); + + for (const auto &elem : map) { + auto resultIdx = elem.getFirst(); + auto useListPairs = elem.getSecond(); + // Sort the use indices by the IR operation ordering and insert them to the + // builder for emission. + std::sort( + useListPairs.begin(), useListPairs.end(), + [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); + + // Compute the number of uses that are actually shuffled. If those are less + // than half of the total uses, encoding the index pair `(src, dst)` is more + // space efficient. + size_t shuffledElements = + llvm::count_if(llvm::enumerate(useListPairs), [](auto item) { + return item.index() != item.value().first; + }); + bool indexPairEncoding = shuffledElements < (useListPairs.size() / 2); + + // For single result, we don't need to store the result index. + if (op->getNumResults() != 1) + emitter.emitVarInt(resultIdx); + + if (indexPairEncoding) { + emitter.emitVarIntWithFlag(shuffledElements * 2, indexPairEncoding); + for (auto pair : llvm::enumerate(useListPairs)) { + if (pair.index() != pair.value().first) { + emitter.emitVarInt(pair.value().first); + emitter.emitVarInt(pair.index()); + } + } + } else { + emitter.emitVarIntWithFlag(useListPairs.size(), indexPairEncoding); + for (auto elem : useListPairs) + emitter.emitVarInt(elem.first); + } + } +} + void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) { // If the region is empty, we only need to emit the number of blocks (which is // zero). diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -152,6 +152,10 @@ assert(blockIDs.count(block) && "block not numbered"); return blockIDs[block]; } + unsigned getNumber(Operation *op) { + assert(operationIDs.count(op) && "block not numbered"); + return operationIDs[op]; + } unsigned getNumber(OperationName opName) { assert(opNames.count(opName) && "opName not numbered"); return opNames[opName]->number; @@ -224,7 +228,8 @@ llvm::SpecificBumpPtrAllocator resourceAllocator; llvm::SpecificBumpPtrAllocator typeAllocator; - /// The value ID for each Block and Value. + /// The value ID for each Block, Operation and Value. + DenseMap operationIDs; DenseMap blockIDs; DenseMap valueIDs; @@ -236,6 +241,7 @@ /// The next value ID to assign when numbering. unsigned nextValueID = 0; + unsigned nextOperationID = 0; }; } // namespace detail } // namespace bytecode diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -263,6 +263,7 @@ } void IRNumberingState::number(Operation &op) { + operationIDs.try_emplace(&op, nextOperationID++); // Number the components of an operation that won't be numbered elsewhere // (e.g. we don't number operands, regions, or successors here). number(op.getName()); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -93,6 +93,22 @@ }); } +void Value::setUseListOrder(ArrayRef order) { + assert((size_t)std::distance(getUses().begin(), getUses().end()) == + order.size() && + "order vector expected to have a number of elements equal to the " + "number of uses"); + // Order uses into a vector and set them for every operand type. + SmallVector newUseList(order.size()); + for (const auto &elem : llvm::enumerate(getUses())) + newUseList[order[elem.index()]] = &elem.value(); + + // For every use in the range, replace it with the one in the list. + for (const auto &item : llvm::zip(getUses(), newUseList)) { + std::get<0>(item).set((*std::get<1>(item)).get()); + } +} + //===----------------------------------------------------------------------===// // OpResult //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -352,7 +352,7 @@ } ParserConfig config(&context, /*verifyAfterParse=*/true, - &fallbackResourceMap); + /*preserveUseListOrders=*/false, &fallbackResourceMap); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) { // If parsing failed, clear out any of the current state. @@ -1298,6 +1298,7 @@ // Setup the parser config. ParserConfig parserConfig(&tempContext, /*verifyAfterParse=*/true, + /*preserveUseListOrders=*/false, &fallbackResourceMap); // Try to parse the given source file. diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -135,6 +135,11 @@ cl::desc("Run the verifier after each transformation pass"), cl::location(verifyPassesFlag), cl::init(true)); + static cl::opt preserveUseListOrders( + "preserve-use-list-orders", + cl::desc("Preserve use-list orders when interacting with bytecode"), + cl::location(preserveUseListOrdersFlag), cl::init(false)); + static cl::list passPlugins( "load-pass-plugin", cl::desc("Load passes from plugin library")); /// Set the callback to load a pass plugin. @@ -235,6 +240,7 @@ PassReproducerOptions reproOptions; FallbackAsmResourceMap fallbackResourceMap; ParserConfig parseConfig(context, /*verifyAfterParse=*/true, + config.shouldPreserveUseListOrders(), &fallbackResourceMap); reproOptions.attachResourceParser(parseConfig); @@ -263,7 +269,8 @@ // Print the output. TimingScope outputTiming = timing.nest("Output"); if (config.shouldEmitBytecode()) { - BytecodeWriterConfig writerConfig(fallbackResourceMap); + BytecodeWriterConfig writerConfig(fallbackResourceMap, + config.shouldPreserveUseListOrders()); if (auto v = config.bytecodeVersionToEmit()) writerConfig.setDesiredBytecodeVersion(*v); return writeBytecodeToFile(op.get(), os, writerConfig);