diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -314,6 +314,12 @@ The IR section contains the encoded form of operations within the bytecode. +``` +ir_section { + block: block; // Single block without arguments. +} +``` + #### Operation Encoding ``` @@ -334,7 +340,9 @@ successors: varint[], regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove) - regions: region[] + + // regions are stored in a section if isIsolatedFromAbove + regions: (region | region_section)[] } ``` diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h --- a/mlir/include/mlir/Bytecode/BytecodeReader.h +++ b/mlir/include/mlir/Bytecode/BytecodeReader.h @@ -15,6 +15,9 @@ #include "mlir/IR/AsmState.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include +#include namespace llvm { class MemoryBufferRef; @@ -22,6 +25,46 @@ } // namespace llvm namespace mlir { + +/// The BytecodeReader allows to load MLIR bytecode files, while keeping the +/// state explicitly available in order to support lazy loading. +class BytecodeReader { +public: + /// Create a bytecode reader for the given buffer. If `lazyLoad` is true, + /// isolated regions aren't loaded eagerly. + explicit BytecodeReader( + llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoad, + const std::shared_ptr &bufferOwnerRef = {}); + ~BytecodeReader(); + + /// Read the operations defined within the given memory buffer, containing + /// MLIR bytecode, into the provided block. If the reader was created with + /// `lazyLoad` enabled, isolated regions aren't loaded eagerly. + LogicalResult readTopLevel(Block *block); + + /// If the reader was created with `lazyLoad` enabled, this function allows to + /// load the isolated region for the given operation. A nullptr is returned if + /// the operation doesn't have an isolated region to load. + std::function getLazyOpMaterializer(Operation *op); + + /// Return the number of ops that haven't been materialized yet. + int64_t getNumOpsToMaterialize() const; + + /// Return the next operation to materialize, or nullptr if none. + Operation *getNextMaterializableOp() const; + + /// Materialize the given operation. + LogicalResult materialize(Operation *op); + + /// Materialize all operations. + LogicalResult materializeAll(); + + class Impl; + +private: + std::unique_ptr impl; +}; + /// Returns true if the given buffer starts with the magic bytes that signal /// MLIR bytecode. bool isBytecode(llvm::MemoryBufferRef buffer); @@ -36,6 +79,7 @@ LogicalResult readBytecodeFile(const std::shared_ptr &sourceMgr, Block *block, const ParserConfig &config); + } // namespace mlir #endif // MLIR_BYTECODE_BYTECODEREADER_H diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -27,7 +27,7 @@ kMinSupportedVersion = 0, /// The current bytecode version. - kVersion = 1, + kVersion = 2, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -17,6 +17,9 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" @@ -24,6 +27,8 @@ #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" +#include +#include #include #define DEBUG_TYPE "mlir-bytecode-reader" @@ -1084,23 +1089,67 @@ // Bytecode Reader //===----------------------------------------------------------------------===// -namespace { /// This class is used to read a bytecode buffer and translate it into MLIR. -class BytecodeReader { +class mlir::BytecodeReader::Impl { + struct RegionReadState; + using LazyLoadableOpsInfo = + std::list>; + using LazyLoadableOpsMap = + DenseMap; + public: - BytecodeReader(Location fileLoc, const ParserConfig &config, - const std::shared_ptr &bufferOwnerRef) - : config(config), fileLoc(fileLoc), + Impl(Location fileLoc, const ParserConfig &config, bool lazyLoading, + llvm::MemoryBufferRef buffer, + const std::shared_ptr &bufferOwnerRef) + : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), attrTypeReader(stringReader, resourceReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), "builtin.unrealized_conversion_cast", ValueRange(), NoneType::get(config.getContext())), - bufferOwnerRef(bufferOwnerRef) {} + buffer(buffer), bufferOwnerRef(bufferOwnerRef) {} /// Read the bytecode defined within `buffer` into the given block. - LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); + LogicalResult read(Block *block); + + std::function getLazyOpMaterializer(Operation *op) { + auto it = lazyLoadableOpsMap.find(op); + if (it == lazyLoadableOpsMap.end()) + return nullptr; + return [=]() { return materialize(it); }; + } + + LogicalResult materialize(LazyLoadableOpsMap::iterator it) { + assert(it != lazyLoadableOpsMap.end()); + valueScopes.emplace_back(); + std::vector regionStack; + regionStack.push_back(std::move(it->getSecond()->second)); + lazyLoadableOps.erase(it->getSecond()); + lazyLoadableOpsMap.erase(it); + auto result = parseRegions(regionStack, regionStack.back()); + assert(regionStack.empty()); + return result; + } + + /// Return the number of ops that haven't been materialized yet. + int64_t getNumOpsToMaterialize() const { return lazyLoadableOpsMap.size(); } + + /// Return the next operation to materialize, or nullptr if none. + Operation *getNextMaterializableOp() const { + if (lazyLoadableOps.empty()) + return nullptr; + return lazyLoadableOps.front().first; + } + + /// Materialize all operations. + LogicalResult materializeAll() { + while (!lazyLoadableOpsMap.empty()) { + if (failed(materialize(lazyLoadableOpsMap.begin()))) + return failure(); + } + return success(); + } private: /// Return the context for this config. @@ -1143,14 +1192,22 @@ /// This struct represents the current read state of a range of regions. This /// struct is used to enable iterative parsing of regions. struct RegionReadState { - RegionReadState(Operation *op, bool isIsolatedFromAbove) - : RegionReadState(op->getRegions(), isIsolatedFromAbove) {} - RegionReadState(MutableArrayRef regions, bool isIsolatedFromAbove) - : curRegion(regions.begin()), endRegion(regions.end()), + RegionReadState(Operation *op, EncodingReader *reader, + bool isIsolatedFromAbove) + : RegionReadState(op->getRegions(), reader, isIsolatedFromAbove) {} + RegionReadState(MutableArrayRef regions, EncodingReader *reader, + bool isIsolatedFromAbove) + : curRegion(regions.begin()), endRegion(regions.end()), reader(reader), isIsolatedFromAbove(isIsolatedFromAbove) {} /// The current regions being read. MutableArrayRef::iterator curRegion, endRegion; + /// This is the reader to use for this region, this pointer is pointing to + /// the parent region reader unless the current region is IsolatedFromAbove, + /// in which case the pointer is pointing to the `owningReader` which is a + /// section dedicated to the current region. + EncodingReader *reader; + std::unique_ptr owningReader; /// The number of values defined immediately within this region. unsigned numValues = 0; @@ -1168,15 +1225,15 @@ }; LogicalResult parseIRSection(ArrayRef sectionData, Block *block); - LogicalResult parseRegions(EncodingReader &reader, - std::vector ®ionStack, + LogicalResult parseRegions(std::vector ®ionStack, RegionReadState &readState); FailureOr parseOpWithoutRegions(EncodingReader &reader, RegionReadState &readState, bool &isIsolatedFromAbove); - LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState); - LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); + LogicalResult parseRegion(RegionReadState &readState); + LogicalResult parseBlockHeader(EncodingReader &reader, + RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); //===--------------------------------------------------------------------===// @@ -1226,6 +1283,15 @@ /// A location to use when emitting errors. Location fileLoc; + /// Flag that indicates if lazyloading is enabled. + bool lazyLoading; + + /// Keep track of operations that have been lazy loaded (their regions haven't + /// been materialized), along with the `RegionReadState` that allows to + /// lazy-load the regions nested under the operation. + LazyLoadableOpsInfo lazyLoadableOps; + LazyLoadableOpsMap lazyLoadableOpsMap; + /// The reader used to process attribute and types within the bytecode. AttrTypeReader attrTypeReader; @@ -1256,13 +1322,15 @@ /// An operation state used when instantiating forward references. OperationState forwardRefOpState; + /// Reference to the input buffer. + llvm::MemoryBufferRef buffer; + /// The optional owning source manager, which when present may be used to /// extend the lifetime of the input buffer. const std::shared_ptr &bufferOwnerRef; }; -} // namespace -LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { +LogicalResult BytecodeReader::Impl::read(Block *block) { EncodingReader reader(buffer.getBuffer(), fileLoc); // Skip over the bytecode header, this should have already been checked. @@ -1294,7 +1362,7 @@ // Check for duplicate sections, we only expect one instance of each. if (sectionDatas[sectionID]) { return reader.emitError("duplicate top-level section: ", - toString(sectionID)); + ::toString(sectionID)); } sectionDatas[sectionID] = sectionData; } @@ -1303,7 +1371,7 @@ bytecode::Section::ID sectionID = static_cast(i); if (!sectionDatas[i] && !isSectionOptional(sectionID)) { return reader.emitError("missing data for top-level section: ", - toString(sectionID)); + ::toString(sectionID)); } } @@ -1332,7 +1400,7 @@ return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); } -LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { +LogicalResult BytecodeReader::Impl::parseVersion(EncodingReader &reader) { if (failed(reader.parseVarInt(version))) return failure(); @@ -1349,6 +1417,9 @@ " is newer than the current version ", currentVersion); } + // Override any request to lazy-load if the bytecode version is too old. + if (version < 2) + lazyLoading = false; return success(); } @@ -1386,7 +1457,7 @@ } LogicalResult -BytecodeReader::parseDialectSection(ArrayRef sectionData) { +BytecodeReader::Impl::parseDialectSection(ArrayRef sectionData) { EncodingReader sectionReader(sectionData, fileLoc); // Parse the number of dialects in the section. @@ -1439,7 +1510,8 @@ return success(); } -FailureOr BytecodeReader::parseOpName(EncodingReader &reader) { +FailureOr +BytecodeReader::Impl::parseOpName(EncodingReader &reader) { BytecodeOperationName *opName = nullptr; if (failed(parseEntry(reader, opNames, opName, "operation name"))) return failure(); @@ -1462,7 +1534,7 @@ //===----------------------------------------------------------------------===// // Resource Section -LogicalResult BytecodeReader::parseResourceSection( +LogicalResult BytecodeReader::Impl::parseResourceSection( EncodingReader &reader, std::optional> resourceData, std::optional> resourceOffsetData) { // Ensure both sections are either present or not. @@ -1490,8 +1562,9 @@ //===----------------------------------------------------------------------===// // IR Section -LogicalResult BytecodeReader::parseIRSection(ArrayRef sectionData, - Block *block) { +LogicalResult +BytecodeReader::Impl::parseIRSection(ArrayRef sectionData, + Block *block) { EncodingReader reader(sectionData, fileLoc); // A stack of operation regions currently being read from the bytecode. @@ -1499,17 +1572,17 @@ // Parse the top-level block using a temporary module operation. OwningOpRef moduleOp = ModuleOp::create(fileLoc); - regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true); + regionStack.emplace_back(*moduleOp, &reader, /*isIsolatedFromAbove=*/true); regionStack.back().curBlocks.push_back(moduleOp->getBody()); regionStack.back().curBlock = regionStack.back().curRegion->begin(); - if (failed(parseBlock(reader, regionStack.back()))) + if (failed(parseBlockHeader(reader, regionStack.back()))) return failure(); valueScopes.emplace_back(); valueScopes.back().push(regionStack.back()); // Iteratively parse regions until everything has been resolved. while (!regionStack.empty()) - if (failed(parseRegions(reader, regionStack, regionStack.back()))) + if (failed(parseRegions(regionStack, regionStack.back()))) return failure(); if (!forwardRefOps.empty()) { return reader.emitError( @@ -1540,15 +1613,19 @@ } LogicalResult -BytecodeReader::parseRegions(EncodingReader &reader, - std::vector ®ionStack, - RegionReadState &readState) { +BytecodeReader::Impl::parseRegions(std::vector ®ionStack, + RegionReadState &readState) { // Read the regions of this operation. + // Process regions, blocks, and operations until the end or if a nested + // region is encountered. In this case we push a new state in regionStack and + // return, the processing of the current region will resume afterward. for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { // If the current block hasn't been setup yet, parse the header for this - // region. + // region. The current block is already setup when this function was + // interrupted to recurse down in a nested region and we resume the current + // block after processing the nested region. if (readState.curBlock == Region::iterator()) { - if (failed(parseRegion(reader, readState))) + if (failed(parseRegion(readState))) return failure(); // If the region is empty, there is nothing to more to do. @@ -1557,6 +1634,7 @@ } // Parse the blocks within the region. + EncodingReader &reader = *readState.reader; do { while (readState.numOpsRemaining--) { // Read in the next operation. We don't read its regions directly, we @@ -1567,9 +1645,35 @@ if (failed(op)) return failure(); - // If the op has regions, add it to the stack for processing. + // If the op has regions, add it to the stack for processing and return: + // we stop the processing of the current region and resume it after the + // inner one is completed. Unless LazyLoading is activated in which case + // nested region parsing is delayed. if ((*op)->getNumRegions()) { - regionStack.emplace_back(*op, isIsolatedFromAbove); + RegionReadState childState(*op, &reader, isIsolatedFromAbove); + // Isolated regions are encoded as a section in version 2 and above. + if (version >= 2 && isIsolatedFromAbove) { + bytecode::Section::ID sectionID; + ArrayRef sectionData; + if (failed(reader.parseSection(sectionID, sectionData))) + return failure(); + if (sectionID != bytecode::Section::kIR) { + emitError(fileLoc, "expected IR section for region"); + return failure(); + } + childState.owningReader = + std::make_unique(sectionData, fileLoc); + childState.reader = childState.owningReader.get(); + } + + if (lazyLoading) { + lazyLoadableOps.push_back( + std::make_pair(*op, std::move(childState))); + lazyLoadableOpsMap.try_emplace(*op, + std::prev(lazyLoadableOps.end())); + continue; + } + regionStack.push_back(std::move(childState)); // If the op is isolated from above, push a new value scope. if (isIsolatedFromAbove) @@ -1581,7 +1685,7 @@ // Move to the next block of the region. if (++readState.curBlock == readState.curRegion->end()) break; - if (failed(parseBlock(reader, readState))) + if (failed(parseBlockHeader(reader, readState))) return failure(); } while (true); @@ -1592,16 +1696,19 @@ // When the regions have been fully parsed, pop them off of the read stack. If // the regions were isolated from above, we also pop the last value scope. - if (readState.isIsolatedFromAbove) + if (readState.isIsolatedFromAbove) { + assert(!valueScopes.empty()); valueScopes.pop_back(); + } + assert(!regionStack.empty()); regionStack.pop_back(); return success(); } FailureOr -BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, - RegionReadState &readState, - bool &isIsolatedFromAbove) { +BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, + RegionReadState &readState, + bool &isIsolatedFromAbove) { // Parse the name of the operation. FailureOr opName = parseOpName(reader); if (failed(opName)) @@ -1687,8 +1794,9 @@ return op; } -LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, - RegionReadState &readState) { +LogicalResult BytecodeReader::Impl::parseRegion(RegionReadState &readState) { + EncodingReader &reader = *readState.reader; + // Parse the number of blocks in the region. uint64_t numBlocks; if (failed(reader.parseVarInt(numBlocks))) @@ -1718,11 +1826,12 @@ // Parse the entry block of the region. readState.curBlock = readState.curRegion->begin(); - return parseBlock(reader, readState); + return parseBlockHeader(reader, readState); } -LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, - RegionReadState &readState) { +LogicalResult +BytecodeReader::Impl::parseBlockHeader(EncodingReader &reader, + RegionReadState &readState) { bool hasArgs; if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) return failure(); @@ -1735,8 +1844,8 @@ return success(); } -LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, - Block *block) { +LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader, + Block *block) { // Parse the value ID for the first argument, and the number of arguments. uint64_t numArgs; if (failed(reader.parseVarInt(numArgs))) @@ -1764,7 +1873,7 @@ //===----------------------------------------------------------------------===// // Value Processing -Value BytecodeReader::parseOperand(EncodingReader &reader) { +Value BytecodeReader::Impl::parseOperand(EncodingReader &reader) { std::vector &values = valueScopes.back().values; Value *value = nullptr; if (failed(parseEntry(reader, values, value, "value"))) @@ -1776,8 +1885,8 @@ return *value; } -LogicalResult BytecodeReader::defineValues(EncodingReader &reader, - ValueRange newValues) { +LogicalResult BytecodeReader::Impl::defineValues(EncodingReader &reader, + ValueRange newValues) { ValueScope &valueScope = valueScopes.back(); std::vector &values = valueScope.values; @@ -1812,7 +1921,7 @@ return success(); } -Value BytecodeReader::createForwardRef() { +Value BytecodeReader::Impl::createForwardRef() { // Check for an avaliable existing operation to use. Otherwise, create a new // fake operation to use for the reference. if (!openForwardRefOps.empty()) { @@ -1828,6 +1937,44 @@ // Entry Points //===----------------------------------------------------------------------===// +BytecodeReader::~BytecodeReader() = default; + +LogicalResult BytecodeReader::readTopLevel(Block *block) { + return impl->read(block); +} +BytecodeReader::BytecodeReader( + llvm::MemoryBufferRef buffer, const ParserConfig &config, bool lazyLoading, + const std::shared_ptr &bufferOwnerRef) { + Location sourceFileLoc = + FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), + /*line=*/0, /*column=*/0); + impl = std::make_unique(sourceFileLoc, config, lazyLoading, buffer, + bufferOwnerRef); +} +std::function +BytecodeReader::getLazyOpMaterializer(Operation *op) { + return impl->getLazyOpMaterializer(op); +} + +int64_t BytecodeReader::getNumOpsToMaterialize() const { + return impl->getNumOpsToMaterialize(); +} + +Operation *BytecodeReader::getNextMaterializableOp() const { + return impl->getNextMaterializableOp(); +} + +LogicalResult BytecodeReader::materialize(Operation *op) { + auto materializer = getLazyOpMaterializer(op); + if (!materializer) + return success(); + return materializer(); +} + +LogicalResult BytecodeReader::materializeAll() { + return impl->materializeAll(); +} + bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { return buffer.getBuffer().startswith("ML\xefR"); } @@ -1847,8 +1994,9 @@ "input buffer is not an MLIR bytecode file"); } - BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef); - return reader.read(buffer, block); + BytecodeReader::Impl reader(sourceFileLoc, config, /*lazyLoading=*/false, + buffer, bufferOwnerRef); + return reader.read(block); } LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -734,8 +734,18 @@ bool isIsolatedFromAbove = op->hasTrait(); emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove); - for (Region ®ion : op->getRegions()) - writeRegion(emitter, ®ion); + for (Region ®ion : op->getRegions()) { + // If the region is not isolated from above, or we are emitting bytecode + // targetting version <2, we don't use a section. + if (!isIsolatedFromAbove || config.bytecodeVersion < 2) { + writeRegion(emitter, ®ion); + continue; + } + + EncodingEmitter regionEmitter; + writeRegion(regionEmitter, ®ion); + emitter.emitSection(bytecode::Section::kIR, std::move(regionEmitter)); + } } } diff --git a/mlir/test/Bytecode/bytecode-lazy-loading.mlir b/mlir/test/Bytecode/bytecode-lazy-loading.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode-lazy-loading.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading)" %s -o %t | FileCheck %s +// RUN: mlir-opt --pass-pipeline="builtin.module(test-lazy-loading{bytecode-version=1})" %s -o %t | FileCheck %s --check-prefix=OLD-BYTECODE + + +func.func @op_with_passthrough_region_args() { + %0 = arith.constant 10 : index + test.isolated_region %0 { + "test.consumer"(%0) : (index) -> () + } + %result:2 = "test.op"() : () -> (index, index) + test.isolated_region %result#1 { + "test.consumer"(%result#1) : (index) -> () + } + return +} + +// Before version 2, we can't support lazy loading. +// OLD-BYTECODE-NOT: Has 1 ops to materialize +// OLD-BYTECODE-NOT: Materializing +// OLD-BYTECODE: Has 0 ops to materialize + + +// CHECK: Has 1 ops to materialize + +// CHECK: Before Materializing... +// CHECK: "builtin.module"() ({ +// CHECK-NOT: func +// CHECK: Materializing... +// CHECK: "builtin.module"() ({ +// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({ +// CHECK-NOT: arith +// CHECK: Has 1 ops to materialize + +// CHECK: Before Materializing... +// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({ +// CHECK-NOT: arith +// CHECK: Materializing... +// CHECK: "func.func"() <{function_type = () -> (), sym_name = "op_with_passthrough_region_args"}> ({ +// CHECK: arith +// CHECK: isolated_region +// CHECK-NOT: test.consumer +// CHECK: Has 2 ops to materialize + +// CHECK: Before Materializing... +// CHECK: test.isolated_region +// CHECK-NOT: test.consumer +// CHECK: Materializing... +// CHECK: test.isolated_region +// CHECK: ^bb0(%arg0: index): +// CHECK: test.consumer +// CHECK: Has 1 ops to materialize + +// CHECK: Before Materializing... +// CHECK: test.isolated_region +// CHECK-NOT: test.consumer +// CHECK: Materializing... +// CHECK: test.isolated_region +// CHECK: test.consumer +// CHECK: Has 0 ops to materialize diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 1 +// VERSION: bytecode version 127 is newer than the current version 2 //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -7,6 +7,7 @@ TestFunc.cpp TestInterfaces.cpp TestMatchers.cpp + TestLazyLoading.cpp TestOpaqueLoc.cpp TestOperationEquals.cpp TestPrintDefUse.cpp diff --git a/mlir/test/lib/IR/TestLazyLoading.cpp b/mlir/test/lib/IR/TestLazyLoading.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestLazyLoading.cpp @@ -0,0 +1,85 @@ +//===- TestLazyLoading.cpp - Pass to test operation lazy loading ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +namespace { + +/// This is a test pass which LazyLoads the current operation recursively. +struct LazyLoadingPass : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LazyLoadingPass) + + StringRef getArgument() const final { return "test-lazy-loading"; } + StringRef getDescription() const final { return "Test LazyLoading of op"; } + LazyLoadingPass() = default; + LazyLoadingPass(const LazyLoadingPass &) {} + + void runOnOperation() override { + Operation *op = getOperation(); + std::string bytecode; + { + BytecodeWriterConfig config; + if (version >= 0) + config.setDesiredBytecodeVersion(version); + llvm::raw_string_ostream os(bytecode); + BytecodeWriterResult result = writeBytecodeToFile(op, os, config); + if (version >= 0 && result.minVersion != version) { + op->emitError() << "Failed to write bytecode at version " + << (int)version << ", emitted minimum version " + << result.minVersion << " instead."; + signalPassFailure(); + return; + } + } + llvm::MemoryBufferRef buffer(bytecode, "test-lazy-loading"); + Block block; + ParserConfig config(op->getContext(), /*verifyAfterParse=*/false); + BytecodeReader reader(buffer, config, + /*lazyLoad=*/true); + if (failed(reader.readTopLevel(&block))) { + op->emitError() << "Failed to read bytecode"; + return; + } + + llvm::outs() << "Has " << reader.getNumOpsToMaterialize() + << " ops to materialize\n"; + + // Recursively print the operations, before and after lazy loading. + while (Operation *op = reader.getNextMaterializableOp()) { + llvm::outs() << "\n\nBefore Materializing...\n\n"; + op->print(llvm::outs()); + llvm::outs() << "\n\nMaterializing...\n\n"; + if (failed(reader.materialize(op))) { + op->emitError() << "Failed to materialize"; + signalPassFailure(); + return; + } + op->print(llvm::outs()); + llvm::outs() << "\n"; + llvm::outs() << "Has " << reader.getNumOpsToMaterialize() + << " ops to materialize\n"; + } + } + Option version{*this, "bytecode-version", + llvm::cl::desc("Specifies the bytecode version to use."), + llvm::cl::init(-1)}; +}; +} // namespace + +namespace mlir { +void registerLazyLoadingTestPasses() { PassRegistration(); } +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -31,6 +31,7 @@ namespace mlir { void registerConvertToTargetEnvPass(); void registerCloneTestPasses(); +void registerLazyLoadingTestPasses(); void registerPassManagerTestPass(); void registerPrintSpirvAvailabilityPass(); void registerLoopLikeInterfaceTestPasses(); @@ -146,6 +147,7 @@ registerConvertToTargetEnvPass(); registerPassManagerTestPass(); registerPrintSpirvAvailabilityPass(); + registerLazyLoadingTestPasses(); registerLoopLikeInterfaceTestPasses(); registerShapeFunctionTestPasses(); registerSideEffectTestPasses();