diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -23,6 +23,17 @@ #include "llvm/ADT/Twine.h" namespace mlir { +//===--------------------------------------------------------------------===// +// Dialect Version Interface. +//===--------------------------------------------------------------------===// + +/// This class is used to represent the version of a dialect, for the purpose +/// of polymorphic destruction. +class DialectVersion { +public: + virtual ~DialectVersion() = default; +}; + //===----------------------------------------------------------------------===// // DialectBytecodeReader //===----------------------------------------------------------------------===// @@ -37,7 +48,14 @@ virtual ~DialectBytecodeReader() = default; /// Emit an error to the reader. - virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0; + virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0; + + /// Retrieve the dialect version by name if available. + virtual FailureOr + getDialectVersion(StringRef dialectName) const = 0; + + /// Retrieve the context associated to the reader. + virtual MLIRContext *getContext() const = 0; /// Read out a list of elements, invoking the provided callback for each /// element. The callback function may be in any of the following forms: @@ -261,17 +279,6 @@ virtual int64_t getBytecodeVersion() const = 0; }; -//===--------------------------------------------------------------------===// -// Dialect Version Interface. -//===--------------------------------------------------------------------===// - -/// This class is used to represent the version of a dialect, for the purpose -/// of polymorphic destruction. -class DialectVersion { -public: - virtual ~DialectVersion() = default; -}; - //===----------------------------------------------------------------------===// // BytecodeDialectInterface //===----------------------------------------------------------------------===// @@ -286,47 +293,23 @@ //===--------------------------------------------------------------------===// /// Read an attribute belonging to this dialect from the given reader. This - /// method should return null in the case of failure. + /// method should return null in the case of failure. Optionally, the dialect + /// version can be accessed through the reader. virtual Attribute readAttribute(DialectBytecodeReader &reader) const { reader.emitError() << "dialect " << getDialect()->getNamespace() << " does not support reading attributes from bytecode"; return Attribute(); } - /// Read a versioned attribute encoding belonging to this dialect from the - /// given reader. This method should return null in the case of failure, and - /// falls back to the non-versioned reader in case the dialect implements - /// versioning but it does not support versioned custom encodings for the - /// attributes. - virtual Attribute readAttribute(DialectBytecodeReader &reader, - const DialectVersion &version) const { - reader.emitError() - << "dialect " << getDialect()->getNamespace() - << " does not support reading versioned attributes from bytecode"; - return Attribute(); - } - /// Read a type belonging to this dialect from the given reader. This method - /// should return null in the case of failure. + /// should return null in the case of failure. Optionally, the dialect version + /// can be accessed thorugh the reader. virtual Type readType(DialectBytecodeReader &reader) const { reader.emitError() << "dialect " << getDialect()->getNamespace() << " does not support reading types from bytecode"; return Type(); } - /// Read a versioned type encoding belonging to this dialect from the given - /// reader. This method should return null in the case of failure, and - /// falls back to the non-versioned reader in case the dialect implements - /// versioning but it does not support versioned custom encodings for the - /// types. - virtual Type readType(DialectBytecodeReader &reader, - const DialectVersion &version) const { - reader.emitError() - << "dialect " << getDialect()->getNamespace() - << " does not support reading versioned types from bytecode"; - return Type(); - } - //===--------------------------------------------------------------------===// // Writing //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -48,6 +48,43 @@ /// Get the set desired bytecode version to emit. int64_t getDesiredBytecodeVersion() const; + //===--------------------------------------------------------------------===// + // Types and Attributes encoding + //===--------------------------------------------------------------------===// + + /// Retrieve the callbacks. + llvm::SmallVector>> & + getAttributeWriterCallbacks() const; + llvm::SmallVector>> & + getTypeWriterCallbacks() const; + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings. + void attachAttributeCallback( + std::unique_ptr> callback); + void + attachTypeCallback(std::unique_ptr> callback); + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings. + template + std::enable_if_t &, + DialectBytecodeWriter &)>>> + attachAttributeCallback(CallableT &&emitFn) { + attachAttributeCallback(AsmAttrTypeBytecodeWriter::fromCallable( + std::forward(emitFn))); + } + template + std::enable_if_t &, + DialectBytecodeWriter &)>>> + attachTypeCallback(CallableT &&emitFn) { + attachTypeCallback(AsmAttrTypeBytecodeWriter::fromCallable( + std::forward(emitFn))); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -23,14 +23,104 @@ #include namespace mlir { -class AsmResourcePrinter; class AsmDialectResourceHandle; +class AsmResourcePrinter; +class DialectBytecodeReader; +class DialectBytecodeWriter; +class DialectVersion; class Operation; namespace detail { class AsmStateImpl; } // namespace detail +//===----------------------------------------------------------------------===// +// AsmAttrTypeBytecode Parser/Printer +//===----------------------------------------------------------------------===// + +/// A class to interact with the attributes and types printer when emitting MLIR +/// bytecode. +template +class AsmAttrTypeBytecodeWriter { +public: + AsmAttrTypeBytecodeWriter() = default; + virtual ~AsmAttrTypeBytecodeWriter() = default; + + /// Callback writer API used in IRNumbering, where groups are created and + /// type/attribute components are numbered. At this stage, writer is expected + /// to be a `NumberingDialectWriter`. + virtual LogicalResult write(T entry, std::optional &name, + DialectBytecodeWriter &writer) = 0; + + /// Callback writer API used in BytecodeWriter, where groups are created and + /// type/attribute components are numbered. Here, DialectBytecodeWriter is + /// expected to be an actual writer. The optional stringref specified by + /// the user is ignored, since the group was already specified when numbering + /// the IR. + LogicalResult write(T entry, DialectBytecodeWriter &writer) { + std::optional dummy; + return write(entry, dummy, writer); + } + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of the `write` function above. + template &, + DialectBytecodeWriter &)>>, + bool> = true> + static std::unique_ptr> + fromCallable(CallableT &&writeFn) { + struct Processor : public AsmAttrTypeBytecodeWriter { + Processor(CallableT &&writeFn) + : AsmAttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {} + LogicalResult write(T entry, std::optional &name, + DialectBytecodeWriter &writer) override { + return writeFn(entry, name, writer); + } + + std::decay_t writeFn; + }; + return std::make_unique(std::forward(writeFn)); + } +}; + +/// A class to interact with the attributes and types parser when parsing MLIR +/// bytecode. +template +class AsmAttrTypeBytecodeParser { +public: + AsmAttrTypeBytecodeParser() = default; + virtual ~AsmAttrTypeBytecodeParser() = default; + + virtual LogicalResult parse(DialectBytecodeReader &reader, + StringRef dialectName, T &entry) = 0; + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of the `parse` function above. + template >, + bool> = true> + static std::unique_ptr> + fromCallable(CallableT &&parseFn) { + struct Processor : public AsmAttrTypeBytecodeParser { + Processor(CallableT &&parseFn) + : AsmAttrTypeBytecodeParser(), parseFn(std::move(parseFn)) {} + LogicalResult parse(DialectBytecodeReader &reader, StringRef dialectName, + T &entry) override { + return parseFn(reader, dialectName, entry); + } + + std::decay_t parseFn; + }; + return std::make_unique(std::forward(parseFn)); + } +}; + //===----------------------------------------------------------------------===// // Resources //===----------------------------------------------------------------------===// @@ -475,6 +565,47 @@ /// Returns if the parser should verify the IR after parsing. bool shouldVerifyAfterParse() const { return verifyAfterParse; } + /// Returns the callbacks available to the parser. + ArrayRef>> + getAttributeBytecodeCallbacks() const { + return attributeBytecodeParsers; + } + ArrayRef>> + getTypeBytecodeCallbacks() const { + return typeBytecodeParsers; + } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings. + void attachAttributeBytecodeCallback( + std::unique_ptr> parser) { + attributeBytecodeParsers.emplace_back(std::move(parser)); + } + void attachTypeBytecodeCallback( + std::unique_ptr> parser) { + typeBytecodeParsers.emplace_back(std::move(parser)); + } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings. + template + std::enable_if_t>> + attachAttributeBytecodeCallback(CallableT &&parserFn) { + attachAttributeBytecodeCallback( + AsmAttrTypeBytecodeParser::fromCallable( + std::forward(parserFn))); + } + template + std::enable_if_t>> + attachTypeBytecodeCallback(CallableT &&parserFn) { + attachTypeBytecodeCallback(AsmAttrTypeBytecodeParser::fromCallable( + std::forward(parserFn))); + } + /// Return the resource parser registered to the given name, or nullptr if no /// parser with `name` is registered. AsmResourceParser *getResourceParser(StringRef name) const { @@ -509,6 +640,10 @@ bool verifyAfterParse; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; + llvm::SmallVector>> + attributeBytecodeParsers; + llvm::SmallVector>> + typeBytecodeParsers; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialectBytecode.h b/mlir/include/mlir/IR/BuiltinDialectBytecode.h rename from mlir/lib/IR/BuiltinDialectBytecode.h rename to mlir/include/mlir/IR/BuiltinDialectBytecode.h --- a/mlir/lib/IR/BuiltinDialectBytecode.h +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.h @@ -10,17 +10,27 @@ // //===----------------------------------------------------------------------===// -#ifndef LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H -#define LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H +#ifndef MLIR_IR_BUILTINDIALECTBYTECODE_H +#define MLIR_IR_BUILTINDIALECTBYTECODE_H + +#include "mlir/Bytecode/BytecodeImplementation.h" namespace mlir { class BuiltinDialect; -namespace builtin_dialect_detail { +namespace builtin { +/// Utility read/write functions for types and attributes based on the builtin +/// bytecode encoding. +Attribute readAttribute(DialectBytecodeReader &reader); +LogicalResult writeAttribute(Attribute attribute, + DialectBytecodeWriter &writer); +Type readType(DialectBytecodeReader &reader); +LogicalResult writeType(Type type, DialectBytecodeWriter &writer); + /// Add the interfaces necessary for encoding the builtin dialect components in /// bytecode. void addBytecodeInterface(BuiltinDialect *dialect); -} // namespace builtin_dialect_detail +} // namespace builtin } // namespace mlir #endif // LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -451,7 +451,7 @@ /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(DialectReader &reader, MLIRContext *ctx); + LogicalResult load(const DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -505,10 +505,11 @@ /// Parse a single dialect group encoded in the byte stream. static LogicalResult parseDialectGrouping( - EncodingReader &reader, MutableArrayRef dialects, + EncodingReader &reader, + MutableArrayRef> dialects, function_ref entryCallback) { // Parse the dialect and the number of entries in the group. - BytecodeDialect *dialect; + std::unique_ptr *dialect; if (failed(parseEntry(reader, dialects, dialect, "dialect"))) return failure(); uint64_t numEntries; @@ -516,7 +517,7 @@ return failure(); for (uint64_t i = 0; i < numEntries; ++i) - if (failed(entryCallback(dialect))) + if (failed(entryCallback(dialect->get()))) return failure(); return success(); } @@ -532,7 +533,7 @@ /// Initialize the resource section reader with the given section data. LogicalResult initialize(Location fileLoc, const ParserConfig &config, - MutableArrayRef dialects, + MutableArrayRef> dialects, StringSectionReader &stringReader, ArrayRef sectionData, ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef); @@ -682,7 +683,7 @@ LogicalResult ResourceSectionReader::initialize( Location fileLoc, const ParserConfig &config, - MutableArrayRef dialects, + MutableArrayRef> dialects, StringSectionReader &stringReader, ArrayRef sectionData, ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef) { @@ -731,19 +732,19 @@ // Read the dialect resources from the bytecode. MLIRContext *ctx = fileLoc->getContext(); while (!offsetReader.empty()) { - BytecodeDialect *dialect; + std::unique_ptr *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(dialectReader, ctx))) + failed((*dialect)->load(dialectReader, ctx))) return failure(); - Dialect *loadedDialect = dialect->getLoadedDialect(); + Dialect *loadedDialect = (*dialect)->getLoadedDialect(); if (!loadedDialect) { return resourceReader.emitError() - << "dialect '" << dialect->name << "' is unknown"; + << "dialect '" << (*dialect)->name << "' is unknown"; } const auto *handler = dyn_cast(loadedDialect); if (!handler) { return resourceReader.emitError() - << "unexpected resources for dialect '" << dialect->name << "'"; + << "unexpected resources for dialect '" << (*dialect)->name << "'"; } // Ensure that each resource is declared before being processed. @@ -753,7 +754,7 @@ if (failed(handle)) { return resourceReader.emitError() << "unknown 'resource' key '" << key << "' for dialect '" - << dialect->name << "'"; + << (*dialect)->name << "'"; } dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); dialectResources.push_back(*handle); @@ -796,14 +797,17 @@ public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc) + ResourceSectionReader &resourceReader, + const llvm::StringMap &dialectsMap, + Location fileLoc, const ParserConfig &config) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc) {} + dialectsMap(dialectsMap), fileLoc(fileLoc), parserConfig(config) {} /// Initialize the attribute and type information within the reader. - LogicalResult initialize(MutableArrayRef dialects, - ArrayRef sectionData, - ArrayRef offsetSectionData); + LogicalResult + initialize(MutableArrayRef> dialects, + ArrayRef sectionData, + ArrayRef offsetSectionData); /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. @@ -877,29 +881,56 @@ /// parsing custom encoded attribute/type entries. ResourceSectionReader &resourceReader; + /// The map of the loaded dialects used to retrieve dialect information, such + /// as the dialect version. + const llvm::StringMap &dialectsMap; + /// The set of attribute and type entries. SmallVector attributes; SmallVector types; /// A location used for error emission. Location fileLoc; + + /// Reference to the parser configuration. + const ParserConfig &parserConfig; }; class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, EncodingReader &reader) + ResourceSectionReader &resourceReader, + const llvm::StringMap &dialectsMap, + EncodingReader &reader) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), reader(reader) {} + resourceReader(resourceReader), dialectsMap(dialectsMap), + reader(reader) {} - InFlightDiagnostic emitError(const Twine &msg) override { + InFlightDiagnostic emitError(const Twine &msg) const override { return reader.emitError(msg); } - DialectReader withEncodingReader(EncodingReader &encReader) { + FailureOr + getDialectVersion(StringRef dialectName) const override { + // First check if the dialect is available in the map. + auto dialectEntry = dialectsMap.find(dialectName); + if (dialectEntry == dialectsMap.end()) + return failure(); + // If the dialect was found, try to load it. This will trigger reading the + // bytecode version from the version buffer if it wasn't already processed. + // Return failure if either of those two actions could not be completed. + if (failed(dialectEntry->getValue()->load(*this, getLoc().getContext())) || + dialectEntry->getValue()->loadedVersion.get() == nullptr) + return failure(); + return dialectEntry->getValue()->loadedVersion.get(); + } + + MLIRContext *getContext() const override { return getLoc().getContext(); } + + DialectReader withEncodingReader(EncodingReader &encReader) const { return DialectReader(attrTypeReader, stringReader, resourceReader, - encReader); + dialectsMap, encReader); } Location getLoc() const { return reader.getLoc(); } @@ -1002,6 +1033,7 @@ AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; ResourceSectionReader &resourceReader; + const llvm::StringMap &dialectsMap; EncodingReader &reader; }; @@ -1087,10 +1119,9 @@ }; } // namespace -LogicalResult -AttrTypeReader::initialize(MutableArrayRef dialects, - ArrayRef sectionData, - ArrayRef offsetSectionData) { +LogicalResult AttrTypeReader::initialize( + MutableArrayRef> dialects, + ArrayRef sectionData, ArrayRef offsetSectionData) { EncodingReader offsetReader(offsetSectionData, fileLoc); // Parse the number of attribute and type entries. @@ -1142,6 +1173,7 @@ return offsetReader.emitError( "unexpected trailing data in the Attribute/Type offset section"); } + return success(); } @@ -1207,31 +1239,52 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, reader); + DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, + reader); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); + + if constexpr (std::is_same_v) { + // Try parsing with callbacks first if available. + for (const auto &callback : parserConfig.getTypeBytecodeCallbacks()) { + if (failed( + callback->parse(dialectReader, entry.dialect->name, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } else { + // Try parsing with callbacks first if available. + for (const auto &callback : parserConfig.getAttributeBytecodeCallbacks()) { + if (failed( + callback->parse(dialectReader, entry.dialect->name, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } + // Ensure that the dialect implements the bytecode interface. if (!entry.dialect->interface) { return reader.emitError("dialect '", entry.dialect->name, "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. If the dialect is versioned, parse - // using the versioned encoding readers. - if (entry.dialect->loadedVersion.get()) { - if constexpr (std::is_same_v) - entry.entry = entry.dialect->interface->readType( - dialectReader, *entry.dialect->loadedVersion); - else - entry.entry = entry.dialect->interface->readAttribute( - dialectReader, *entry.dialect->loadedVersion); + if constexpr (std::is_same_v) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); - } else { - if constexpr (std::is_same_v) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); - } return success(!!entry.entry); } @@ -1252,7 +1305,8 @@ llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc), + attrTypeReader(stringReader, resourceReader, dialectsMap, fileLoc, + config), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1518,7 +1572,8 @@ StringRef producer; /// The table of IR units referenced within the bytecode file. - SmallVector dialects; + SmallVector> dialects; + llvm::StringMap dialectsMap; SmallVector opNames; /// The reader used to process resources within the bytecode. @@ -1665,7 +1720,8 @@ //===----------------------------------------------------------------------===// // Dialect Section -LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { +LogicalResult BytecodeDialect::load(const DialectReader &reader, + MLIRContext *ctx) { if (dialect) return success(); Dialect *loadedDialect = ctx->getOrLoadDialect(name); @@ -1709,13 +1765,15 @@ // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { + dialects[i] = std::make_unique(); /// Before version kDialectVersioning, there wasn't any versioning available /// for dialects, and the entryIdx represent the string itself. if (version < bytecode::kDialectVersioning) { - if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) return failure(); continue; } + // Parse ID representing dialect and version. uint64_t dialectNameIdx; bool versionAvailable; @@ -1723,18 +1781,19 @@ versionAvailable))) return failure(); if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, - dialects[i].name))) + dialects[i]->name))) return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed( - sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + if (failed(sectionReader.parseSection(sectionID, + dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { emitError(fileLoc, "expected dialect version section"); return failure(); } } + dialectsMap[dialects[i]->name] = dialects[i].get(); } // Parse the operation names, which are grouped by dialect. @@ -1782,7 +1841,7 @@ if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + dialectsMap, reader); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1825,7 +1884,7 @@ // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + dialectsMap, reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2026,14 +2085,14 @@ "parsed use-list orders were invalid and could not be applied"); // Resolve dialect version. - for (const BytecodeDialect &byteCodeDialect : dialects) { + for (const std::unique_ptr &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the // IR and perform upgrades. - if (!byteCodeDialect.loadedVersion) + if (!byteCodeDialect->loadedVersion) continue; - if (byteCodeDialect.interface && - failed(byteCodeDialect.interface->upgradeFromVersion( - *moduleOp, *byteCodeDialect.loadedVersion))) + if (byteCodeDialect->interface && + failed(byteCodeDialect->interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect->loadedVersion))) return failure(); } @@ -2186,7 +2245,7 @@ // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + dialectsMap, reader); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -18,15 +18,10 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/Endian.h" -#include -#include -#include +#include "llvm/Support/raw_ostream.h" #include -#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -47,6 +42,12 @@ /// The producer of the bytecode. StringRef producer; + /// Printer callbacks used to emit custom type and attribute encodings. + llvm::SmallVector>> + attributeWriterCallbacks; + llvm::SmallVector>> + typeWriterCallbacks; + /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; }; @@ -60,6 +61,26 @@ } BytecodeWriterConfig::~BytecodeWriterConfig() = default; +llvm::SmallVector>> & +BytecodeWriterConfig::getAttributeWriterCallbacks() const { + return impl->attributeWriterCallbacks; +} + +llvm::SmallVector>> & +BytecodeWriterConfig::getTypeWriterCallbacks() const { + return impl->typeWriterCallbacks; +} + +void BytecodeWriterConfig::attachAttributeCallback( + std::unique_ptr> callback) { + impl->attributeWriterCallbacks.emplace_back(std::move(callback)); +} + +void BytecodeWriterConfig::attachTypeCallback( + std::unique_ptr> callback) { + impl->typeWriterCallbacks.emplace_back(std::move(callback)); +} + void BytecodeWriterConfig::attachResourcePrinter( std::unique_ptr printer) { impl->externalResourcePrinters.emplace_back(std::move(printer)); @@ -767,32 +788,50 @@ auto emitAttrOrType = [&](auto &entry) { auto entryValue = entry.getValue(); - // First, try to emit this entry using the dialect bytecode interface. - bool hasCustomEncoding = false; - if (const BytecodeDialectInterface *interface = entry.dialect->interface) { - // The writer used when emitting using a custom bytecode encoding. + auto emitAttrOrTypeRawImpl = [&]() -> void { + RawEmitterOstream(attrTypeEmitter) << entryValue; + attrTypeEmitter.emitByte(0); + }; + auto emitAttrOrTypeImpl = [&]() -> bool { + // TODO: We don't currently support custom encoded mutable types and + // attributes. + if (entryValue.template hasTrait() || + entryValue.template hasTrait()) { + emitAttrOrTypeRawImpl(); + return false; + } + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, numberingState, stringSection); - if constexpr (std::is_same_v, Type>) { - // TODO: We don't currently support custom encoded mutable types. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeType(entryValue, dialectWriter)); + for (const auto &callback : config.typeWriterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + return true; + } + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + if (succeeded(interface->writeType(entryValue, dialectWriter))) + return true; + } } else { - // TODO: We don't currently support custom encoded mutable attributes. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeAttribute(entryValue, dialectWriter)); + for (const auto &callback : config.attributeWriterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + return true; + } + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + if (succeeded(interface->writeAttribute(entryValue, dialectWriter))) + return true; + } } - } - // If the entry was not emitted using the dialect interface, emit it using - // the textual format. - if (!hasCustomEncoding) { - RawEmitterOstream(attrTypeEmitter) << entryValue; - attrTypeEmitter.emitByte(0); - } + // If the entry was not emitted using a callback or a dialect interface, + // emit it using the textual format. + emitAttrOrTypeRawImpl(); + return false; + }; + + bool hasCustomEncoding = emitAttrOrTypeImpl(); // Record the offset of this entry. uint64_t curOffset = attrTypeEmitter.size(); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -200,9 +200,22 @@ // If this attribute will be emitted using the bytecode format, perform a // dummy writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable attributes right now. - if (!attr.hasTrait()) { + // TODO: We don't allow custom encodings for mutable attributes right now. + if (!attr.hasTrait()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getAttributeWriterCallbacks()) { + NumberingDialectWriter writer(*this); + // The client has the ability to override the group name through the + // callback. + std::optional groupNameOverride; + if (succeeded(callback->write(attr, groupNameOverride, writer))) { + if (groupNameOverride.has_value()) + numbering->dialect = &numberDialect(*groupNameOverride); + return; + } + } + + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeAttribute(attr, writer))) return; @@ -350,9 +363,24 @@ // If this type will be emitted using the bytecode format, perform a dummy // writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable types right now. - if (!type.hasTrait()) { + // TODO: We don't allow custom encodings for mutable types right now. + if (!type.hasTrait()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getTypeWriterCallbacks()) { + NumberingDialectWriter writer(*this); + // The client has the ability to override the group name through the + // callback. + std::optional groupNameOverride; + if (succeeded(callback->write(type, groupNameOverride, writer))) { + if (groupNameOverride.has_value()) + numbering->dialect = &numberDialect(*groupNameOverride); + return; + } + } + + // If this attribute will be emitted using the bytecode format, perform a + // dummy writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeType(type, writer))) return; diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinDialect.h" -#include "BuiltinDialectBytecode.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinDialectBytecode.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" @@ -23,6 +23,7 @@ #include "mlir/IR/TypeRange.h" using namespace mlir; +using namespace builtin; //===----------------------------------------------------------------------===// // TableGen'erated dialect @@ -122,7 +123,7 @@ auto &blobInterface = addInterface(); addInterface(blobInterface); - builtin_dialect_detail::addBytecodeInterface(this); + addBytecodeInterface(this); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "BuiltinDialectBytecode.h" +#include "mlir/IR/BuiltinDialectBytecode.h" #include "AttributeDetail.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/IR/BuiltinAttributes.h" @@ -18,10 +18,6 @@ using namespace mlir; -//===----------------------------------------------------------------------===// -// BuiltinDialectBytecodeInterface -//===----------------------------------------------------------------------===// - namespace { //===----------------------------------------------------------------------===// @@ -81,6 +77,10 @@ #include "mlir/IR/BuiltinDialectBytecode.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinDialectBytecodeInterface +//===----------------------------------------------------------------------===// + /// This class implements the bytecode interface for the builtin dialect. struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { BuiltinDialectBytecodeInterface(Dialect *dialect) @@ -112,6 +112,23 @@ }; } // namespace -void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) { +Attribute builtin::readAttribute(DialectBytecodeReader &reader) { + return ::readAttribute(reader.getContext(), reader); +} + +LogicalResult builtin::writeAttribute(Attribute attribute, + DialectBytecodeWriter &writer) { + return ::writeAttribute(attribute, writer); +} + +Type builtin::readType(DialectBytecodeReader &reader) { + return ::readType(reader.getContext(), reader); +} + +LogicalResult builtin::writeType(Type type, DialectBytecodeWriter &writer) { + return ::writeType(type, writer); +} + +void builtin::addBytecodeInterface(BuiltinDialect *dialect) { dialect->addInterfaces(); } diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2 +// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0 + +func.func @base_test(%arg0 : i32) -> f32 { + %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %1 = "test.cast"(%0) : (i32) -> f32 + return %1 : f32 +} + +// VERSION_1_2: Overriding IntegerType encoding... +// VERSION_1_2: Overriding parsing of IntegerType encoding... + +// VERSION_2_0-NOT: Overriding IntegerType encoding... +// VERSION_2_0-NOT: Overriding parsing of IntegerType encoding... diff --git a/mlir/test/Bytecode/bytecode_callback_full_override.mlir b/mlir/test/Bytecode/bytecode_callback_full_override.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_full_override.mlir @@ -0,0 +1,18 @@ +// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s + +// CHECK-NOT: failed to read bytecode +func.func @base_test(%arg0 : i32) -> f32 { + %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %1 = "test.cast"(%0) : (i32) -> f32 + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: error: unknown attribute code: 99 +// CHECK: failed to read bytecode +func.func @base_test(%arg0 : !test.i32) -> f32 { + %0 = "test.addi"(%arg0, %arg0) : (!test.i32, !test.i32) -> !test.i32 + %1 = "test.cast"(%0) : (!test.i32) -> f32 + return %1 : f32 +} diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3 +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4 + +"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> () + +// TEST_3: Overriding TestAttrParamsAttr encoding... +// TEST_3: "test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> () + +// ----- + +"test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> () + +// TEST_4: Overriding parsing of TestAttrParamsAttr encoding... +// TEST_4: "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> () diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1 +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2 + +func.func @base_test(%arg0: !test.i32, %arg1: f32) { + return +} + +// TEST_1: Overriding TestI32Type encoding... +// TEST_1: func.func @base_test([[ARG0:%.+]]: i32, [[ARG1:%.+]]: f32) { + +// ----- + +func.func @base_test(%arg0: i32, %arg1: f32) { + return +} + +// TEST_2: Overriding parsing of TestI32Type encoding... +// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) { diff --git a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir --- a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir +++ b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir @@ -5,12 +5,12 @@ // Index //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX +// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX // INDEX: invalid Attribute index: 3 //===--------------------------------------------------------------------===// // Trailing Data //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA +// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA // TRAILING_DATA: trailing characters found after Attribute assembly format: trailing diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,9 +14,10 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H -#include "TestTypes.h" #include "TestAttributes.h" #include "TestInterfaces.h" +#include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -57,6 +58,19 @@ #include "TestOpsDialect.h.inc" namespace test { + +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion : public mlir::DialectVersion { + TestDialectVersion() = default; + TestDialectVersion(uint32_t _major, uint32_t _minor) + : major(_major), minor(_minor){}; + uint32_t major = 2; + uint32_t minor = 0; +}; + // Define some classes to exercises the Properties feature. struct PropertiesWithCustomPrint { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -10,7 +10,6 @@ #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" -#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -119,15 +118,6 @@ registry.insert(); } -//===----------------------------------------------------------------------===// -// TestDialect version utilities -//===----------------------------------------------------------------------===// - -struct TestDialectVersion : public DialectVersion { - uint32_t major = 2; - uint32_t minor = 0; -}; - //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -152,7 +142,7 @@ }; namespace { -enum test_encoding { k_attr_params = 0 }; +enum test_encoding { k_attr_params = 0, k_test_i32 = 99 }; } // Test support for interacting with the Bytecode reader/writer. @@ -161,6 +151,24 @@ TestBytecodeDialectInterface(Dialect *dialect) : BytecodeDialectInterface(dialect) {} + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const final { + if (auto concreteType = llvm::dyn_cast(type)) { + writer.writeVarInt(test_encoding::k_test_i32); + return success(); + } + return failure(); + } + + Type readType(DialectBytecodeReader &reader) const final { + uint64_t encoding; + if (failed(reader.readVarInt(encoding))) + return Type(); + if (encoding == test_encoding::k_test_i32) + return TestI32Type::get(getContext()); + return Type(); + } + LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const final { if (auto concreteAttr = llvm::dyn_cast(attr)) { @@ -172,9 +180,13 @@ return failure(); } - Attribute readAttribute(DialectBytecodeReader &reader, - const DialectVersion &version_) const final { - const auto &version = static_cast(version_); + Attribute readAttribute(DialectBytecodeReader &reader) const final { + auto versionOr = reader.getDialectVersion("test"); + // Assume current version if not available through the reader. + const auto version = + (succeeded(versionOr)) + ? *reinterpret_cast(*versionOr) + : TestDialectVersion(); if (version.major < 2) return readAttrOldEncoding(reader); if (version.major == 2 && version.minor == 0) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1311,8 +1311,9 @@ } def TestAddIOp : TEST_Op<"addi"> { - let arguments = (ins I32:$op1, I32:$op2); - let results = (outs I32); + let arguments = (ins AnyTypeOf<[I32, TestI32]>:$op1, + AnyTypeOf<[I32, TestI32]>:$op2); + let results = (outs AnyTypeOf<[I32, TestI32]>); } def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { @@ -3315,6 +3316,12 @@ ); } +def TestVersionedOpC : TEST_Op<"versionedC"> { + let arguments = (ins AnyAttrOf<[TestAttrParams, + I32ElementsAttr]>:$attribute + ); +} + //===----------------------------------------------------------------------===// // Test Properties //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -369,4 +369,8 @@ let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`"; } +def TestI32 : Test_Type<"TestI32"> { + let mnemonic = "i32"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestBytecodeCallbacks.cpp TestBuiltinAttributeInterfaces.cpp TestBuiltinDistinctAttributes.cpp TestClone.cpp diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp @@ -0,0 +1,358 @@ +//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinDialectBytecode.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace llvm; + +namespace { +class TestDialectVersionParser : public cl::parser { +public: + TestDialectVersionParser(cl::Option &O) + : cl::parser(O) {} + + bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg, + test::TestDialectVersion &v) { + long long major, minor; + if (getAsSignedInteger(arg.split(".").first, 10, major)) + return O.error("Invalid argument '" + arg); + if (getAsSignedInteger(arg.split(".").second, 10, minor)) + return O.error("Invalid argument '" + arg); + v = test::TestDialectVersion(major, minor); + // Returns true on error. + return false; + } + static void print(raw_ostream &os, const test::TestDialectVersion &v) { + os << v.major << "." << v.minor; + }; +}; + +/// This is a test pass which uses callbacks to encode attributes and types in a +/// custom fashion. +struct TestBytecodeCallbackPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass) + + StringRef getArgument() const final { return "test-bytecode-callback"; } + StringRef getDescription() const final { + return "Test encoding of a dialect type/attributes with a custom callback"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + TestBytecodeCallbackPass() = default; + TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {} + + void runOnOperation() override { + switch (testKind) { + case (0): + return runTest0(getOperation()); + case (1): + return runTest1(getOperation()); + case (2): + return runTest2(getOperation()); + case (3): + return runTest3(getOperation()); + case (4): + return runTest4(getOperation()); + case (5): + return runTest5(getOperation()); + default: + llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass"); + } + } + + mlir::Pass::Option + targetVersion{*this, "test-dialect-version", + llvm::cl::desc( + "Specifies the test dialect version to emit and parse"), + cl::init(test::TestDialectVersion())}; + + mlir::Pass::Option testKind{ + *this, "callback-test", + llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)}; + +private: + void doRoundtripWithConfigs(Operation *op, + const BytecodeWriterConfig &writeConfig, + const ParserConfig &parseConfig) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + if (failed(writeBytecodeToFile(op, os, writeConfig))) { + op->emitError() << "failed to write bytecode\n"; + signalPassFailure(); + return; + } + auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig); + if (!newModuleOp.get()) { + op->emitError() << "failed to read bytecode\n"; + signalPassFailure(); + return; + } + // Print the module to the output stream, so that we can filecheck the + // result. + newModuleOp->print(llvm::outs()); + return; + } + + // Test0: let's assume that versions older than 2.0 were relying on a special + // integer attribute of a deprecated dialect called "funky". Assume that its + // encoding was made by two varInts, the first was the ID (999) and the second + // contained width and signedness info. We can emit it using a callback + // writing a custom encoding for the "funky" dialect group, and parse it back + // with a custom parser reading the same encoding in the same dialect group. + // Note that the ID 999 does not correspond to a valid integer type in the + // current encodings of builtin types. + void runTest0(Operation *op) { + auto newCtx = std::make_shared(); + test::TestDialectVersion targetEmissionVersion = targetVersion; + BytecodeWriterConfig writeConfig; + writeConfig.attachTypeCallback( + [&](Type entryValue, std::optional &dialectGroupName, + DialectBytecodeWriter &writer) -> LogicalResult { + // Do not override anything if version less than 2.0. + if (targetEmissionVersion.major >= 2) + return failure(); + + // For version less than 2.0, override the encoding of IntegerType. + if (auto type = llvm::dyn_cast(entryValue)) { + llvm::outs() << "Overriding IntegerType encoding...\n"; + dialectGroupName = StringLiteral("funky"); + writer.writeVarInt(/* IntegerType */ 999); + writer.writeVarInt(type.getWidth() << 2 | type.getSignedness()); + return success(); + } + return failure(); + }); + newCtx->appendDialectRegistry(op->getContext()->getDialectRegistry()); + newCtx->allowUnregisteredDialects(); + ParserConfig parseConfig(newCtx.get(), /*verifyAfterParse=*/true); + parseConfig.attachTypeBytecodeCallback([&](DialectBytecodeReader &reader, + StringRef dialectName, + Type &entry) -> LogicalResult { + // Get test dialect version from the version map. + auto versionOr = reader.getDialectVersion("test"); + assert( + succeeded(versionOr) && + "expected reader to be able to access the version for test dialect"); + const auto *version = + reinterpret_cast(*versionOr); + + // TODO: once back-deployment is formally supported, + // `targetEmissionVersion` will be encoded in the bytecode file, and + // exposed through the versionMap. Right now though this is not yet + // supported. For the purpose of the test, just use + // `targetEmissionVersion`. + (void)version; + if (targetEmissionVersion.major >= 2) + return success(); + + // `dialectName` is the name of the group we have the opportunity to + // override. In this case, override only the dialect group "funky", for + // which does not exist in memory. + if (dialectName != StringLiteral("funky")) + return success(); + + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || encoding != 999) + return success(); + llvm::outs() << "Overriding parsing of IntegerType encoding...\n"; + uint64_t _widthAndSignedness, width; + IntegerType::SignednessSemantics signedness; + if (succeeded(reader.readVarInt(_widthAndSignedness)) && + ((width = _widthAndSignedness >> 2), true) && + ((signedness = static_cast( + _widthAndSignedness & 0x3)), + true)) + entry = IntegerType::get(reader.getContext(), width, signedness); + // Return nullopt to fall through the rest of the parsing code path. + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test1: When writing bytecode, we override the encoding of TestI32Type with + // the encoding of builtin IntegerType. We can natively parse this without + // the use of a callback, relying on the existing builtin reader mechanism. + void runTest1(Operation *op) { + BytecodeWriterConfig writeConfig; + writeConfig.attachTypeCallback( + [&](Type entryValue, std::optional &dialectGroupName, + DialectBytecodeWriter &writer) -> LogicalResult { + // Emit TestIntegerType using the builtin dialect encoding. + if (llvm::isa(entryValue)) { + llvm::outs() << "Overriding TestI32Type encoding...\n"; + auto builtinI32Type = + IntegerType::get(op->getContext(), 32, + IntegerType::SignednessSemantics::Signless); + // Specify that this type will need to be written as part of the + // builtin group. This will override the default dialect group of + // the attribute (test). + dialectGroupName = StringLiteral("builtin"); + if (succeeded(builtin::writeType(builtinI32Type, writer))) + return success(); + } + return failure(); + }); + // We natively parse the attribute as a builtin, so no callback needed. + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test2: When writing bytecode, we write standard builtin IntegerTypes. At + // parsing, we use the encoding of IntegerType to intercept all i32. Then, + // instead of creating i32s, we assemble TestI32Type and return it. + void runTest2(Operation *op) { + BytecodeWriterConfig writeConfig; + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + parseConfig.attachTypeBytecodeCallback([&](DialectBytecodeReader &reader, + StringRef dialectName, + Type &entry) -> LogicalResult { + if (dialectName != StringLiteral("builtin")) + return success(); + Type builtinAttr = builtin::readType(reader); + if (auto integerType = llvm::dyn_cast_or_null(builtinAttr)) { + if (integerType.getWidth() == 32 && integerType.isSignless()) { + llvm::outs() << "Overriding parsing of TestI32Type encoding...\n"; + entry = test::TestI32Type::get(reader.getContext()); + } + } + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test3: When writing bytecode, we override the encoding of + // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. We + // can natively parse this without the use of a callback, relying on the + // existing builtin reader mechanism. + void runTest3(Operation *op) { + auto i32Type = IntegerType::get(op->getContext(), 32, + IntegerType::SignednessSemantics::Signless); + BytecodeWriterConfig writeConfig; + writeConfig.attachAttributeCallback( + [&](Attribute entryValue, std::optional &dialectGroupName, + DialectBytecodeWriter &writer) -> LogicalResult { + // Emit TestIntegerType using the builtin dialect encoding. + if (auto testParamAttrs = + llvm::dyn_cast(entryValue)) { + llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n"; + // Specify that this attribute will need to be written as part of + // the builtin group. This will override the default dialect group + // of the attribute (test). + dialectGroupName = StringLiteral("builtin"); + auto denseAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, i32Type), + {testParamAttrs.getV0(), testParamAttrs.getV1()}); + if (succeeded(builtin::writeAttribute(denseAttr, writer))) + return success(); + } + return failure(); + }); + // We natively parse the attribute as a builtin, so no callback needed. + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test4: When writing bytecode, we write standard builtin + // DenseIntElementsAttr. At parsing, we use the encoding of + // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of + // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble + // TestAttrParamsAttr and return it. + void runTest4(Operation *op) { + auto i32Type = IntegerType::get(op->getContext(), 32, + IntegerType::SignednessSemantics::Signless); + BytecodeWriterConfig writeConfig; + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); + parseConfig.attachAttributeBytecodeCallback( + [&](DialectBytecodeReader &reader, StringRef dialectName, + Attribute &entry) -> LogicalResult { + // Override only the case where the return type of the builtin reader + // is an i32 and fall through on all the other cases, since we want to + // still use TestDialect normal codepath to parse the other types. + Attribute builtinAttr = builtin::readAttribute(reader); + if (auto denseAttr = + llvm::dyn_cast_or_null(builtinAttr)) { + if (denseAttr.getType().getShape() == ArrayRef(2) && + denseAttr.getElementType() == i32Type) { + llvm::outs() + << "Overriding parsing of TestAttrParamsAttr encoding...\n"; + int v0 = denseAttr.getValues()[0].getInt(); + int v1 = denseAttr.getValues()[1].getInt(); + entry = + test::TestAttrParamsAttr::get(reader.getContext(), v0, v1); + } + } + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test5: When writing bytecode, we want TestDialect to use nothing else than + // the builtin types and attributes and take full control of the encoding, + // returning failure if any type or attribute is not part of builtin. + void runTest5(Operation *op) { + BytecodeWriterConfig writeConfig; + writeConfig.attachAttributeCallback( + [&](Attribute attr, std::optional &dialectGroupName, + DialectBytecodeWriter &writer) -> LogicalResult { + return builtin::writeAttribute(attr, writer); + }); + writeConfig.attachTypeCallback( + [&](Type type, std::optional &dialectGroupName, + DialectBytecodeWriter &writer) -> LogicalResult { + return builtin::writeType(type, writer); + }); + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); + parseConfig.attachAttributeBytecodeCallback( + [&](DialectBytecodeReader &reader, StringRef dialectName, + Attribute &entry) -> LogicalResult { + Attribute builtinAttr = builtin::readAttribute(reader); + if (!builtinAttr) + return failure(); + entry = builtinAttr; + return success(); + }); + parseConfig.attachTypeBytecodeCallback([&](DialectBytecodeReader &reader, + StringRef dialectName, + Type &entry) -> LogicalResult { + Type builtinType = builtin::readType(reader); + if (!builtinType) { + return failure(); + } + entry = builtinType; + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } +}; +} // namespace + +namespace mlir { +void registerTestBytecodeCallbackPasses() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -43,6 +43,7 @@ void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAffineReifyValueBoundsPass(); +void registerTestBytecodeCallbackPasses(); void registerTestDecomposeAffineOpPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); @@ -164,6 +165,7 @@ registerTestDecomposeAffineOpPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass(); + registerTestBytecodeCallbackPasses(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); registerTestLoopPermutationPass();