diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -48,6 +48,33 @@ /// Get the set desired bytecode version to emit. int64_t getDesiredBytecodeVersion() const; + //===--------------------------------------------------------------------===// + // Types and Attributes encoding + //===--------------------------------------------------------------------===// + + /// Retrieve the callbacks. + llvm::SmallVector> & + getAttrTypePrinterCallbacks() const; + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings. + void + attachAttrTypeCallback(std::unique_ptr callback); + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings. + template + std::enable_if_t< + std::is_convertible_v> && + std::is_convertible_v< + CallableT, + std::function>> + attachAttrTypeCallback(CallableT &&emitFn) { + attachAttrTypeCallback(AsmAttrTypeBytecodePrinter::fromCallable( + std::forward(emitFn))); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -25,12 +25,112 @@ namespace mlir { class AsmResourcePrinter; class AsmDialectResourceHandle; +class DialectBytecodeWriter; +class DialectBytecodeReader; +class DialectVersion; class Operation; namespace detail { class AsmStateImpl; } // namespace detail +//===----------------------------------------------------------------------===// +// AsmAttrTypeBytecode Parser/Printer +//===----------------------------------------------------------------------===// + +/// A class to interact with the attributes and types printer when emitting MLIR +/// bytecode. +class AsmAttrTypeBytecodePrinter { +public: + AsmAttrTypeBytecodePrinter() = default; + virtual ~AsmAttrTypeBytecodePrinter() = default; + + virtual LogicalResult write(Type entry, DialectBytecodeWriter &writer) = 0; + virtual LogicalResult write(Attribute entry, + DialectBytecodeWriter &writer) = 0; + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of `write` functions above. + template > && + std::is_convertible_v< + CallableT, std::function>, + bool> = true> + static std::unique_ptr + fromCallable(CallableT &&writeFn) { + struct Processor : public AsmAttrTypeBytecodePrinter { + Processor(CallableT &&writeFn) + : AsmAttrTypeBytecodePrinter(), _writeFn(std::move(writeFn)) {} + LogicalResult write(Type entry, DialectBytecodeWriter &writer) override { + return _writeFn(entry, writer); + } + LogicalResult write(Attribute entry, + DialectBytecodeWriter &writer) override { + return _writeFn(entry, writer); + } + + std::decay_t _writeFn; + }; + return std::make_unique(std::forward(writeFn)); + } +}; + +/// A class to interact with the attributes and types parser when emitting MLIR +/// bytecode. +class AsmAttrTypeBytecodeParser { +public: + AsmAttrTypeBytecodeParser() = default; + virtual ~AsmAttrTypeBytecodeParser() = default; + + virtual void parse(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) = 0; + virtual void parse(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Attribute &entry) = 0; + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of `parse` functions above. + template < + typename CallableT, + std::enable_if_t< + std::is_convertible_v< + CallableT, + std::function &, + Type &)>> && + std::is_convertible_v< + CallableT, + std::function &, + Attribute &)>>, + bool> = true> + static std::unique_ptr + fromCallable(CallableT &&parseFn) { + struct Processor : public AsmAttrTypeBytecodeParser { + Processor(CallableT &&parseFn) + : AsmAttrTypeBytecodeParser(), _parseFn(std::move(parseFn)) {} + void parse(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) override { + return _parseFn(ctx, reader, versionMap, entry); + } + void parse(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Attribute &entry) override { + return _parseFn(ctx, reader, versionMap, entry); + } + + std::decay_t _parseFn; + }; + return std::make_unique(std::forward(parseFn)); + } +}; + //===----------------------------------------------------------------------===// // Resources //===----------------------------------------------------------------------===// @@ -475,6 +575,34 @@ /// Returns if the parser should verify the IR after parsing. bool shouldVerifyAfterParse() const { return verifyAfterParse; } + /// Returns the callbacks available to the parser. + auto &getAttrTypeBytecodeCallbacks() const { return attrTypeBytecodeParsers; } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings. + void attachAttrTypeBytecodeCallback( + std::unique_ptr parser) { + attrTypeBytecodeParsers.emplace_back(std::move(parser)); + } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings. + template + std::enable_if_t< + std::is_convertible_v< + CallableT, std::function &, Type &)>> && + std::is_convertible_v< + CallableT, + std::function &, + Attribute &)>>> + attachAttrTypeBytecodeCallback(CallableT &&parserFn) { + attachAttrTypeBytecodeCallback(AsmAttrTypeBytecodeParser::fromCallable( + std::forward(parserFn))); + } + /// Return the resource parser registered to the given name, or nullptr if no /// parser with `name` is registered. AsmResourceParser *getResourceParser(StringRef name) const { @@ -509,6 +637,8 @@ bool verifyAfterParse; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; + llvm::SmallVector> + attrTypeBytecodeParsers; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -796,9 +796,10 @@ public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc) + ResourceSectionReader &resourceReader, Location fileLoc, + const ParserConfig &config) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc) {} + fileLoc(fileLoc), parserConfig(config) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -883,6 +884,12 @@ /// A location used for error emission. Location fileLoc; + + /// A map to retrieve parsed dialect versions associated to each dialect name. + llvm::StringMap dialectsVersionMap; + + /// Reference to the parser configuration. + const ParserConfig &parserConfig; }; class DialectReader : public DialectBytecodeReader { @@ -1142,6 +1149,18 @@ return offsetReader.emitError( "unexpected trailing data in the Attribute/Type offset section"); } + + // Fill up the dialect to dialectVersion map for every dialect version + // available. + for (auto &dialect : dialects) { + EncodingReader encReader(dialect.versionBuffer, fileLoc); + DialectReader dialectReader(*this, stringReader, resourceReader, encReader); + if (failed(dialect.load(dialectReader, fileLoc.getContext()))) + return failure(); + if (dialect.loadedVersion.get()) + dialectsVersionMap.insert({dialect.name, dialect.loadedVersion.get()}); + } + return success(); } @@ -1210,6 +1229,21 @@ DialectReader dialectReader(*this, stringReader, resourceReader, reader); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); + + // Try parsing with callbacks first if available. + for (const auto &callback : parserConfig.getAttrTypeBytecodeCallbacks()) { + callback->parse(fileLoc.getContext(), dialectReader, dialectsVersionMap, + entry.entry); + + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the other + // parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + // Ensure that the dialect implements the bytecode interface. if (!entry.dialect->interface) { return reader.emitError("dialect '", entry.dialect->name, @@ -1252,7 +1286,7 @@ llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc), + attrTypeReader(stringReader, resourceReader, fileLoc, config), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1519,6 +1553,7 @@ /// The table of IR units referenced within the bytecode file. SmallVector dialects; + llvm::StringMap dialectMap; SmallVector opNames; /// The reader used to process resources within the bytecode. diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -18,15 +18,10 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/Endian.h" -#include -#include -#include +#include "llvm/Support/raw_ostream.h" #include -#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -47,6 +42,10 @@ /// The producer of the bytecode. StringRef producer; + /// Printer callbacks used to emit custom type and attribute encodings. + llvm::SmallVector> + attrTypePrinterCallbacks; + /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; }; @@ -60,6 +59,16 @@ } BytecodeWriterConfig::~BytecodeWriterConfig() = default; +llvm::SmallVector> & +BytecodeWriterConfig::getAttrTypePrinterCallbacks() const { + return impl->attrTypePrinterCallbacks; +} + +void BytecodeWriterConfig::attachAttrTypeCallback( + std::unique_ptr callback) { + impl->attrTypePrinterCallbacks.emplace_back(std::move(callback)); +} + void BytecodeWriterConfig::attachResourcePrinter( std::unique_ptr printer) { impl->externalResourcePrinters.emplace_back(std::move(printer)); @@ -767,23 +776,34 @@ auto emitAttrOrType = [&](auto &entry) { auto entryValue = entry.getValue(); - // First, try to emit this entry using the dialect bytecode interface. bool hasCustomEncoding = false; - if (const BytecodeDialectInterface *interface = entry.dialect->interface) { + // TODO: We don't currently support custom encoded mutable types and + // attributes. + if (!entryValue.template hasTrait() && + !entryValue.template hasTrait()) { // The writer used when emitting using a custom bytecode encoding. DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, numberingState, stringSection); + for (const auto &callback : config.attrTypePrinterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + hasCustomEncoding = true; + } - if constexpr (std::is_same_v, Type>) { - // TODO: We don't currently support custom encoded mutable types. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeType(entryValue, dialectWriter)); - } else { - // TODO: We don't currently support custom encoded mutable attributes. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeAttribute(entryValue, dialectWriter)); + if (!hasCustomEncoding) { + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + // The writer used when emitting using a custom bytecode encoding. + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, + numberingState, stringSection); + if constexpr (std::is_same_v, + Type>) { + hasCustomEncoding = + succeeded(interface->writeType(entryValue, dialectWriter)); + } else { + hasCustomEncoding = + succeeded(interface->writeAttribute(entryValue, dialectWriter)); + } + } } } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -200,9 +200,16 @@ // If this attribute will be emitted using the bytecode format, perform a // dummy writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable attributes right now. - if (!attr.hasTrait()) { + // TODO: We don't allow custom encodings for mutable attributes right now. + if (!attr.hasTrait()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getAttrTypePrinterCallbacks()) { + NumberingDialectWriter writer(*this); + if (succeeded(callback->write(attr, writer))) + return; + } + + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeAttribute(attr, writer))) return; @@ -350,9 +357,18 @@ // If this type will be emitted using the bytecode format, perform a dummy // writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable types right now. - if (!type.hasTrait()) { + // TODO: We don't allow custom encodings for mutable types right now. + if (!type.hasTrait()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getAttrTypePrinterCallbacks()) { + NumberingDialectWriter writer(*this); + if (succeeded(callback->write(type, writer))) + return; + } + + // If this attribute will be emitted using the bytecode format, perform a + // dummy writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeType(type, writer))) return; diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2 +// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0 + +func.func @base_test(%arg0 : i32) -> f32 { + %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %1 = "test.cast"(%0) : (i32) -> f32 + return %1 : f32 +} + +// VERSION_1_2: Overriding IntegerType encoding... +// VERSION_1_2: Overriding parsing of IntegerType encoding... + +// VERSION_2_0-NOT: Overriding IntegerType encoding... +// VERSION_2_0-NOT: Overriding parsing of IntegerType encoding... diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1 +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2 + +func.func @base_test(%arg0: !test.i32, %arg1: f32) { + return +} + +// TEST_1: Overriding TestI32Type encoding... +// TEST_1-NOT: Overriding parsing of TestI32Type encoding... +// TEST_1: func.func @base_test([[ARG0:%.+]]: i32, [[ARG1:%.+]]: f32) { + +// ----- + +func.func @base_test(%arg0: i32, %arg1: f32) { + return +} + +// TEST_2-NOT: Overriding TestI32Type encoding... +// TEST_2: Overriding parsing of TestI32Type encoding... +// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) { diff --git a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir --- a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir +++ b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir @@ -5,12 +5,12 @@ // Index //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX +// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX // INDEX: invalid Attribute index: 3 //===--------------------------------------------------------------------===// // Trailing Data //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA +// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA // TRAILING_DATA: trailing characters found after Attribute assembly format: trailing diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,9 +14,10 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H -#include "TestTypes.h" #include "TestAttributes.h" #include "TestInterfaces.h" +#include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -57,6 +58,19 @@ #include "TestOpsDialect.h.inc" namespace test { + +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion : public mlir::DialectVersion { + TestDialectVersion() = default; + TestDialectVersion(uint32_t _major, uint32_t _minor) + : major(_major), minor(_minor){}; + uint32_t major = 2; + uint32_t minor = 0; +}; + // Define some classes to exercises the Properties feature. struct PropertiesWithCustomPrint { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -10,7 +10,6 @@ #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" -#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -119,15 +118,6 @@ registry.insert(); } -//===----------------------------------------------------------------------===// -// TestDialect version utilities -//===----------------------------------------------------------------------===// - -struct TestDialectVersion : public DialectVersion { - uint32_t major = 2; - uint32_t minor = 0; -}; - //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -369,4 +369,8 @@ let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`"; } +def TestI32 : Test_Type<"TestI32"> { + let mnemonic = "i32"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestBytecodeCallbacks.cpp TestBuiltinAttributeInterfaces.cpp TestBuiltinDistinctAttributes.cpp TestClone.cpp diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp @@ -0,0 +1,296 @@ +//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace llvm; + +namespace { +class TestDialectVersionParser : public cl::parser { +public: + TestDialectVersionParser(cl::Option &O) + : cl::parser(O) {} + + bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg, + test::TestDialectVersion &v) { + long long major, minor; + if (getAsSignedInteger(arg.split(".").first, 10, major)) + return O.error("Invalid argument '" + arg); + if (getAsSignedInteger(arg.split(".").second, 10, minor)) + return O.error("Invalid argument '" + arg); + v = test::TestDialectVersion(major, minor); + // Returns true on error. + return false; + } + static void print(raw_ostream &os, const test::TestDialectVersion &v) { + os << v.major << "." << v.minor; + }; +}; + +/// This is a test pass which uses callbacks to encode attributes and types in a +/// custom fashion. +struct TestBytecodeCallbackPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass) + + StringRef getArgument() const final { return "test-bytecode-callback"; } + StringRef getDescription() const final { + return "Test encoding of a dialect type/attributes with a custom callback"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + TestBytecodeCallbackPass() = default; + TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {} + + void runOnOperation() override { + switch (testKind) { + case (0): + return runTest0(getOperation()); + case (1): + return runTest1(getOperation()); + case (2): + return runTest2(getOperation()); + default: + llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass"); + } + } + + mlir::Pass::Option + targetVersion{*this, "test-dialect-version", + llvm::cl::desc( + "Specifies the test dialect version to emit and parse"), + cl::init(test::TestDialectVersion())}; + + mlir::Pass::Option testKind{ + *this, "callback-test", + llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)}; + +private: + // Test0: let's assume that versions older than 2.0 were relying on a special + // integer attribute of the builtin dialect that is now deprecated. Assume + // that its encoding was made by two varInts, the first was the ID (999) and + // the second contained width and signedness info. We can emit it using a + // callback emitting a custom encoding, and parse it back with a custom parser + // reading the same encoding. Note that the ID 999 does not correspond to a + // valid integer type in the current encodings of builtin types. + void runTest0(Operation *op) { + test::TestDialectVersion targetEmissionVersion = targetVersion; + std::string bytecode; + BytecodeWriterConfig writeConfig; + writeConfig.attachAttrTypeCallback([&](auto entryValue, + DialectBytecodeWriter &writer) + -> LogicalResult { + // Do not override anything if version less than 2.0. + if (targetEmissionVersion.major >= 2) + return failure(); + + // For version less than 2.0, override the encoding of IntegerType. + if constexpr (std::is_same_v, Type>) { + if (auto type = llvm::dyn_cast(entryValue)) { + llvm::outs() << "Overriding IntegerType encoding...\n"; + writer.writeVarInt(/* IntegerType */ 999); + writer.writeVarInt(type.getWidth() << 2 | type.getSignedness()); + return success(); + } + } + return failure(); + }); + llvm::raw_string_ostream os(bytecode); + if (failed(writeBytecodeToFile(op, os, writeConfig))) { + op->emitError() << "failed to write bytecode\n"; + signalPassFailure(); + return; + } + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + parseConfig.attachAttrTypeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + auto &entry) -> void { + // Get test dialect version from the version map. + assert( + versionMap.contains("test") && + "expected versionMap to contain all the available version info"); + test::TestDialectVersion &version = + static_cast(*versionMap.at("test")); + + // TODO: once back-deployment is formally supported, + // targetEmissionVersion will be encoded in the bytecode file, and + // exposed through the versionMap. Right now though this is not yet + // supported. For the purpose of the test, just use + // `targetEmissionVersion`. + (void)version; + if (targetEmissionVersion.major >= 2) + return; + + if constexpr (std::is_same_v, Type>) { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || encoding != 999) + return; + llvm::outs() << "Overriding parsing of IntegerType encoding...\n"; + uint64_t _widthAndSignedness, width; + IntegerType::SignednessSemantics signedness; + if (succeeded(reader.readVarInt(_widthAndSignedness)) && + ((width = _widthAndSignedness >> 2), true) && + ((signedness = static_cast( + _widthAndSignedness & 0x3)), + true)) { + entry = IntegerType::get(ctx, width, signedness); + return; + } + // Fall through and do not assign entry to fallback on the + // standard codepath for parsing types and attributes. + } + return; + }); + auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig); + if (!newModuleOp.get()) { + op->emitError() << "failed to read bytecode\n"; + signalPassFailure(); + } + return; + } + + // Test1: When writing bytecode, we override the encoding of TestI32Type with + // the encoding of builtin IntegerType. At parsing, we use such encoding to + // read the type and assemble the builtin IntegerType. At the end of the pass, + // we dump the module to standard output, and so can verify that all !test.i32 + // are now i32. + void runTest1(Operation *op) { + test::TestDialectVersion targetEmissionVersion = targetVersion; + std::string bytecode; + BytecodeWriterConfig writeConfig; + writeConfig.attachAttrTypeCallback([&](auto entryValue, + DialectBytecodeWriter &writer) + -> LogicalResult { + // Emit TestIntegerType with the same encoding of Builtin. + if constexpr (std::is_same_v, Type>) { + if (llvm::isa(entryValue)) { + llvm::outs() << "Overriding TestI32Type encoding...\n"; + writer.writeVarInt(/* IntegerType */ 0); + writer.writeVarInt((32 << 2) | + IntegerType::SignednessSemantics::Signless); + return success(); + } + } + return failure(); + }); + llvm::raw_string_ostream os(bytecode); + if (failed(writeBytecodeToFile(op, os, writeConfig))) { + op->emitError() << "failed to write bytecode\n"; + signalPassFailure(); + return; + } + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + parseConfig.attachAttrTypeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + auto &entry) -> void { + if constexpr (std::is_same_v, Type>) { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || encoding != 0) + return; + uint64_t _widthAndSignedness, width; + IntegerType::SignednessSemantics signedness; + if (succeeded(reader.readVarInt(_widthAndSignedness)) && + ((width = _widthAndSignedness >> 2), true) && + ((signedness = static_cast( + _widthAndSignedness & 0x3)), + true)) { + entry = IntegerType::get(ctx, width, signedness); + return; + } + // Fall through and do not assign entry to fallback on the + // standard codepath for parsing types and attributes. + } + return; + }); + auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig); + if (!newModuleOp.get()) { + op->emitError() << "failed to read bytecode\n"; + signalPassFailure(); + } + // Print the module to the output stream, so that we can filecheck the + // result. + newModuleOp->print(llvm::outs()); + return; + } + + // Test2: When writing bytecode, we write standard builtin IntegerTypes. At + // parsing, we use the encoding of IntegerType to intercept all i32. Then, + // instead of creating i32s, we assemble TestI32Type and return it. At the end + // of the pass, we dump the module to standard output, so we can verify that + // all i32 types are now !test.i32. + void runTest2(Operation *op) { + test::TestDialectVersion targetEmissionVersion = targetVersion; + std::string bytecode; + BytecodeWriterConfig writeConfig; + llvm::raw_string_ostream os(bytecode); + if (failed(writeBytecodeToFile(op, os, writeConfig))) { + op->emitError() << "failed to write bytecode\n"; + signalPassFailure(); + return; + } + getContext().loadDialect(); + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + parseConfig.attachAttrTypeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + auto &entry) -> void { + if constexpr (std::is_same_v, Type>) { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || encoding != 0) + return; + llvm::outs() << "Overriding parsing of TestI32Type encoding...\n"; + uint64_t _widthAndSignedness, width; + IntegerType::SignednessSemantics signedness; + if (succeeded(reader.readVarInt(_widthAndSignedness)) && + ((width = _widthAndSignedness >> 2), true) && + ((signedness = static_cast( + _widthAndSignedness & 0x3)), + true)) { + if (width == 32 && + signedness == IntegerType::SignednessSemantics::Signless) { + entry = test::TestI32Type::get(ctx); + return; + } + } + // Fall through and do not assign entry to fallback on the + // standard codepath for parsing types and attributes. + } + return; + }); + auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig); + if (!newModuleOp.get()) { + op->emitError() << "failed to read bytecode\n"; + signalPassFailure(); + } + // Print the module to the output stream, so that we can filecheck the + // result. + newModuleOp->print(llvm::outs()); + return; + } +}; +} // namespace + +namespace mlir { +void registerTestBytecodeCallbackPasses() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -43,6 +43,7 @@ void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAffineReifyValueBoundsPass(); +void registerTestBytecodeCallbackPasses(); void registerTestDecomposeAffineOpPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); @@ -164,6 +165,7 @@ registerTestDecomposeAffineOpPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass(); + registerTestBytecodeCallbackPasses(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); registerTestLoopPermutationPass();