diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -339,11 +339,20 @@ numSuccessors: varint?, successors: varint[], + numUseListOrders: varint?, + useListOrders: uselist[], + regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove) // regions are stored in a section if isIsolatedFromAbove regions: (region | region_section)[] } + +uselist { + indexInRange: varint?, + useListEncoding: varint, // (numIndices << 1) | (isIndexPairEncoding) + indices: varint[] +} ``` The encoding of an operation is important because this is generally the most @@ -377,6 +386,26 @@ If the operation has successors, the number of successors and the indexes of the successor blocks within the parent region are encoded. +##### Use-list orders + +The reference use-list order is assumed to be the reverse of the global +enumeration of all the op operands that one would obtain with a pre-order walk +of the IR. This order is naturally obtained by building blocks of operations +op-by-op. However, some transformations may shuffle the use-lists with respect +to this reference ordering. If any of the results of the operation have a +use-list order that is not sorted with respect to the reference use-list order, +an encoding is emitted such that it is possible to reconstruct such order after +parsing the bytecode. The encoding represents an index map from the reference +operand order to the current use-list order. A bit flag is used to detect if +this encoding is of type index-pair or not. When the bit flag is set to zero, +the element at `i` represent the position of the use `i` of the reference list +into the current use-list. When the bit flag is set to `1`, the encoding +represent index pairs `(i, j)`, which indicate that the use at position `i` of +the reference list is mapped to position `j` in the current use-list. When only +less than half of the elements in the current use-list are shuffled with respect +to the reference use-list, the index-pair encoding is used to reduce the +bytecode memory requirements. + ##### Regions If the operation has regions, the number of regions and if the regions are @@ -410,6 +439,8 @@ block_arguments { numArgs: varint?, args: block_argument[] + numUseListOrders: varint?, + useListOrders: uselist[], } block_argument { @@ -421,3 +452,6 @@ A block is encoded with an array of operations and block arguments. The first field is an encoding that combines the number of operations in the block, with a flag indicating if the block has arguments. + +Use-list orders are attached to block arguments similarly to how they are +attached to operation results. diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h rename from mlir/lib/Bytecode/Encoding.h rename to mlir/include/mlir/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/include/mlir/Bytecode/Encoding.h @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#ifndef LIB_MLIR_BYTECODE_ENCODING_H -#define LIB_MLIR_BYTECODE_ENCODING_H +#ifndef MLIR_BYTECODE_ENCODING_H +#define MLIR_BYTECODE_ENCODING_H #include @@ -27,7 +27,7 @@ kMinSupportedVersion = 0, /// The current bytecode version. - kVersion = 2, + kVersion = 3, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, @@ -87,10 +87,27 @@ kHasOperands = 0b00000100, kHasSuccessors = 0b00001000, kHasInlineRegions = 0b00010000, + kHasUseListOrders = 0b00100000, // clang-format on }; } // namespace OpEncodingMask +/// Get the unique ID of a value use. We encode the unique ID combining an owner +/// number and the argument number such as if ownerID(op1) < ownerID(op2), then +/// useID(op1) < useID(op2). If uses have the same owner, then argNumber(op1) < +/// argNumber(op2) implies useID(op1) < useID(op2). +template +static inline uint64_t getUseID(OperandT &val, unsigned ownerID) { + uint32_t operandNumberID; + if constexpr (std::is_same_v) + operandNumberID = val.getOperandNumber(); + else if constexpr (std::is_same_v) + operandNumberID = val.getArgNumber(); + else + llvm_unreachable("unexpected operand type"); + return (static_cast(ownerID) << 32) | operandNumberID; +} + } // namespace bytecode } // namespace mlir diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -44,6 +44,21 @@ /// of the SSA machinery. IROperandBase *getNextOperandUsingThisValue() { return nextUse; } + /// Initialize the use-def chain by setting the back address to self and + /// nextUse to nullptr. + void initChainWithUse(IROperandBase **self) { + assert(this == *self); + back = self; + nextUse = nullptr; + } + + /// Link the current node to next. + void linkTo(IROperandBase *next) { + nextUse = next; + if (nextUse) + nextUse->back = &nextUse; + } + protected: IROperandBase(Operation *owner) : owner(owner) {} IROperandBase(IROperandBase &&other) : owner(other.owner) { @@ -192,6 +207,30 @@ use_begin()->set(newValue); } + /// Shuffle the use-list chain according to the provided indices vector, which + /// need to represent a valid shuffle. That is, a vector of unique integers in + /// range [0, numUses - 1]. Users of this function need to guarantee the + /// validity of the indices vector. + void shuffleUseList(ArrayRef indices) { + assert((size_t)std::distance(getUses().begin(), getUses().end()) == + indices.size() && + "indices vector expected to have a number of elements equal to the " + "number of uses"); + SmallVector shuffled(indices.size()); + detail::IROperandBase *ptr = firstUse; + for (size_t idx = 0; idx < indices.size(); + idx++, ptr = ptr->getNextOperandUsingThisValue()) + shuffled[indices[idx]] = ptr; + + initFirstUse(shuffled.front()); + auto *current = firstUse; + for (auto &next : llvm::drop_begin(shuffled)) { + current->linkTo(next); + current = next; + } + current->linkTo(nullptr); + } + //===--------------------------------------------------------------------===// // Uses //===--------------------------------------------------------------------===// @@ -234,6 +273,12 @@ OperandType *getFirstUse() const { return (OperandType *)firstUse; } private: + /// Set use as the first use of the chain. + void initFirstUse(detail::IROperandBase *use) { + firstUse = use; + firstUse->initChainWithUse(&firstUse); + } + detail::IROperandBase *firstUse = nullptr; /// Allow access to `firstUse`. diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -187,6 +187,11 @@ /// Returns true if the value is used outside of the given block. bool isUsedOutsideOfBlock(Block *block); + /// Shuffle the use list order according to the provided indices. It is + /// responsibility of the caller to make sure that the indices map the current + /// use-list chain to another valid use-list chain. + void shuffleUseList(ArrayRef indices); + //===--------------------------------------------------------------------===// // Uses diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -7,12 +7,11 @@ //===----------------------------------------------------------------------===// // TODO: Support for big-endian architectures. -// TODO: Properly preserve use lists of values. #include "mlir/Bytecode/BytecodeReader.h" -#include "../Encoding.h" #include "mlir/AsmParser/AsmParser.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/Encoding.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" @@ -29,6 +28,7 @@ #include "llvm/Support/SourceMgr.h" #include #include +#include #include #define DEBUG_TYPE "mlir-bytecode-reader" @@ -1281,6 +1281,42 @@ /// Create a value to use for a forward reference. Value createForwardRef(); + //===--------------------------------------------------------------------===// + // Use-list order helpers + + /// This struct is a simple storage that contains information required to + /// reorder the use-list of a value with respect to the pre-order traversal + /// ordering. + struct UseListOrderStorage { + UseListOrderStorage(bool isIndexPairEncoding, + SmallVector &&indices) + : indices(std::move(indices)), + isIndexPairEncoding(isIndexPairEncoding){}; + /// The vector containing the information required to reorder the + /// use-list of a value. + SmallVector indices; + + /// Whether indices represent a pair of type `(src, dst)` or it is a direct + /// indexing, such as `dst = order[src]`. + bool isIndexPairEncoding; + }; + + /// Parse use-list order from bytecode for a range of values if available. The + /// range is expected to be either a block argument or an op result range. On + /// success, return a map of the position in the range and the use-list order + /// encoding. The function assumes to know the size of the range it is + /// processing. + using UseListMapT = DenseMap; + FailureOr parseUseListOrderForRange(EncodingReader &reader, + uint64_t rangeSize); + + /// Shuffle the use-chain according to the order parsed. + LogicalResult sortUseListOrder(Value value); + + /// Recursively visit all the values defined within topLevelOp and sort the + /// use-list orders according to the indices parsed. + LogicalResult processUseLists(Operation *topLevelOp); + //===--------------------------------------------------------------------===// // Fields @@ -1341,17 +1377,27 @@ /// The reader used to process resources within the bytecode. ResourceSectionReader resourceReader; + /// Worklist of values with custom use-list orders to process before the end + /// of the parsing. + DenseMap valueToUseListMap; + /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; /// The current set of available IR value scopes. std::vector valueScopes; + + /// The global pre-order operation ordering. + DenseMap operationIDs; + /// A block containing the set of operations defined to create forward /// references. Block forwardRefOps; + /// A block containing previously created, and no longer used, forward /// reference operations. Block openForwardRefOps; + /// An operation state used when instantiating forward references. OperationState forwardRefOpState; @@ -1597,6 +1643,165 @@ dialectReader, bufferOwnerRef); } +//===----------------------------------------------------------------------===// +// UseListOrder Helpers + +FailureOr +BytecodeReader::Impl::parseUseListOrderForRange(EncodingReader &reader, + uint64_t numResults) { + BytecodeReader::Impl::UseListMapT map; + uint64_t numValuesToRead = 1; + if (numResults > 1 && failed(reader.parseVarInt(numValuesToRead))) + return failure(); + + for (size_t valueIdx = 0; valueIdx < numValuesToRead; valueIdx++) { + uint64_t resultIdx = 0; + if (numResults > 1 && failed(reader.parseVarInt(resultIdx))) + return failure(); + + uint64_t numValues; + bool indexPairEncoding; + if (failed(reader.parseVarIntWithFlag(numValues, indexPairEncoding))) + return failure(); + + SmallVector useListOrders; + for (size_t idx = 0; idx < numValues; idx++) { + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + useListOrders.push_back(index); + } + + // Store in a map the result index + map.try_emplace(resultIdx, UseListOrderStorage(indexPairEncoding, + std::move(useListOrders))); + } + + return map; +} + +/// Sorts each use according to the order specified in the use-list parsed. If +/// the custom use-list is not found, this means that the order needs to be +/// consistent with the reverse pre-order walk of the IR. If multiple uses lie +/// on the same operation, the order will follow the reverse operand number +/// ordering. +LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { + // Early return for trivial use-lists. + if (value.use_empty() || value.hasOneUse()) + return success(); + + bool hasIncomingOrder = + valueToUseListMap.contains(value.getAsOpaquePointer()); + + // Compute the current order of the use-list with respect to the global + // ordering. Detect if the order is already sorted while doing so. + bool alreadySorted = true; + auto &firstUse = *value.use_begin(); + uint64_t prevID = + bytecode::getUseID(firstUse, operationIDs.at(firstUse.getOwner())); + llvm::SmallVector> currentOrder = {{0, prevID}}; + for (auto item : llvm::drop_begin(llvm::enumerate(value.getUses()))) { + uint64_t currentID = bytecode::getUseID( + item.value(), operationIDs.at(item.value().getOwner())); + alreadySorted &= prevID > currentID; + currentOrder.push_back({item.index(), currentID}); + prevID = currentID; + } + + // If the order is already sorted, and there wasn't a custom order to apply + // from the bytecode file, we are done. + if (alreadySorted && !hasIncomingOrder) + return success(); + + // If not already sorted, sort the indices of the current order by descending + // useIDs. + if (!alreadySorted) + std::sort( + currentOrder.begin(), currentOrder.end(), + [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); + + if (!hasIncomingOrder) { + // If the bytecode file did not contain any custom use-list order, it means + // that the order was descending useID. Hence, shuffle by the first index + // of the `currentOrder` pair. + SmallVector shuffle = SmallVector( + llvm::map_range(currentOrder, [&](auto item) { return item.first; })); + value.shuffleUseList(shuffle); + return success(); + } + + // Pull the custom order info from the map. + UseListOrderStorage customOrder = + valueToUseListMap.at(value.getAsOpaquePointer()); + SmallVector shuffle = std::move(customOrder.indices); + uint64_t numUses = + std::distance(value.getUses().begin(), value.getUses().end()); + + // If the encoding was a pair of indices `(src, dst)` for every permutation, + // reconstruct the shuffle vector for every use. Initialize the shuffle vector + // as identity, and then apply the mapping encoded in the indices. + if (customOrder.isIndexPairEncoding) { + // Return failure if the number of indices was not representing pairs. + if (shuffle.size() & 1) + return failure(); + + SmallVector newShuffle(numUses); + size_t idx = 0; + std::iota(newShuffle.begin(), newShuffle.end(), idx); + for (idx = 0; idx < shuffle.size(); idx += 2) + newShuffle[shuffle[idx]] = shuffle[idx + 1]; + + shuffle = std::move(newShuffle); + } + + // Make sure that the indices represent a valid mapping. That is, the sum of + // all the values needs to be equal to (numUses - 1) * numUses / 2, and no + // duplicates are allowed in the list. + DenseSet set; + uint64_t accumulator = 0; + for (const auto &elem : shuffle) { + if (set.contains(elem)) + return failure(); + accumulator += elem; + set.insert(elem); + } + if (numUses != shuffle.size() || + accumulator != (((numUses - 1) * numUses) >> 1)) + return failure(); + + // Apply the current ordering map onto the shuffle vector to get the final + // use-list sorting indices before shuffling. + shuffle = SmallVector(llvm::map_range( + currentOrder, [&](auto item) { return shuffle[item.first]; })); + value.shuffleUseList(shuffle); + return success(); +} + +LogicalResult BytecodeReader::Impl::processUseLists(Operation *topLevelOp) { + // Precompute operation IDs according to the pre-order walk of the IR. We + // can't do this while parsing since parseRegions ordering is not strictly + // equal to the pre-order walk. + unsigned operationID = 0; + topLevelOp->walk( + [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); + + auto blockWalk = topLevelOp->walk([this](Block *block) { + for (auto arg : block->getArguments()) + if (failed(sortUseListOrder(arg))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + auto resultWalk = topLevelOp->walk([this](Operation *op) { + for (auto result : op->getResults()) + if (failed(sortUseListOrder(result))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + return failure(blockWalk.wasInterrupted() || resultWalk.wasInterrupted()); +} + //===----------------------------------------------------------------------===// // IR Section @@ -1627,6 +1832,11 @@ "not all forward unresolved forward operand references"); } + // Sort use-lists according to what specified in bytecode. + if (failed(processUseLists(*moduleOp))) + return reader.emitError( + "parsed use-list orders were invalid and could not be applied"); + // Resolve dialect version. for (const BytecodeDialect &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the @@ -1812,6 +2022,17 @@ } } + /// Parse the use-list orders for the results of the operation. Use-list + /// orders are available since version 3 of the bytecode. + std::optional resultIdxToUseListMap = std::nullopt; + if (version > 2 && (opMask & bytecode::OpEncodingMask::kHasUseListOrders)) { + size_t numResults = opState.types.size(); + auto parseResult = parseUseListOrderForRange(reader, numResults); + if (failed(parseResult)) + return failure(); + resultIdxToUseListMap = std::move(*parseResult); + } + /// Parse the regions of the operation. if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { uint64_t numRegions; @@ -1831,6 +2052,16 @@ if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) return failure(); + /// Store a map for every value that received a custom use-list order from the + /// bytecode file. + if (resultIdxToUseListMap.has_value()) { + for (size_t idx = 0; idx < op->getNumResults(); idx++) { + if (resultIdxToUseListMap->contains(idx)) { + valueToUseListMap.try_emplace(op->getResult(idx).getAsOpaquePointer(), + resultIdxToUseListMap->at(idx)); + } + } + } return op; } @@ -1880,6 +2111,28 @@ if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) return failure(); + // Uselist orders are available since version 3 of the bytecode. + if (version < 3) + return success(); + + uint8_t hasUseListOrders = 0; + if (hasArgs && failed(reader.parseByte(hasUseListOrders))) + return failure(); + + if (!hasUseListOrders) + return success(); + + Block &blk = *readState.curBlock; + auto argIdxToUseListMap = + parseUseListOrderForRange(reader, blk.getNumArguments()); + if (failed(argIdxToUseListMap) || argIdxToUseListMap->empty()) + return failure(); + + for (size_t idx = 0; idx < blk.getNumArguments(); idx++) + if (argIdxToUseListMap->contains(idx)) + valueToUseListMap.try_emplace(blk.getArgument(idx).getAsOpaquePointer(), + argIdxToUseListMap->at(idx)); + // We don't parse the operations of the block here, that's done elsewhere. return success(); } diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Bytecode/BytecodeWriter.h" -#include "../Encoding.h" #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Bytecode/Encoding.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" @@ -470,6 +470,12 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Helpers + + void writeUseListOrders(EncodingEmitter &emitter, uint8_t &opEncodingMask, + ValueRange range); + //===--------------------------------------------------------------------===// // Fields @@ -667,6 +673,14 @@ emitter.emitVarInt(numberingState.getNumber(arg.getType())); emitter.emitVarInt(numberingState.getNumber(arg.getLoc())); } + if (config.bytecodeVersion > 2) { + uint64_t maskOffset = emitter.size(); + uint8_t encodingMask = 0; + emitter.emitByte(0); + writeUseListOrders(emitter, encodingMask, args); + if (encodingMask) + emitter.patchByte(maskOffset, encodingMask); + } } // Emit the operations within the block. @@ -718,6 +732,11 @@ emitter.emitVarInt(numberingState.getNumber(successor)); } + // Emit the use-list orders to bytecode, so we can reconstruct the same order + // at parsing. + if (config.bytecodeVersion > 2) + writeUseListOrders(emitter, opEncodingMask, ValueRange(op->getResults())); + // Check for regions. unsigned numRegions = op->getNumRegions(); if (numRegions) @@ -749,6 +768,94 @@ } } +void BytecodeWriter::writeUseListOrders(EncodingEmitter &emitter, + uint8_t &opEncodingMask, + ValueRange range) { + // Loop over the results and store the use-list order per result index. + DenseMap> map; + for (auto item : llvm::enumerate(range)) { + auto value = item.value(); + // No need to store a custom use-list order if the result does not have + // multiple uses. + if (value.use_empty() || value.hasOneUse()) + continue; + + // For each result, assemble the list of pairs (use-list-index, + // global-value-index). While doing so, detect if the global-value-index is + // already ordered with respect to the use-list-index. + bool alreadyOrdered = true; + auto &firstUse = *value.use_begin(); + uint64_t prevID = bytecode::getUseID( + firstUse, numberingState.getNumber(firstUse.getOwner())); + llvm::SmallVector> useListPairs( + {{0, prevID}}); + + for (auto use : llvm::drop_begin(llvm::enumerate(value.getUses()))) { + uint64_t currentID = bytecode::getUseID( + use.value(), numberingState.getNumber(use.value().getOwner())); + // The use-list order achieved when building the IR at parsing always + // pushes new uses on front. Hence, if the order by unique ID is + // monotonically decreasing, a roundtrip to bytecode preserves such order. + alreadyOrdered &= (prevID > currentID); + useListPairs.push_back({use.index(), currentID}); + prevID = currentID; + } + + // Do not emit if the order is already sorted. + if (alreadyOrdered) + continue; + + // Sort the use indices by the unique ID indices in descending order. + std::sort( + useListPairs.begin(), useListPairs.end(), + [](auto elem1, auto elem2) { return elem1.second > elem2.second; }); + + map.try_emplace(item.index(), llvm::map_range(useListPairs, [](auto elem) { + return elem.first; + })); + } + + if (map.empty()) + return; + + opEncodingMask |= bytecode::OpEncodingMask::kHasUseListOrders; + // Emit the number of results that have a custom use-list order if the number + // of results is greater than one. + if (range.size() != 1) + emitter.emitVarInt(map.size()); + + for (const auto &item : map) { + auto resultIdx = item.getFirst(); + auto useListOrder = item.getSecond(); + + // Compute the number of uses that are actually shuffled. If those are less + // than half of the total uses, encoding the index pair `(src, dst)` is more + // space efficient. + size_t shuffledElements = + llvm::count_if(llvm::enumerate(useListOrder), + [](auto item) { return item.index() != item.value(); }); + bool indexPairEncoding = shuffledElements < (useListOrder.size() / 2); + + // For single result, we don't need to store the result index. + if (range.size() != 1) + emitter.emitVarInt(resultIdx); + + if (indexPairEncoding) { + emitter.emitVarIntWithFlag(shuffledElements * 2, indexPairEncoding); + for (auto pair : llvm::enumerate(useListOrder)) { + if (pair.index() != pair.value()) { + emitter.emitVarInt(pair.value()); + emitter.emitVarInt(pair.index()); + } + } + } else { + emitter.emitVarIntWithFlag(useListOrder.size(), indexPairEncoding); + for (const auto &index : useListOrder) + emitter.emitVarInt(index); + } + } +} + void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) { // If the region is empty, we only need to emit the number of blocks (which is // zero). diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -152,6 +152,10 @@ assert(blockIDs.count(block) && "block not numbered"); return blockIDs[block]; } + unsigned getNumber(Operation *op) { + assert(operationIDs.count(op) && "operation not numbered"); + return operationIDs[op]; + } unsigned getNumber(OperationName opName) { assert(opNames.count(opName) && "opName not numbered"); return opNames[opName]->number; @@ -224,7 +228,8 @@ llvm::SpecificBumpPtrAllocator resourceAllocator; llvm::SpecificBumpPtrAllocator typeAllocator; - /// The value ID for each Block and Value. + /// The value ID for each Operation, Block and Value. + DenseMap operationIDs; DenseMap blockIDs; DenseMap valueIDs; diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -7,9 +7,7 @@ //===----------------------------------------------------------------------===// #include "IRNumbering.h" -#include "../Encoding.h" #include "mlir/Bytecode/BytecodeImplementation.h" -#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -109,6 +107,12 @@ } IRNumberingState::IRNumberingState(Operation *op) { + // Compute a global operation ID numbering according to the pre-order walk of + // the IR. This is used as reference to construct use-list orders. + unsigned operationID = 0; + op->walk( + [&](Operation *op) { operationIDs.try_emplace(op, operationID++); }); + // Number the root operation. number(*op); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -93,6 +93,11 @@ }); } +/// Shuffles the use-list order according to the provided indices. +void Value::shuffleUseList(ArrayRef indices) { + getImpl()->shuffleUseList(indices); +} + //===----------------------------------------------------------------------===// // OpResult //===----------------------------------------------------------------------===// diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 2 +// VERSION: bytecode version 127 is newer than the current version 3 //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/Bytecode/uselist_orders.mlir b/mlir/test/Bytecode/uselist_orders.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/uselist_orders.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt %s -split-input-file --test-verify-uselistorder -verify-diagnostics + +// COM: --test-verify-uselistorder will randomly shuffle the uselist of every +// value and do a roundtrip to bytecode. An error is returned if the +// uselist order are not preserved when doing a roundtrip to bytecode. The +// test needs to verify diagnostics to be functional. + +func.func @base_test(%arg0 : i32) -> i32 { + %0 = arith.constant 45 : i32 + %1 = arith.constant 46 : i32 + %2 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %3 = "test.addi"(%2, %0) : (i32, i32) -> i32 + %4 = "test.addi"(%2, %1) : (i32, i32) -> i32 + %5 = "test.addi"(%3, %4) : (i32, i32) -> i32 + %6 = "test.addi"(%5, %4) : (i32, i32) -> i32 + %7 = "test.addi"(%6, %4) : (i32, i32) -> i32 + return %7 : i32 +} + +// ----- + +func.func @test_with_multiple_uses_in_same_op(%arg0 : i32) -> i32 { + %0 = arith.constant 45 : i32 + %1 = arith.constant 46 : i32 + %2 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %3 = "test.addi"(%2, %0) : (i32, i32) -> i32 + %4 = "test.addi"(%2, %1) : (i32, i32) -> i32 + %5 = "test.addi"(%2, %2) : (i32, i32) -> i32 + %6 = "test.addi"(%3, %4) : (i32, i32) -> i32 + %7 = "test.addi"(%6, %5) : (i32, i32) -> i32 + %8 = "test.addi"(%7, %4) : (i32, i32) -> i32 + %9 = "test.addi"(%8, %4) : (i32, i32) -> i32 + return %9 : i32 +} + +// ----- + +func.func @test_with_multiple_block_arg_uses(%arg0 : i32) -> i32 { + %0 = arith.constant 45 : i32 + %1 = arith.constant 46 : i32 + %2 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %3 = "test.addi"(%2, %arg0) : (i32, i32) -> i32 + %4 = "test.addi"(%2, %1) : (i32, i32) -> i32 + %5 = "test.addi"(%2, %2) : (i32, i32) -> i32 + %6 = "test.addi"(%3, %4) : (i32, i32) -> i32 + %7 = "test.addi"(%6, %5) : (i32, i32) -> i32 + %8 = "test.addi"(%7, %4) : (i32, i32) -> i32 + %9 = "test.addi"(%8, %4) : (i32, i32) -> i32 + return %9 : i32 +} + +// ----- + +// Test that use-lists in region with no dominance are preserved +test.graph_region { + %0 = "test.foo"(%1) : (i32) -> i32 + test.graph_region attributes {a} { + %a = "test.a"(%b) : (i32) -> i32 + %b = "test.b"(%2) : (i32) -> i32 + } + %1 = "test.bar"(%2) : (i32) -> i32 + %2 = "test.baz"() : () -> i32 +} diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -18,6 +18,7 @@ TestSymbolUses.cpp TestRegions.cpp TestTypes.cpp + TestUseListOrders.cpp TestVisitors.cpp TestVisitorsGeneric.cpp diff --git a/mlir/test/lib/IR/TestUseListOrders.cpp b/mlir/test/lib/IR/TestUseListOrders.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestUseListOrders.cpp @@ -0,0 +1,219 @@ +//===- TestPrintDefUse.cpp - Passes to illustrate the IR def-use chains ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/Bytecode/Encoding.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" + +#include +#include + +using namespace mlir; + +namespace { +/// This pass tests that: +/// 1) we can shuffle use-lists correctly; +/// 2) use-list orders are preserved after a roundtrip to bytecode. +class TestPreserveUseListOrders + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreserveUseListOrders) + + TestPreserveUseListOrders() = default; + TestPreserveUseListOrders(const TestPreserveUseListOrders &pass) + : PassWrapper(pass) {} + StringRef getArgument() const final { return "test-verify-uselistorder"; } + StringRef getDescription() const final { + return "Verify that roundtripping the IR to bytecode preserves the order " + "of the uselists"; + } + Option rngSeed{*this, "rng-seed", + llvm::cl::desc("Specify an input random seed"), + llvm::cl::init(1)}; + void runOnOperation() override { + // Clone the module so that we can plug in this pass to any other + // independently. + auto cloneModule = getOperation().clone(); + + // 1. Compute the op numbering of the module. + computeOpNumbering(cloneModule); + + // 2. Loop over all the values and shuffle the uses. While doing so, check + // that each shuffle is correct. + if (failed(shuffleUses(cloneModule))) + return signalPassFailure(); + + // 3. Do a bytecode roundtrip to version 3, which supports use-list order + // preservation. + auto roundtripModuleOr = doRoundtripToBytecode(cloneModule, 3); + // If the bytecode roundtrip failed, try to roundtrip the original module + // to version 2, which does not support use-list. If this also fails, the + // original module had an issue unrelated to uselists. + if (failed(roundtripModuleOr)) { + auto testModuleOr = doRoundtripToBytecode(getOperation(), 2); + if (failed(testModuleOr)) + return; + + return signalPassFailure(); + } + + // 4. Recompute the op numbering on the new module. The numbering should be + // the same as (1), but on the new operation pointers. + computeOpNumbering(roundtripModuleOr->get()); + + // 5. Loop over all the values and verify that the use-list is consistent + // with the post-shuffle order of step (2). + if (failed(verifyUseListOrders(roundtripModuleOr->get()))) + return signalPassFailure(); + } + +private: + FailureOr> doRoundtripToBytecode(Operation *module, + uint32_t version) { + std::string str; + llvm::raw_string_ostream m(str); + BytecodeWriterConfig config; + config.setDesiredBytecodeVersion(version); + if (failed(writeBytecodeToFile(module, m, config))) + return failure(); + + ParserConfig parseConfig(&getContext(), /*verifyAfterParse=*/true); + auto newModuleOp = parseSourceString(StringRef(str), parseConfig); + if (!newModuleOp.get()) + return failure(); + return newModuleOp; + } + + /// Compute an ordered numbering for all the operations in the IR. + void computeOpNumbering(Operation *topLevelOp) { + uint32_t operationID = 0; + opNumbering.clear(); + topLevelOp->walk( + [&](Operation *op) { opNumbering.try_emplace(op, operationID++); }); + } + + template + SmallVector getUseIDs(ValueT val) { + return SmallVector(llvm::map_range(val.getUses(), [&](auto &use) { + return bytecode::getUseID(use, opNumbering.at(use.getOwner())); + })); + } + + LogicalResult shuffleUses(Operation *topLevelOp) { + uint32_t valueID = 0; + /// Permute randomly the use-list of each value. It is guaranteed that at + /// least one pair of the use list is permuted. + auto doShuffleForRange = [&](ValueRange range) -> LogicalResult { + for (auto val : range) { + if (val.use_empty() || val.hasOneUse()) + continue; + + /// Get a valid index permutation for the uses of value. + SmallVector permutation = getRandomPermutation(val); + + /// Store original order and verify that the shuffle was applied + /// correctly. + auto useIDs = getUseIDs(val); + + /// Apply shuffle to the uselist. + val.shuffleUseList(permutation); + + /// Get the new order and verify the shuffle happened correctly. + auto permutedIDs = getUseIDs(val); + if (permutedIDs.size() != useIDs.size()) + return failure(); + for (size_t idx = 0; idx < permutation.size(); idx++) + if (useIDs[idx] != permutedIDs[permutation[idx]]) + return failure(); + + referenceUseListOrder.try_emplace( + valueID++, llvm::map_range(val.getUses(), [&](auto &use) { + return bytecode::getUseID(use, opNumbering.at(use.getOwner())); + })); + } + return success(); + }; + + return walkOverValues(topLevelOp, doShuffleForRange); + } + + LogicalResult verifyUseListOrders(Operation *topLevelOp) { + uint32_t valueID = 0; + /// Check that the use-list for the value range matches the one stored in + /// the reference. + auto doValidationForRange = [&](ValueRange range) -> LogicalResult { + for (auto val : range) { + if (val.use_empty() || val.hasOneUse()) + continue; + auto referenceOrder = referenceUseListOrder.at(valueID++); + for (auto [use, referenceID] : + llvm::zip(val.getUses(), referenceOrder)) { + uint64_t uniqueID = + bytecode::getUseID(use, opNumbering.at(use.getOwner())); + if (uniqueID != referenceID) { + use.getOwner()->emitError() + << "found use-list order mismatch for value: " << val; + return failure(); + } + } + } + return success(); + }; + + return walkOverValues(topLevelOp, doValidationForRange); + } + + /// Walk over blocks and operations and execute a callable over the ranges of + /// operands/results respectively. + template + LogicalResult walkOverValues(Operation *topLevelOp, FuncT callable) { + auto blockWalk = topLevelOp->walk([&](Block *block) { + if (failed(callable(block->getArguments()))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + if (blockWalk.wasInterrupted()) + return failure(); + + auto resultsWalk = topLevelOp->walk([&](Operation *op) { + if (failed(callable(op->getResults()))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + + return failure(resultsWalk.wasInterrupted()); + } + + /// Creates a random permutation of the uselist order chain of the provided + /// value. + SmallVector getRandomPermutation(Value value) { + size_t numUses = std::distance(value.use_begin(), value.use_end()); + SmallVector permutation(numUses); + unsigned zero = 0; + std::iota(permutation.begin(), permutation.end(), zero); + auto rng = std::default_random_engine(rngSeed); + std::shuffle(permutation.begin(), permutation.end(), rng); + return permutation; + } + + /// Map each value to its use-list order encoded with unique use IDs. + DenseMap> referenceUseListOrder; + + /// Map each operation to its global ID. + DenseMap opNumbering; +}; +} // namespace + +namespace mlir { +void registerTestPreserveUseListOrders() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -53,6 +53,7 @@ void registerTestPrintDefUsePass(); void registerTestPrintInvalidPass(); void registerTestPrintNestingPass(); +void registerTestPreserveUseListOrders(); void registerTestReducer(); void registerTestSpirvEntryPointABIPass(); void registerTestSpirvModuleCombinerPass(); @@ -167,6 +168,7 @@ registerTestPrintDefUsePass(); registerTestPrintInvalidPass(); registerTestPrintNestingPass(); + registerTestPreserveUseListOrders(); registerTestReducer(); registerTestSpirvEntryPointABIPass(); registerTestSpirvModuleCombinerPass();