diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -49,6 +49,34 @@ /// Get the set desired bytecode version to emit. int64_t getDesiredBytecodeVersion() const; + //===--------------------------------------------------------------------===// + // Types and Attributes encoding + //===--------------------------------------------------------------------===// + + /// Retrieve the dialect to callback map. + llvm::StringMap> & + getAttrTypePrinterCallbacks() const; + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings of a given dialect. + void + attachAttrTypeCallback(StringRef dialectName, + std::unique_ptr callback); + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings of a given dialect. + template + std::enable_if_t< + std::is_convertible_v> && + std::is_convertible_v< + CallableT, + std::function>> + attachAttrTypeCallback(StringRef name, CallableT &&emitFn) { + attachAttrTypeCallback(name, AsmAttrTypeBytecodePrinter::fromCallable( + std::forward(emitFn))); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -24,12 +24,114 @@ namespace mlir { class AsmResourcePrinter; class AsmDialectResourceHandle; +class DialectBytecodeWriter; +class DialectBytecodeReader; +class DialectVersion; class Operation; namespace detail { class AsmStateImpl; } // namespace detail +//===----------------------------------------------------------------------===// +// AsmAttrTypeBytecode Parser/Printer +//===----------------------------------------------------------------------===// + +/// A class to interact with the attributes and types printer when emitting MLIR +/// bytecode. +class AsmAttrTypeBytecodePrinter { +public: + AsmAttrTypeBytecodePrinter() = default; + virtual ~AsmAttrTypeBytecodePrinter() = default; + + virtual LogicalResult write(Type entry, DialectBytecodeWriter &writer) = 0; + virtual LogicalResult write(Attribute entry, + DialectBytecodeWriter &writer) = 0; + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of `write` functions above. + template > && + std::is_convertible_v< + CallableT, std::function>, + bool> = true> + static std::unique_ptr + fromCallable(CallableT &&writeFn) { + struct Processor : public AsmAttrTypeBytecodePrinter { + Processor(CallableT &&writeFn) + : AsmAttrTypeBytecodePrinter(), _writeFn(std::move(writeFn)) {} + LogicalResult write(Type entry, DialectBytecodeWriter &writer) override { + return _writeFn(entry, writer); + } + LogicalResult write(Attribute entry, + DialectBytecodeWriter &writer) override { + return _writeFn(entry, writer); + } + + std::decay_t _writeFn; + }; + return std::make_unique(std::forward(writeFn)); + } +}; + +/// A class to interact with the attributes and types parser when emitting MLIR +/// bytecode. +class AsmAttrTypeBytecodeParser { +public: + AsmAttrTypeBytecodeParser() = default; + virtual ~AsmAttrTypeBytecodeParser() = default; + + virtual void parseType(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) = 0; + virtual void + parseAttribute(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Attribute &entry) = 0; + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of `parseType` and `parseAttribute` functions + /// above. + template < + typename CallableT, + std::enable_if_t< + std::is_convertible_v< + CallableT, + std::function &, + Type &)>> && + std::is_convertible_v< + CallableT, + std::function &, + Attribute &)>>, + bool> = true> + static std::unique_ptr + fromCallable(CallableT &&parseFn) { + struct Processor : public AsmAttrTypeBytecodeParser { + Processor(CallableT &&parseFn) + : AsmAttrTypeBytecodeParser(), _parseFn(std::move(parseFn)) {} + void parseType(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) override { + return _parseFn(ctx, reader, versionMap, entry); + } + void parseAttribute(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Attribute &entry) override { + return _parseFn(ctx, reader, versionMap, entry); + } + + std::decay_t _parseFn; + }; + return std::make_unique(std::forward(parseFn)); + } +}; + //===----------------------------------------------------------------------===// // Resources //===----------------------------------------------------------------------===// @@ -474,6 +576,36 @@ /// Returns if the parser should verify the IR after parsing. bool shouldVerifyAfterParse() const { return verifyAfterParse; } + /// Returns the map of callbacks for the parser. + auto &getAttrTypeBytecodeCallback() const { return attrTypeBytecodeParsers; } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings of a given dialect. + void attachAttrTypeBytecodeCallback( + StringRef dialectName, + std::unique_ptr parser) { + attrTypeBytecodeParsers.try_emplace(dialectName, std::move(parser)); + } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings of a given dialect. + template + std::enable_if_t< + std::is_convertible_v< + CallableT, std::function &, Type &)>> && + std::is_convertible_v< + CallableT, + std::function &, + Attribute &)>>> + attachAttrTypeBytecodeCallback(StringRef name, CallableT &&parserFn) { + attachAttrTypeBytecodeCallback(name, + AsmAttrTypeBytecodeParser::fromCallable( + std::forward(parserFn))); + } + /// Return the resource parser registered to the given name, or nullptr if no /// parser with `name` is registered. AsmResourceParser *getResourceParser(StringRef name) const { @@ -508,6 +640,10 @@ bool verifyAfterParse; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; + + // A string map between dialect names and attribute/types parser callbacks. + llvm::StringMap> + attrTypeBytecodeParsers; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -790,9 +790,10 @@ public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc) + ResourceSectionReader &resourceReader, Location fileLoc, + const ParserConfig &config) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc) {} + fileLoc(fileLoc), parserConfig(config) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -877,6 +878,12 @@ /// A location used for error emission. Location fileLoc; + + /// A map to retrieve parsed dialect versions associated to each dialect name. + llvm::StringMap dialectsVersionMap; + + /// Reference to the parser configuration. + const ParserConfig &parserConfig; }; class DialectReader : public DialectBytecodeReader { @@ -1132,6 +1139,18 @@ return offsetReader.emitError( "unexpected trailing data in the Attribute/Type offset section"); } + + // Fill up the dialect to dialectVersion map for every dialect version + // available. + for (auto &dialect : dialects) { + EncodingReader encReader(dialect.versionBuffer, fileLoc); + DialectReader dialectReader(*this, stringReader, resourceReader, encReader); + if (failed(dialect.load(dialectReader, fileLoc.getContext()))) + return failure(); + if (dialect.loadedVersion.get()) + dialectsVersionMap.insert({dialect.name, dialect.loadedVersion.get()}); + } + return success(); } @@ -1200,6 +1219,24 @@ DialectReader dialectReader(*this, stringReader, resourceReader, reader); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); + + // If a callback is available for the current dialect entry, try parsing with + // it. + auto &callbackMap = parserConfig.getAttrTypeBytecodeCallback(); + if (callbackMap.contains(entry.dialect->name)) { + auto &callback = callbackMap.at(entry.dialect->name); + if constexpr (std::is_same_v) + callback->parseType(fileLoc.getContext(), dialectReader, + dialectsVersionMap, entry.entry); + else + callback->parseAttribute(fileLoc.getContext(), dialectReader, + dialectsVersionMap, entry.entry); + + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + } + // Ensure that the dialect implements the bytecode interface. if (!entry.dialect->interface) { return reader.emitError("dialect '", entry.dialect->name, @@ -1242,7 +1279,7 @@ llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc), + attrTypeReader(stringReader, resourceReader, fileLoc, config), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1509,6 +1546,7 @@ /// The table of IR units referenced within the bytecode file. SmallVector dialects; + llvm::StringMap dialectMap; SmallVector opNames; /// The reader used to process resources within the bytecode. diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -18,14 +18,9 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" -#include -#include -#include #include -#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -46,6 +41,11 @@ /// The producer of the bytecode. StringRef producer; + /// A map from dialect name to printer callback to emit custom type and + /// attribute encodings. + llvm::StringMap> + attrTypePrinterCallbacks; + /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; }; @@ -59,6 +59,17 @@ } BytecodeWriterConfig::~BytecodeWriterConfig() = default; +llvm::StringMap> & +BytecodeWriterConfig::getAttrTypePrinterCallbacks() const { + return impl->attrTypePrinterCallbacks; +} + +void BytecodeWriterConfig::attachAttrTypeCallback( + StringRef dialectName, + std::unique_ptr callback) { + impl->attrTypePrinterCallbacks.try_emplace(dialectName, std::move(callback)); +} + void BytecodeWriterConfig::attachResourcePrinter( std::unique_ptr printer) { impl->externalResourcePrinters.emplace_back(std::move(printer)); @@ -757,23 +768,36 @@ auto emitAttrOrType = [&](auto &entry) { auto entryValue = entry.getValue(); - // First, try to emit this entry using the dialect bytecode interface. bool hasCustomEncoding = false; - if (const BytecodeDialectInterface *interface = entry.dialect->interface) { + // TODO: We don't currently support custom encoded mutable types and + // attributes. + if (!entryValue.template hasTrait() && + !entryValue.template hasTrait()) { // The writer used when emitting using a custom bytecode encoding. DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, numberingState, stringSection); + if (config.attrTypePrinterCallbacks.count(entry.dialect->name)) { + auto &callBackHandler = + config.attrTypePrinterCallbacks.at(entry.dialect->name); + if (succeeded(callBackHandler->write(entryValue, dialectWriter))) + hasCustomEncoding = true; + } - if constexpr (std::is_same_v, Type>) { - // TODO: We don't currently support custom encoded mutable types. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeType(entryValue, dialectWriter)); - } else { - // TODO: We don't currently support custom encoded mutable attributes. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeAttribute(entryValue, dialectWriter)); + if (!hasCustomEncoding) { + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + // The writer used when emitting using a custom bytecode encoding. + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, + numberingState, stringSection); + if constexpr (std::is_same_v, + Type>) { + hasCustomEncoding = + succeeded(interface->writeType(entryValue, dialectWriter)); + } else { + hasCustomEncoding = + succeeded(interface->writeAttribute(entryValue, dialectWriter)); + } + } } } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -199,9 +199,18 @@ // If this attribute will be emitted using the bytecode format, perform a // dummy writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable attributes right now. - if (!attr.hasTrait()) { + // TODO: We don't allow custom encodings for mutable attributes right now. + if (!attr.hasTrait()) { + auto &attrTypePrinterCallbacks = config.getAttrTypePrinterCallbacks(); + if (attrTypePrinterCallbacks.contains(numbering->dialect->name)) { + NumberingDialectWriter writer(*this); + auto &callBackHandler = + attrTypePrinterCallbacks.at(numbering->dialect->name); + if (succeeded(callBackHandler->write(attr, writer))) + return; + } + + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeAttribute(attr, writer))) return; @@ -349,14 +358,26 @@ // If this type will be emitted using the bytecode format, perform a dummy // writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable types right now. - if (!type.hasTrait()) { + // TODO: We don't allow custom encodings for mutable types right now. + if (!type.hasTrait()) { + auto &attrTypePrinterCallbacks = config.getAttrTypePrinterCallbacks(); + if (attrTypePrinterCallbacks.contains(numbering->dialect->name)) { + NumberingDialectWriter writer(*this); + auto &callBackHandler = + attrTypePrinterCallbacks.at(numbering->dialect->name); + if (succeeded(callBackHandler->write(type, writer))) + return; + } + + // If this attribute will be emitted using the bytecode format, perform a + // dummy writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeType(type, writer))) return; } } + // If this type will be emitted using the fallback, number the nested dialect // resources. We don't number everything (e.g. no nested attributes/types), // because we don't want to encode things we won't decode (the textual format diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s --test-bytecode-callback | FileCheck %s + +func.func @base_test(%arg0 : i32) -> i32 { + %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + return %0 : i32 +} + +// CHECK: Overriding printing for builtin dialect attr/types +// CHECK: Overriding parsing for builtin dialect attr/types diff --git a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir --- a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir +++ b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir @@ -8,12 +8,12 @@ // Index //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX +// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX // INDEX: invalid Attribute index: 3 //===--------------------------------------------------------------------===// // Trailing Data //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA +// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA // TRAILING_DATA: trailing characters found after Attribute assembly format: trailing diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestBytecodeCallbacks.cpp TestBuiltinAttributeInterfaces.cpp TestClone.cpp TestDiagnostics.cpp diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp @@ -0,0 +1,82 @@ +//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; + +namespace { + +/// This is a test pass which uses callbacks to encode attributes and types in a +/// custom fashion. +struct TestBytecodeCallbackPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass) + + StringRef getArgument() const final { return "test-bytecode-callback"; } + StringRef getDescription() const final { + return "Test encoding of a dialect type/attributes with a custom callback"; + } + TestBytecodeCallbackPass() = default; + TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {} + + void runOnOperation() override { + Operation *op = getOperation(); + std::string bytecode; + { + BytecodeWriterConfig writeConfig; + writeConfig.attachAttrTypeCallback( + "builtin", + [&](auto entryValue, DialectBytecodeWriter &writer) -> LogicalResult { + llvm::outs() + << "Overriding printing for builtin dialect attr/types\n"; + // We don't override any encoding, hence return failure. + return failure(); + }); + llvm::raw_string_ostream os(bytecode); + if (failed(writeBytecodeToFile(op, os, writeConfig))) { + op->emitError() << "failed to write bytecode\n"; + signalPassFailure(); + return; + } + } + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); + parseConfig.attachAttrTypeBytecodeCallback( + "builtin", + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + auto &entry) -> void { + llvm::outs() << "Overriding parsing for builtin dialect attr/types\n"; + // `entry` is an empty type. If not assigned, will fall through the + // normal emission/parsing path. + return; + }); + auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig); + if (!newModuleOp.get()) { + op->emitError() << "failed to write bytecode\n"; + signalPassFailure(); + } + return; + } +}; +} // namespace + +namespace mlir { +void registerTestBytecodeCallbackPasses() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -42,6 +42,7 @@ void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAffineReifyValueBoundsPass(); +void registerTestBytecodeCallbackPasses(); void registerTestDecomposeAffineOpPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); @@ -160,6 +161,7 @@ registerTestDecomposeAffineOpPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass(); + registerTestBytecodeCallbackPasses(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); registerTestLoopPermutationPass();