diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h --- a/mlir/include/mlir/Bytecode/BytecodeReader.h +++ b/mlir/include/mlir/Bytecode/BytecodeReader.h @@ -18,6 +18,7 @@ namespace llvm { class MemoryBufferRef; +class SourceMgr; } // namespace llvm namespace mlir { @@ -29,6 +30,12 @@ /// bytecode, into the provided block. LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, const ParserConfig &config); +/// An overload with a source manager whose main file buffer is used for +/// parsing. The lifetime of the source manager may be freely extended during +/// parsing such that the source manager is not destroyed before the parsed IR. +LogicalResult +readBytecodeFile(const std::shared_ptr &sourceMgr, + Block *block, const ParserConfig &config); } // namespace mlir #endif // MLIR_BYTECODE_BYTECODEREADER_H diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -215,17 +215,19 @@ /// Create a new unmanaged resource directly referencing the provided data. /// `dataIsMutable` indicates if the allocated data can be mutated. By /// default, we treat unmanaged blobs as immutable. - static AsmResourceBlob allocateWithAlign(ArrayRef data, size_t align, - bool dataIsMutable = false) { - return AsmResourceBlob(data, align, /*deleter=*/{}, - /*dataIsMutable=*/false); + static AsmResourceBlob + allocateWithAlign(ArrayRef data, size_t align, + AsmResourceBlob::DeleterFn deleter = {}, + bool dataIsMutable = false) { + return AsmResourceBlob(data, align, std::move(deleter), dataIsMutable); } template - static AsmResourceBlob allocateInferAlign(ArrayRef data, - bool dataIsMutable = false) { + static AsmResourceBlob + allocateInferAlign(ArrayRef data, AsmResourceBlob::DeleterFn deleter = {}, + bool dataIsMutable = false) { return allocateWithAlign( ArrayRef((const char *)data.data(), data.size() * sizeof(T)), - alignof(T)); + alignof(T), std::move(deleter), dataIsMutable); } }; diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -93,6 +93,14 @@ LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc = nullptr); +/// An overload with a source manager that may have references taken during the +/// parsing process, and whose lifetime can be freely extended (such that the +/// source manager is not destroyed before the parsed IR). This is useful, for +/// example, to avoid copying some large resources into the MLIRContext and +/// instead referencing the data directly from the input buffers. +LogicalResult parseSourceFile(const std::shared_ptr &sourceMgr, + Block *block, const ParserConfig &config, + LocationAttr *sourceFileLoc = nullptr); /// This parses the file specified by the indicated filename and appends parsed /// operations to the given block. If the block is non-empty, the operations are @@ -116,6 +124,15 @@ llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc = nullptr); +/// An overload with a source manager that may have references taken during the +/// parsing process, and whose lifetime can be freely extended (such that the +/// source manager is not destroyed before the parsed IR). This is useful, for +/// example, to avoid copying some large resources into the MLIRContext and +/// instead referencing the data directly from the input buffers. +LogicalResult parseSourceFile(llvm::StringRef filename, + const std::shared_ptr &sourceMgr, + Block *block, const ParserConfig &config, + LocationAttr *sourceFileLoc = nullptr); /// This parses the IR string and appends parsed operations to the given block. /// If the block is non-empty, the operations are placed before the current @@ -157,6 +174,17 @@ parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) { return detail::parseSourceFile(config, sourceMgr); } +/// An overload with a source manager that may have references taken during the +/// parsing process, and whose lifetime can be freely extended (such that the +/// source manager is not destroyed before the parsed IR). This is useful, for +/// example, to avoid copying some large resources into the MLIRContext and +/// instead referencing the data directly from the input buffers. +template +inline OwningOpRef +parseSourceFile(const std::shared_ptr &sourceMgr, + const ParserConfig &config) { + return detail::parseSourceFile(config, sourceMgr); +} /// This parses the file specified by the indicated filename. If the source IR /// contained a single instance of `ContainerOpT`, it is returned. Otherwise, a @@ -186,6 +214,18 @@ const ParserConfig &config) { return detail::parseSourceFile(config, filename, sourceMgr); } +/// An overload with a source manager that may have references taken during the +/// parsing process, and whose lifetime can be freely extended (such that the +/// source manager is not destroyed before the parsed IR). This is useful, for +/// example, to avoid copying some large resources into the MLIRContext and +/// instead referencing the data directly from the input buffers. +template +inline OwningOpRef +parseSourceFile(llvm::StringRef filename, + const std::shared_ptr &sourceMgr, + const ParserConfig &config) { + return detail::parseSourceFile(config, filename, sourceMgr); +} /// This parses the provided string containing MLIR. If the source IR contained /// a single instance of `ContainerOpT`, it is returned. Otherwise, a new diff --git a/mlir/include/mlir/Tools/ParseUtilities.h b/mlir/include/mlir/Tools/ParseUtilities.h --- a/mlir/include/mlir/Tools/ParseUtilities.h +++ b/mlir/include/mlir/Tools/ParseUtilities.h @@ -24,8 +24,8 @@ /// If 'insertImplicitModule' is true a top-level 'builtin.module' op will be /// inserted that contains the parsed IR, unless one exists already. inline OwningOpRef -parseSourceFileForTool(llvm::SourceMgr &sourceMgr, const ParserConfig &config, - bool insertImplicitModule) { +parseSourceFileForTool(const std::shared_ptr &sourceMgr, + const ParserConfig &config, bool insertImplicitModule) { if (insertImplicitModule) { // TODO: Move implicit module logic out of 'parseSourceFile' and into here. return parseSourceFile(sourceMgr, config); diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h --- a/mlir/include/mlir/Tools/mlir-translate/Translation.h +++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h @@ -25,7 +25,10 @@ /// should create a new MLIR Operation in the given context and return a /// pointer to it, or a nullptr in case of any error. using TranslateSourceMgrToMLIRFunction = std::function( - llvm::SourceMgr &sourceMgr, MLIRContext *)>; + const std::shared_ptr &sourceMgr, MLIRContext *)>; +using TranslateRawSourceMgrToMLIRFunction = + std::function(llvm::SourceMgr &sourceMgr, + MLIRContext *)>; /// Interface of the function that translates the given string to MLIR. The /// implementation should create a new MLIR Operation in the given context. If @@ -45,7 +48,8 @@ /// all MLIR constructs needed during the process inside the given context. This /// can be used for round-tripping external formats through the MLIR system. using TranslateFunction = std::function; + const std::shared_ptr &sourceMgr, + llvm::raw_ostream &output, MLIRContext *)>; /// This class contains all of the components necessary for performing a /// translation. @@ -64,7 +68,7 @@ Optional getInputAlignment() const { return inputAlignment; } /// Invoke the translation function with the given input and output streams. - LogicalResult operator()(llvm::SourceMgr &sourceMgr, + LogicalResult operator()(const std::shared_ptr &sourceMgr, llvm::raw_ostream &output, MLIRContext *context) const { return function(sourceMgr, output, context); @@ -101,6 +105,10 @@ llvm::StringRef name, llvm::StringRef description, const TranslateSourceMgrToMLIRFunction &function, Optional inputAlignment = std::nullopt); + TranslateToMLIRRegistration( + llvm::StringRef name, llvm::StringRef description, + const TranslateRawSourceMgrToMLIRFunction &function, + Optional inputAlignment = std::nullopt); TranslateToMLIRRegistration( llvm::StringRef name, llvm::StringRef description, const TranslateStringRefToMLIRFunction &function, diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -23,6 +23,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" +#include "llvm/Support/SourceMgr.h" #include #define DEBUG_TYPE "mlir-bytecode-reader" @@ -492,11 +493,12 @@ class ResourceSectionReader { public: /// Initialize the resource section reader with the given section data. - LogicalResult initialize(Location fileLoc, const ParserConfig &config, - MutableArrayRef dialects, - StringSectionReader &stringReader, - ArrayRef sectionData, - ArrayRef offsetSectionData); + LogicalResult + initialize(Location fileLoc, const ParserConfig &config, + MutableArrayRef dialects, + StringSectionReader &stringReader, ArrayRef sectionData, + ArrayRef offsetSectionData, + const std::shared_ptr &bufferOwnerRef); /// Parse a dialect resource handle from the resource section. LogicalResult parseResourceHandle(EncodingReader &reader, @@ -512,8 +514,10 @@ class ParsedResourceEntry : public AsmParsedResourceEntry { public: ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, - EncodingReader &reader, StringSectionReader &stringReader) - : key(key), kind(kind), reader(reader), stringReader(stringReader) {} + EncodingReader &reader, StringSectionReader &stringReader, + const std::shared_ptr &bufferOwnerRef) + : key(key), kind(kind), reader(reader), stringReader(stringReader), + bufferOwnerRef(bufferOwnerRef) {} ~ParsedResourceEntry() override = default; StringRef getKey() const final { return key; } @@ -554,11 +558,22 @@ if (failed(reader.parseBlobAndAlignment(data, alignment))) return failure(); + // If we have an extendable reference to the buffer owner, we don't need to + // allocate a new buffer for the data, and can use the data directly. + if (bufferOwnerRef) { + ArrayRef charData(reinterpret_cast(data.data()), + data.size()); + + // Allocate an unmanager buffer which captures a reference to the owner. + // For now we just mark this as immutable, but in the future we should + // explore marking this as mutable when desired. + return UnmanagedAsmResourceBlob::allocateWithAlign( + charData, alignment, + [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); + } + // Allocate memory for the blob using the provided allocator and copy the // data into it. - // FIXME: If the current holder of the bytecode can ensure its lifetime - // (e.g. when mmap'd), we should not copy the data. We should use the data - // from the bytecode directly. AsmResourceBlob blob = allocator(data.size(), alignment); assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && blob.isMutable() && @@ -572,6 +587,7 @@ AsmResourceEntryKind kind; EncodingReader &reader; StringSectionReader &stringReader; + const std::shared_ptr &bufferOwnerRef; }; } // namespace @@ -580,6 +596,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty, EncodingReader &offsetReader, EncodingReader &resourceReader, StringSectionReader &stringReader, T *handler, + const std::shared_ptr &bufferOwnerRef, function_ref processKeyFn = {}) { uint64_t numResources; if (failed(offsetReader.parseVarInt(numResources))) @@ -611,7 +628,8 @@ // Otherwise, parse the resource value. EncodingReader entryReader(data, fileLoc); - ParsedResourceEntry entry(key, kind, entryReader, stringReader); + ParsedResourceEntry entry(key, kind, entryReader, stringReader, + bufferOwnerRef); if (failed(handler->parseResource(entry))) return failure(); if (!entryReader.empty()) { @@ -622,12 +640,12 @@ return success(); } -LogicalResult -ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config, - MutableArrayRef dialects, - StringSectionReader &stringReader, - ArrayRef sectionData, - ArrayRef offsetSectionData) { +LogicalResult ResourceSectionReader::initialize( + Location fileLoc, const ParserConfig &config, + MutableArrayRef dialects, + StringSectionReader &stringReader, ArrayRef sectionData, + ArrayRef offsetSectionData, + const std::shared_ptr &bufferOwnerRef) { EncodingReader resourceReader(sectionData, fileLoc); EncodingReader offsetReader(offsetSectionData, fileLoc); @@ -641,7 +659,7 @@ auto parseGroup = [&](auto *handler, bool allowEmpty = false, function_ref keyFn = {}) { return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, - stringReader, handler, keyFn); + stringReader, handler, bufferOwnerRef, keyFn); }; // Read the external resources from the bytecode. @@ -1058,14 +1076,16 @@ /// This class is used to read a bytecode buffer and translate it into MLIR. class BytecodeReader { public: - BytecodeReader(Location fileLoc, const ParserConfig &config) + BytecodeReader(Location fileLoc, const ParserConfig &config, + const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, resourceReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), "builtin.unrealized_conversion_cast", ValueRange(), - NoneType::get(config.getContext())) {} + NoneType::get(config.getContext())), + bufferOwnerRef(bufferOwnerRef) {} /// Read the bytecode defined within `buffer` into the given block. LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); @@ -1222,6 +1242,10 @@ Block openForwardRefOps; /// An operation state used when instantiating forward references. OperationState forwardRefOpState; + + /// The optional owning source manager, which when present may be used to + /// extend the lifetime of the input buffer. + const std::shared_ptr &bufferOwnerRef; }; } // namespace @@ -1383,7 +1407,8 @@ // Initialize the resource reader with the resource sections. return resourceReader.initialize(fileLoc, config, dialects, stringReader, - *resourceData, *resourceOffsetData); + *resourceData, *resourceOffsetData, + bufferOwnerRef); } //===----------------------------------------------------------------------===// @@ -1719,8 +1744,13 @@ return buffer.getBuffer().startswith("ML\xefR"); } -LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, - const ParserConfig &config) { +/// Read the bytecode from the provided memory buffer reference. +/// `bufferOwnerRef` if provided is the owning source manager for the buffer, +/// and may be used to extend the lifetime of the buffer. +static LogicalResult +readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, + const ParserConfig &config, + const std::shared_ptr &bufferOwnerRef) { Location sourceFileLoc = FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), /*line=*/0, /*column=*/0); @@ -1729,6 +1759,18 @@ "input buffer is not an MLIR bytecode file"); } - BytecodeReader reader(sourceFileLoc, config); + BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef); return reader.read(buffer, block); } + +LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, + const ParserConfig &config) { + return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); +} +LogicalResult +mlir::readBytecodeFile(const std::shared_ptr &sourceMgr, + Block *block, const ParserConfig &config) { + return readBytecodeFileImpl( + *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, + sourceMgr); +} diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -129,8 +129,8 @@ return nullptr; } - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc()); OwningOpRef module = parseSourceFileForTool(sourceMgr, context, insertImplicitModule); if (!module) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -30,30 +30,60 @@ return readBytecodeFile(*sourceBuf, block, config); return parseAsmSourceFile(sourceMgr, block, config); } +LogicalResult +mlir::parseSourceFile(const std::shared_ptr &sourceMgr, + Block *block, const ParserConfig &config, + LocationAttr *sourceFileLoc) { + const auto *sourceBuf = + sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()); + if (sourceFileLoc) { + *sourceFileLoc = FileLineColLoc::get(config.getContext(), + sourceBuf->getBufferIdentifier(), + /*line=*/0, /*column=*/0); + } + if (isBytecode(*sourceBuf)) + return readBytecodeFile(sourceMgr, block, config); + return parseAsmSourceFile(*sourceMgr, block, config); +} LogicalResult mlir::parseSourceFile(llvm::StringRef filename, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) { - llvm::SourceMgr sourceMgr; + auto sourceMgr = std::make_shared(); return parseSourceFile(filename, sourceMgr, block, config, sourceFileLoc); } -LogicalResult mlir::parseSourceFile(llvm::StringRef filename, - llvm::SourceMgr &sourceMgr, Block *block, - const ParserConfig &config, - LocationAttr *sourceFileLoc) { +static LogicalResult loadSourceFileBuffer(llvm::StringRef filename, + llvm::SourceMgr &sourceMgr, + MLIRContext *ctx) { if (sourceMgr.getNumBuffers() != 0) { // TODO: Extend to support multiple buffers. - return emitError(mlir::UnknownLoc::get(config.getContext()), + return emitError(mlir::UnknownLoc::get(ctx), "only main buffer parsed at the moment"); } auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code error = fileOrErr.getError()) - return emitError(mlir::UnknownLoc::get(config.getContext()), + return emitError(mlir::UnknownLoc::get(ctx), "could not open input file " + filename); // Load the MLIR source file. sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc()); + return success(); +} + +LogicalResult mlir::parseSourceFile(llvm::StringRef filename, + llvm::SourceMgr &sourceMgr, Block *block, + const ParserConfig &config, + LocationAttr *sourceFileLoc) { + if (failed(loadSourceFileBuffer(filename, sourceMgr, config.getContext()))) + return failure(); + return parseSourceFile(sourceMgr, block, config, sourceFileLoc); +} +LogicalResult mlir::parseSourceFile( + llvm::StringRef filename, const std::shared_ptr &sourceMgr, + Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) { + if (failed(loadSourceFileBuffer(filename, *sourceMgr, config.getContext()))) + return failure(); return parseSourceFile(sourceMgr, block, config, sourceFileLoc); } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -46,11 +46,11 @@ /// This typically parses the main source file, runs zero or more optimization /// passes, then prints the output. /// -static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, - bool verifyPasses, SourceMgr &sourceMgr, - MLIRContext *context, - PassPipelineFn passManagerSetupFn, - bool emitBytecode, bool implicitModule) { +static LogicalResult +performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, + const std::shared_ptr &sourceMgr, + MLIRContext *context, PassPipelineFn passManagerSetupFn, + bool emitBytecode, bool implicitModule) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -115,8 +115,8 @@ PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. - SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Create a context just for the current buffer. Disable threading on creation // since we'll inject the thread-pool separately. @@ -135,13 +135,13 @@ // If we are in verify diagnostics mode then we have a lot of work to do, // otherwise just perform the actions without worrying about it. if (!verifyDiagnostics) { - SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, passManagerSetupFn, emitBytecode, implicitModule); } - SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); + SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -41,8 +41,8 @@ return nullptr; } - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc()); return parseSourceFileForTool(sourceMgr, &context, insertImplictModule); } diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -87,18 +87,18 @@ MLIRContext context; context.allowUnregisteredDialects(allowUnregisteredDialects); context.printOpOnDiagnostic(!verifyDiagnostics); - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); if (!verifyDiagnostics) { - SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); + SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context); return (*translationRequested)(sourceMgr, os, &context); } // In the diagnostic verification flow, we ignore whether the translation // failed (in most cases, it is expected to fail). Instead, we check if the // diagnostics were produced as expected. - SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); + SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context); (void)(*translationRequested)(sourceMgr, os, &context); return sourceMgrHandler.verify(); }; diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp --- a/mlir/lib/Tools/mlir-translate/Translation.cpp +++ b/mlir/lib/Tools/mlir-translate/Translation.cpp @@ -75,8 +75,8 @@ static void registerTranslateToMLIRFunction( StringRef name, StringRef description, Optional inputAlignment, const TranslateSourceMgrToMLIRFunction &function) { - auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, - MLIRContext *context) { + auto wrappedFn = [function](const std::shared_ptr &sourceMgr, + raw_ostream &output, MLIRContext *context) { OwningOpRef op = function(sourceMgr, context); if (!op || failed(verify(*op))) return failure(); @@ -92,6 +92,15 @@ Optional inputAlignment) { registerTranslateToMLIRFunction(name, description, inputAlignment, function); } +TranslateToMLIRRegistration::TranslateToMLIRRegistration( + StringRef name, StringRef description, + const TranslateRawSourceMgrToMLIRFunction &function, + Optional inputAlignment) { + registerTranslateToMLIRFunction( + name, description, inputAlignment, + [function](const std::shared_ptr &sourceMgr, + MLIRContext *ctx) { return function(*sourceMgr, ctx); }); +} /// Wraps `function` with a lambda that extracts a StringRef from a source /// manager and registers the wrapper lambda as a to-MLIR conversion. TranslateToMLIRRegistration::TranslateToMLIRRegistration( @@ -100,9 +109,10 @@ Optional inputAlignment) { registerTranslateToMLIRFunction( name, description, inputAlignment, - [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { + [function](const std::shared_ptr &sourceMgr, + MLIRContext *ctx) { const llvm::MemoryBuffer *buffer = - sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); + sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()); return function(buffer->getBuffer(), ctx); }); } @@ -117,9 +127,9 @@ const std::function &dialectRegistration) { registerTranslation( name, description, /*inputAlignment=*/std::nullopt, - [function, dialectRegistration](llvm::SourceMgr &sourceMgr, - raw_ostream &output, - MLIRContext *context) { + [function, + dialectRegistration](const std::shared_ptr &sourceMgr, + raw_ostream &output, MLIRContext *context) { DialectRegistry registry; dialectRegistration(registry); context->appendDialectRegistry(registry);