diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -314,6 +314,12 @@ The IR section contains the encoded form of operations within the bytecode. +``` +ir_section { + block: block; // Single block without arguments. +} +``` + #### Operation Encoding ``` @@ -334,7 +340,9 @@ successors: varint[], regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove) - regions: region[] + + // regions are stored in a section if isIsolatedFromAbove + regions: (region | region_section)[] } ``` diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h --- a/mlir/include/mlir/Bytecode/BytecodeReader.h +++ b/mlir/include/mlir/Bytecode/BytecodeReader.h @@ -15,6 +15,9 @@ #include "mlir/IR/AsmState.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include +#include namespace llvm { class MemoryBufferRef; @@ -22,6 +25,59 @@ } // namespace llvm namespace mlir { + +/// The BytecodeReader allows to load MLIR bytecode files, while keeping the +/// state explicitly available in order to support lazy loading. +/// The `finalize` method must be called before destruction. +class BytecodeReader { +public: + /// Create a bytecode reader for the given buffer. If `lazyLoad` is true, + /// isolated regions aren't loaded eagerly. + explicit BytecodeReader( + llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, + const std::shared_ptr &bufferOwnerRef = {}); + ~BytecodeReader(); + + /// Read the operations defined within the given memory buffer, containing + /// MLIR bytecode, into the provided block. If the reader was created with + /// `lazyLoad` enabled, isolated regions aren't loaded eagerly. + /// The lazyOps call back is invoked for every ops that can be lazy-loaded. + /// This let the client decide if the op should be materialized + /// immediately or delayed. + LogicalResult readTopLevel( + Block *block, llvm::function_ref lazyOps = + [](Operation *) { return false; }); + + /// Return the number of ops that haven't been materialized yet. + int64_t getNumOpsToMaterialize() const; + + /// Return true if the provided op is materializable. + bool isMaterializable(Operation *op); + + /// Materialize the provide operation. The provided operation must be + /// materializable. + /// The lazyOps call back is invoked for every ops that can be lazy-loaded. + /// This let the client decide if the op should be materialized immediately or + /// delayed. + /// !! Using this materialize withing an IR walk() can be confusing: make sure + /// to use a PreOrder traversal !! + LogicalResult materialize( + Operation *op, llvm::function_ref lazyOpsCallback = + [](Operation *) { return false; }); + + /// Finalize the lazy-loading by calling back with every op that hasn't been + /// materialized to let the client decide if the op should be deleted or + /// materialized. The op is materialized if the callback returns true, deleted + /// otherwise. The implementation of the callback must be thread-safe. + LogicalResult finalize(function_ref shouldMaterialize = + [](Operation *) { return true; }); + + class Impl; + +private: + std::unique_ptr impl; +}; + /// Returns true if the given buffer starts with the magic bytes that signal /// MLIR bytecode. bool isBytecode(llvm::MemoryBufferRef buffer); @@ -36,6 +92,7 @@ LogicalResult readBytecodeFile(const std::shared_ptr &sourceMgr, Block *block, const ParserConfig &config); + } // namespace mlir #endif // MLIR_BYTECODE_BYTECODEREADER_H diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -27,7 +27,7 @@ kMinSupportedVersion = 0, /// The current bytecode version. - kVersion = 1, + kVersion = 2, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -17,6 +17,9 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" @@ -24,6 +27,8 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include +#include #include #define DEBUG_TYPE "mlir-bytecode-reader" @@ -1092,25 +1097,93 @@ // Bytecode Reader //===----------------------------------------------------------------------===// -namespace { /// This class is used to read a bytecode buffer and translate it into MLIR. -class BytecodeReader { +class mlir::BytecodeReader::Impl { + struct RegionReadState; + using LazyLoadableOpsInfo = + std::list>; + using LazyLoadableOpsMap = + DenseMap; + public: - BytecodeReader(Location fileLoc, const ParserConfig &config, - const std::shared_ptr &bufferOwnerRef) - : config(config), fileLoc(fileLoc), + Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, + llvm::MemoryBufferRef buffer, + const std::shared_ptr &bufferOwnerRef) + : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), attrTypeReader(stringReader, resourceReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), "builtin.unrealized_conversion_cast", ValueRange(), NoneType::get(config.getContext())), - bufferOwnerRef(bufferOwnerRef) {} + buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} /// Read the bytecode defined within `buffer` into the given block. - LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); + LogicalResult read(Block *block, + llvm::function_ref lazyOps); + + /// Return the number of ops that haven't been materialized yet. + int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } + + bool isMaterializable(Operation *op) { return lazyLoadableOpsMap.count(op); } + + /// Materialize the provided operation, invoke the lazyOpsCallback on every + /// newly found lazy operation. + LogicalResult + materialize(Operation *op, + llvm::function_ref lazyOpsCallback) { + this->lazyOpsCallback = lazyOpsCallback; + auto resetlazyOpsCallback = + llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); + auto it = lazyLoadableOpsMap.find(op); + assert(it != lazyLoadableOpsMap.end() && + "materialize called on non-materializable op"); + return materialize(it); + } + + /// Materialize all operations. + LogicalResult materializeAll() { + while (!lazyLoadableOpsMap.empty()) { + if (failed(materialize(lazyLoadableOpsMap.begin()))) + return failure(); + } + return success(); + } + + /// Finalize the lazy-loading by calling back with every op that hasn't been + /// materialized to let the client decide if the op should be deleted or + /// materialized. The op is materialized if the callback returns true, deleted + /// otherwise. + LogicalResult finalize(function_ref shouldMaterialize) { + while (!lazyLoadableOps.empty()) { + Operation *op = lazyLoadableOps.begin()->first; + if (shouldMaterialize(op)) { + if (failed(materialize(lazyLoadableOpsMap.find(op)))) + return failure(); + continue; + } + op->dropAllReferences(); + op->erase(); + lazyLoadableOps.pop_front(); + lazyLoadableOpsMap.erase(op); + } + return success(); + } private: + LogicalResult materialize(LazyLoadableOpsMap::iterator it) { + assert(it != lazyLoadableOpsMap.end() && + "materialize called on non-materializable op"); + valueScopes.emplace_back(); + std::vector regionStack; + regionStack.push_back(std::move(it->getSecond()->second)); + lazyLoadableOps.erase(it->getSecond()); + lazyLoadableOpsMap.erase(it); + auto result = parseRegions(regionStack, regionStack.back()); + assert(regionStack.empty()); + return result; + } + /// Return the context for this config. MLIRContext *getContext() const { return config.getContext(); } @@ -1151,14 +1224,22 @@ /// This struct represents the current read state of a range of regions. This /// struct is used to enable iterative parsing of regions. struct RegionReadState { - RegionReadState(Operation *op, bool isIsolatedFromAbove) - : RegionReadState(op->getRegions(), isIsolatedFromAbove) {} - RegionReadState(MutableArrayRef regions, bool isIsolatedFromAbove) - : curRegion(regions.begin()), endRegion(regions.end()), + RegionReadState(Operation *op, EncodingReader *reader, + bool isIsolatedFromAbove) + : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} + RegionReadState(MutableArrayRef regions, EncodingReader *reader, + bool isIsolatedFromAbove) + : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), isIsolatedFromAbove(isIsolatedFromAbove) {} /// The current regions being read. MutableArrayRef::iterator curRegion, endRegion; + /// This is the reader to use for this region, this pointer is pointing to + /// the parent region reader unless the current region is IsolatedFromAbove, + /// in which case the pointer is pointing to the `owningReader` which is a + /// section dedicated to the current region. + EncodingReader *reader; + std::unique_ptr owningReader; /// The number of values defined immediately within this region. unsigned numValues = 0; @@ -1176,15 +1257,15 @@ }; LogicalResult parseIRSection(ArrayRef sectionData, Block *block); - LogicalResult parseRegions(EncodingReader &reader, - std::vector ®ionStack, + LogicalResult parseRegions(std::vector ®ionStack, RegionReadState &readState); FailureOr parseOpWithoutRegions(EncodingReader &reader, RegionReadState &readState, bool &isIsolatedFromAbove); - LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState); - LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); + LogicalResult parseRegion(RegionReadState &readState); + LogicalResult parseBlockHeader(EncodingReader &reader, + RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); //===--------------------------------------------------------------------===// @@ -1234,6 +1315,16 @@ /// A location to use when emitting errors. Location fileLoc; + /// Flag that indicates if lazyloading is enabled. + bool lazyLoading; + + /// Keep track of operations that have been lazy loaded (their regions haven't + /// been materialized), along with the `RegionReadState` that allows to + /// lazy-load the regions nested under the operation. + LazyLoadableOpsInfo lazyLoadableOps; + LazyLoadableOpsMap lazyLoadableOpsMap; + llvm::function_ref lazyOpsCallback; + /// The reader used to process attribute and types within the bytecode. AttrTypeReader attrTypeReader; @@ -1264,14 +1355,20 @@ /// An operation state used when instantiating forward references. OperationState forwardRefOpState; + /// Reference to the input buffer. + llvm::MemoryBufferRef buffer; + /// The optional owning source manager, which when present may be used to /// extend the lifetime of the input buffer. const std::shared_ptr &bufferOwnerRef; }; -} // namespace -LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { +LogicalResult BytecodeReader::Impl::read( + Block *block, llvm::function_ref lazyOpsCallback) { EncodingReader reader(buffer.getBuffer(), fileLoc); + this->lazyOpsCallback = lazyOpsCallback; + auto resetlazyOpsCallback = + llvm::make_scope_exit([&] { this->lazyOpsCallback = nullptr; }); // Skip over the bytecode header, this should have already been checked. if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) @@ -1302,7 +1399,7 @@ // Check for duplicate sections, we only expect one instance of each. if (sectionDatas[sectionID]) { return reader.emitError("duplicate top-level section: ", - toString(sectionID)); + ::toString(sectionID)); } sectionDatas[sectionID] = sectionData; } @@ -1311,7 +1408,7 @@ bytecode::Section::ID sectionID = static_cast(i); if (!sectionDatas[i] && !isSectionOptional(sectionID)) { return reader.emitError("missing data for top-level section: ", - toString(sectionID)); + ::toString(sectionID)); } } @@ -1340,7 +1437,7 @@ return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); } -LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { +LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { if (failed(reader.parseVarInt(version))) return failure(); @@ -1357,6 +1454,9 @@ " is newer than the current version ", currentVersion); } + // Override any request to lazy-load if the bytecode version is too old. + if (version < 2) + lazyLoading = false; return success(); } @@ -1396,7 +1496,7 @@ } LogicalResult -BytecodeReader::parseDialectSection(ArrayRef sectionData) { +BytecodeReader::Impl::parseDialectSection(ArrayRef sectionData) { EncodingReader sectionReader(sectionData, fileLoc); // Parse the number of dialects in the section. @@ -1449,7 +1549,8 @@ return success(); } -FailureOr BytecodeReader::parseOpName(EncodingReader &reader) { +FailureOr +BytecodeReader::Impl::parseOpName(EncodingReader &reader) { BytecodeOperationName *opName = nullptr; if (failed(parseEntry(reader, opNames, opName, "operation name"))) return failure(); @@ -1471,7 +1572,7 @@ //===----------------------------------------------------------------------===// // Resource Section -LogicalResult BytecodeReader::parseResourceSection( +LogicalResult BytecodeReader::Impl::parseResourceSection( EncodingReader &reader, std::optional> resourceData, std::optional> resourceOffsetData) { // Ensure both sections are either present or not. @@ -1499,8 +1600,9 @@ //===----------------------------------------------------------------------===// // IR Section -LogicalResult BytecodeReader::parseIRSection(ArrayRef sectionData, - Block *block) { +LogicalResult +BytecodeReader::Impl::parseIRSection(ArrayRef sectionData, + Block *block) { EncodingReader reader(sectionData, fileLoc); // A stack of operation regions currently being read from the bytecode. @@ -1508,17 +1610,17 @@ // Parse the top-level block using a temporary module operation. OwningOpRef moduleOp = ModuleOp::create(fileLoc); - regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true); + regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); regionStack.back().curBlocks.push_back(moduleOp->getBody()); regionStack.back().curBlock = regionStack.back().curRegion->begin(); - if (failed(parseBlock(reader, regionStack.back()))) + if (failed(parseBlockHeader(reader, regionStack.back()))) return failure(); valueScopes.emplace_back(); valueScopes.back().push(regionStack.back()); // Iteratively parse regions until everything has been resolved. while (!regionStack.empty()) - if (failed(parseRegions(reader, regionStack, regionStack.back()))) + if (failed(parseRegions(regionStack, regionStack.back()))) return failure(); if (!forwardRefOps.empty()) { return reader.emitError( @@ -1549,15 +1651,18 @@ } LogicalResult -BytecodeReader::parseRegions(EncodingReader &reader, - std::vector ®ionStack, - RegionReadState &readState) { - // Read the regions of this operation. +BytecodeReader::Impl::parseRegions(std::vector ®ionStack, + RegionReadState &readState) { + // Process regions, blocks, and operations until the end or if a nested + // region is encountered. In this case we push a new state in regionStack and + // return, the processing of the current region will resume afterward. for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { // If the current block hasn't been setup yet, parse the header for this - // region. + // region. The current block is already setup when this function was + // interrupted to recurse down in a nested region and we resume the current + // block after processing the nested region. if (readState.curBlock == Region::iterator()) { - if (failed(parseRegion(reader, readState))) + if (failed(parseRegion(readState))) return failure(); // If the region is empty, there is nothing to more to do. @@ -1566,6 +1671,7 @@ } // Parse the blocks within the region. + EncodingReader &reader = *readState.reader; do { while (readState.numOpsRemaining--) { // Read in the next operation. We don't read its regions directly, we @@ -1576,9 +1682,38 @@ if (failed(op)) return failure(); - // If the op has regions, add it to the stack for processing. + // If the op has regions, add it to the stack for processing and return: + // we stop the processing of the current region and resume it after the + // inner one is completed. Unless LazyLoading is activated in which case + // nested region parsing is delayed. if ((*op)->getNumRegions()) { - regionStack.emplace_back(*op, isIsolatedFromAbove); + RegionReadState childState(*op, &reader, isIsolatedFromAbove); + + // Isolated regions are encoded as a section in version 2 and above. + if (version >= 2 && isIsolatedFromAbove) { + bytecode::Section::ID sectionID; + ArrayRef sectionData; + if (failed(reader.parseSection(sectionID, sectionData))) + return failure(); + if (sectionID != bytecode::Section::kIR) + return emitError(fileLoc, "expected IR section for region"); + childState.owningReader = + std::make_unique(sectionData, fileLoc); + childState.reader = childState.owningReader.get(); + } + + if (lazyLoading) { + // If the user has a callback set, they have the opportunity + // to control lazyloading as we go. + if (!lazyOpsCallback || !lazyOpsCallback(*op)) { + lazyLoadableOps.push_back( + std::make_pair(*op, std::move(childState))); + lazyLoadableOpsMap.try_emplace(*op, + std::prev(lazyLoadableOps.end())); + continue; + } + } + regionStack.push_back(std::move(childState)); // If the op is isolated from above, push a new value scope. if (isIsolatedFromAbove) @@ -1590,7 +1725,7 @@ // Move to the next block of the region. if (++readState.curBlock == readState.curRegion->end()) break; - if (failed(parseBlock(reader, readState))) + if (failed(parseBlockHeader(reader, readState))) return failure(); } while (true); @@ -1601,16 +1736,19 @@ // When the regions have been fully parsed, pop them off of the read stack. If // the regions were isolated from above, we also pop the last value scope. - if (readState.isIsolatedFromAbove) + if (readState.isIsolatedFromAbove) { + assert(!valueScopes.empty() && "Expect a valueScope after reading region"); valueScopes.pop_back(); + } + assert(!regionStack.empty() && "Expect a regionStack after reading region"); regionStack.pop_back(); return success(); } FailureOr -BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, - RegionReadState &readState, - bool &isIsolatedFromAbove) { +BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, + RegionReadState &readState, + bool &isIsolatedFromAbove) { // Parse the name of the operation. FailureOr opName = parseOpName(reader); if (failed(opName)) @@ -1696,8 +1834,9 @@ return op; } -LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, - RegionReadState &readState) { +LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { + EncodingReader &reader = *readState.reader; + // Parse the number of blocks in the region. uint64_t numBlocks; if (failed(reader.parseVarInt(numBlocks))) @@ -1727,11 +1866,12 @@ // Parse the entry block of the region. readState.curBlock = readState.curRegion->begin(); - return parseBlock(reader, readState); + return parseBlockHeader(reader, readState); } -LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, - RegionReadState &readState) { +LogicalResult +BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, + RegionReadState &readState) { bool hasArgs; if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) return failure(); @@ -1744,8 +1884,8 @@ return success(); } -LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, - Block *block) { +LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, + Block *block) { // Parse the value ID for the first argument, and the number of arguments. uint64_t numArgs; if (failed(reader.parseVarInt(numArgs))) @@ -1773,7 +1913,7 @@ //===----------------------------------------------------------------------===// // Value Processing -Value BytecodeReader::parseOperand(EncodingReader &reader) { +Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { std::vector &values = valueScopes.back().values; Value *value = nullptr; if (failed(parseEntry(reader, values, value, "value"))) @@ -1785,8 +1925,8 @@ return *value; } -LogicalResult BytecodeReader::defineValues(EncodingReader &reader, - ValueRange newValues) { +LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, + ValueRange newValues) { ValueScope &valueScope = valueScopes.back(); std::vector &values = valueScope.values; @@ -1821,7 +1961,7 @@ return success(); } -Value BytecodeReader::createForwardRef() { +Value BytecodeReader::Impl::createForwardRef() { // Check for an avaliable existing operation to use. Otherwise, create a new // fake operation to use for the reference. if (!openForwardRefOps.empty()) { @@ -1837,6 +1977,41 @@ // Entry Points //===----------------------------------------------------------------------===// +BytecodeReader::~BytecodeReader() { assert(getNumOpsToMaterialize() == 0); } + +BytecodeReader::BytecodeReader( + llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, + const std::shared_ptr &bufferOwnerRef) { + Location sourceFileLoc = + FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), + /*line=*/0, /*column=*/0); + impl = std::make_unique(sourceFileLoc, config, lazyLoading, buffer, + bufferOwnerRef); +} + +LogicalResult BytecodeReader::readTopLevel( + Block *block, llvm::function_ref lazyOpsCallback) { + return impl->read(block, lazyOpsCallback); +} + +int64_t BytecodeReader::getNumOpsToMaterialize() const { + return impl->getNumOpsToMaterialize(); +} + +bool BytecodeReader::isMaterializable(Operation *op) { + return impl->isMaterializable(op); +} + +LogicalResult BytecodeReader::materialize( + Operation *op, llvm::function_ref lazyOpsCallback) { + return impl->materialize(op, lazyOpsCallback); +} + +LogicalResult +BytecodeReader::finalize(function_ref shouldMaterialize) { + return impl->finalize(shouldMaterialize); +} + bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { return buffer.getBuffer().startswith("ML\xefR"); } @@ -1856,8 +2031,9 @@ "input buffer is not an MLIR bytecode file"); } - BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef); - return reader.read(buffer, block); + BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, + buffer, bufferOwnerRef); + return reader.read(block, /*lazyOpsCallback=*/nullptr); } LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -734,8 +734,18 @@ bool isIsolatedFromAbove = op->hasTrait(); emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove); - for (Region ®ion : op->getRegions()) - writeRegion(emitter, ®ion); + for (Region ®ion : op->getRegions()) { + // If the region is not isolated from above, or we are emitting bytecode + // targetting version <2, we don't use a section. + if (!isIsolatedFromAbove || config.bytecodeVersion < 2) { + writeRegion(emitter, ®ion); + continue; + } + + EncodingEmitter regionEmitter; + writeRegion(regionEmitter, ®ion); + emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter)); + } } } diff --git a/mlir/test/Bytecode/bytecode-lazy-loading.mlir b/mlir/test/Bytecode/bytecode-lazy-loading.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode-lazy-loading.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading)" %s -o %t | FileCheck %s +// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading{bytecode-version=1})" %s -o %t | FileCheck %s --check-prefix=OLD-BYTECODE + + +func.func @op_with_passthrough_region_args() { + %0 = arith.constant 10 : index + test.isolated_region %0 { + "test.consumer"(%0) : (index) -> () + } + %result:2 = "test.op"() : () -> (index, index) + test.isolated_region %result#1 { + "test.consumer"(%result#1) : (index) -> () + } + return +} + +// Before version 2, we can't support lazy loading. +// OLD-BYTECODE-NOT: Has 1 ops to materialize +// OLD-BYTECODE-NOT: Materializing +// OLD-BYTECODE: Has 0 ops to materialize + + +// CHECK: Has 1 ops to materialize + +// CHECK: Before Materializing... +// CHECK: "builtin.module"() ({ +// CHECK-NOT: func +// CHECK: Materializing... +// CHECK: "builtin.module"() ({ +// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({ +// CHECK-NOT: arith +// CHECK: Has 1 ops to materialize + +// CHECK: Before Materializing... +// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({ +// CHECK-NOT: arith +// CHECK: Materializing... +// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({ +// CHECK: arith +// CHECK: isolated_region +// CHECK-NOT: test.consumer +// CHECK: Has 2 ops to materialize + +// CHECK: Before Materializing... +// CHECK: test.isolated_region +// CHECK-NOT: test.consumer +// CHECK: Materializing... +// CHECK: test.isolated_region +// CHECK: ^bb0(%arg0: index): +// CHECK: test.consumer +// CHECK: Has 1 ops to materialize + +// CHECK: Before Materializing... +// CHECK: test.isolated_region +// CHECK-NOT: test.consumer +// CHECK: Materializing... +// CHECK: test.isolated_region +// CHECK: test.consumer +// CHECK: Has 0 ops to materialize diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 1 +// VERSION: bytecode version 127 is newer than the current version 2 //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -7,6 +7,7 @@ TestFunc.cpp TestInterfaces.cpp TestMatchers.cpp + TestLazyLoading.cpp TestOpaqueLoc.cpp TestOperationEquals.cpp TestPrintDefUse.cpp diff --git a/mlir/test/lib/IR/TestLazyLoading.cpp b/mlir/test/lib/IR/TestLazyLoading.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestLazyLoading.cpp @@ -0,0 +1,93 @@ +//===- TestLazyLoading.cpp - Pass to test operation lazy loading ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; + +namespace { + +/// This is a test pass which LazyLoads the current operation recursively. +struct LazyLoadingPass : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LazyLoadingPass) + + StringRef getArgument() const final { return "test-lazy-loading"; } + StringRef getDescription() const final { return "Test LazyLoading of op"; } + LazyLoadingPass() = default; + LazyLoadingPass(const LazyLoadingPass &) {} + + void runOnOperation() override { + Operation *op = getOperation(); + std::string bytecode; + { + BytecodeWriterConfig config; + if (version >= 0) + config.setDesiredBytecodeVersion(version); + llvm::raw_string_ostream os(bytecode); + if (failed(writeBytecodeToFile(op, os, config))) { + op->emitError() << "failed to write bytecode at version " + << (int)version; + signalPassFailure(); + return; + } + } + llvm::MemoryBufferRef buffer(bytecode, "test-lazy-loading"); + Block block; + ParserConfig config(op->getContext(), /*verifyAfterParse=*/false); + BytecodeReader reader(buffer, config, + /*lazyLoad=*/true); + std::list toLoadOps; + if (failed(reader.readTopLevel(&block, [&](Operation *op) { + toLoadOps.push_back(op); + return false; + }))) { + op->emitError() << "failed to read bytecode"; + return; + } + + llvm::outs() << "Has " << reader.getNumOpsToMaterialize() + << " ops to materialize\n"; + + // Recursively print the operations, before and after lazy loading. + while (!toLoadOps.empty()) { + Operation *toLoad = toLoadOps.front(); + toLoadOps.pop_front(); + llvm::outs() << "\n\nBefore Materializing...\n\n"; + toLoad->print(llvm::outs()); + llvm::outs() << "\n\nMaterializing...\n\n"; + if (failed(reader.materialize(toLoad, [&](Operation *op) { + toLoadOps.push_back(op); + return false; + }))) { + toLoad->emitError() << "failed to materialize"; + signalPassFailure(); + return; + } + toLoad->print(llvm::outs()); + llvm::outs() << "\n"; + llvm::outs() << "Has " << reader.getNumOpsToMaterialize() + << " ops to materialize\n"; + } + } + Option version{*this, "bytecode-version", + llvm::cl::desc("Specifies the bytecode version to use."), + llvm::cl::init(-1)}; +}; +} // namespace + +namespace mlir { +void registerLazyLoadingTestPasses() { PassRegistration(); } +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -31,6 +31,7 @@ namespace mlir { void registerConvertToTargetEnvPass(); void registerCloneTestPasses(); +void registerLazyLoadingTestPasses(); void registerPassManagerTestPass(); void registerPrintSpirvAvailabilityPass(); void registerLoopLikeInterfaceTestPasses(); @@ -146,6 +147,7 @@ registerConvertToTargetEnvPass(); registerPassManagerTestPass(); registerPrintSpirvAvailabilityPass(); + registerLazyLoadingTestPasses(); registerLoopLikeInterfaceTestPasses(); registerShapeFunctionTestPasses(); registerSideEffectTestPasses();