diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -48,6 +48,41 @@ /// Get the set desired bytecode version to emit. int64_t getDesiredBytecodeVersion() const; + //===--------------------------------------------------------------------===// + // Types and Attributes encoding + //===--------------------------------------------------------------------===// + + /// Retrieve the callbacks. + llvm::SmallVector>> & + getAttributePrinterCallbacks() const; + llvm::SmallVector>> & + getTypePrinterCallbacks() const; + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings. + void attachAttributeCallback( + std::unique_ptr> callback); + void attachTypeCallback( + std::unique_ptr> callback); + + /// Attach a custom bytecode printer callback to the configuration for the + /// emission of custom type/attributes encodings. + template + std::enable_if_t>> + attachAttributeCallback(CallableT &&emitFn) { + attachAttributeCallback(AsmAttrTypeBytecodePrinter::fromCallable( + std::forward(emitFn))); + } + template + std::enable_if_t>> + attachTypeCallback(CallableT &&emitFn) { + attachTypeCallback(AsmAttrTypeBytecodePrinter::fromCallable( + std::forward(emitFn))); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -25,12 +25,90 @@ namespace mlir { class AsmResourcePrinter; class AsmDialectResourceHandle; +class DialectBytecodeWriter; +class DialectBytecodeReader; +class DialectVersion; class Operation; namespace detail { class AsmStateImpl; } // namespace detail +//===----------------------------------------------------------------------===// +// AsmAttrTypeBytecode Parser/Printer +//===----------------------------------------------------------------------===// + +/// A class to interact with the attributes and types printer when emitting MLIR +/// bytecode. +template +class AsmAttrTypeBytecodePrinter { +public: + AsmAttrTypeBytecodePrinter() = default; + virtual ~AsmAttrTypeBytecodePrinter() = default; + + virtual LogicalResult write(T entry, DialectBytecodeWriter &writer) = 0; + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of the `write` function above. + template >, + bool> = true> + static std::unique_ptr> + fromCallable(CallableT &&writeFn) { + struct Processor : public AsmAttrTypeBytecodePrinter { + Processor(CallableT &&writeFn) + : AsmAttrTypeBytecodePrinter(), _writeFn(std::move(writeFn)) {} + LogicalResult write(T entry, DialectBytecodeWriter &writer) override { + return _writeFn(entry, writer); + } + + std::decay_t _writeFn; + }; + return std::make_unique(std::forward(writeFn)); + } +}; + +/// A class to interact with the attributes and types parser when emitting MLIR +/// bytecode. +template +class AsmAttrTypeBytecodeParser { +public: + AsmAttrTypeBytecodeParser() = default; + virtual ~AsmAttrTypeBytecodeParser() = default; + + virtual LogicalResult + parse(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, T &entry) = 0; + + /// Return an Attribute/Type printer implemented via the given callable, whose + /// form should match that of the `parse` function above. + template < + typename CallableT, + std::enable_if_t< + std::is_convertible_v< + CallableT, std::function &, T &)>>, + bool> = true> + static std::unique_ptr> + fromCallable(CallableT &&parseFn) { + struct Processor : public AsmAttrTypeBytecodeParser { + Processor(CallableT &&parseFn) + : AsmAttrTypeBytecodeParser(), _parseFn(std::move(parseFn)) {} + LogicalResult parse(MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + T &entry) override { + return _parseFn(ctx, reader, versionMap, entry); + } + + std::decay_t _parseFn; + }; + return std::make_unique(std::forward(parseFn)); + } +}; + //===----------------------------------------------------------------------===// // Resources //===----------------------------------------------------------------------===// @@ -475,6 +553,45 @@ /// Returns if the parser should verify the IR after parsing. bool shouldVerifyAfterParse() const { return verifyAfterParse; } + /// Returns the callbacks available to the parser. + auto &getAttributeBytecodeCallbacks() const { + return attributeBytecodeParsers; + } + auto &getTypeBytecodeCallbacks() const { return typeBytecodeParsers; } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings. + void attachAttributeBytecodeCallback( + std::unique_ptr> parser) { + attributeBytecodeParsers.emplace_back(std::move(parser)); + } + void attachTypeBytecodeCallback( + std::unique_ptr> parser) { + typeBytecodeParsers.emplace_back(std::move(parser)); + } + + /// Attach a custom bytecode parser callback to the configuration for parsing + /// of custom type/attributes encodings. + template + std::enable_if_t &, Attribute &)>>> + attachAttributeBytecodeCallback(CallableT &&parserFn) { + attachAttributeBytecodeCallback( + AsmAttrTypeBytecodeParser::fromCallable( + std::forward(parserFn))); + } + template + std::enable_if_t &, Type &)>>> + attachTypeBytecodeCallback(CallableT &&parserFn) { + attachTypeBytecodeCallback(AsmAttrTypeBytecodeParser::fromCallable( + std::forward(parserFn))); + } + /// Return the resource parser registered to the given name, or nullptr if no /// parser with `name` is registered. AsmResourceParser *getResourceParser(StringRef name) const { @@ -509,6 +626,10 @@ bool verifyAfterParse; DenseMap> resourceParsers; FallbackAsmResourceMap *fallbackResourceMap; + llvm::SmallVector>> + attributeBytecodeParsers; + llvm::SmallVector>> + typeBytecodeParsers; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.h b/mlir/include/mlir/IR/BuiltinDialectBytecode.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.h @@ -0,0 +1,51 @@ +//===- BuiltinDialectBytecode.h - MLIR Bytecode Implementation --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the builtin dialect bytecode implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_BUILTINDIALECTBYTECODE_H +#define MLIR_IR_BUILTINDIALECTBYTECODE_H + +#include "mlir/Bytecode/BytecodeImplementation.h" + +namespace mlir { +class BuiltinDialect; +class Dialect; + +namespace builtin_dialect_detail { +/// This class implements the bytecode interface for the builtin dialect. +struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { + BuiltinDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + //===--------------------------------------------------------------------===// + // Attributes + + Attribute readAttribute(DialectBytecodeReader &reader) const override; + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override; + + //===--------------------------------------------------------------------===// + // Types + + Type readType(DialectBytecodeReader &reader) const override; + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override; +}; + +/// Add the interfaces necessary for encoding the builtin dialect components in +/// bytecode. +void addBytecodeInterface(BuiltinDialect *dialect); +} // namespace builtin_dialect_detail +} // namespace mlir + +#endif // LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -796,9 +796,10 @@ public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc) + ResourceSectionReader &resourceReader, Location fileLoc, + const ParserConfig &config) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc) {} + fileLoc(fileLoc), parserConfig(config) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -883,6 +884,12 @@ /// A location used for error emission. Location fileLoc; + + /// A map to retrieve parsed dialect versions associated to each dialect name. + llvm::StringMap dialectsVersionMap; + + /// Reference to the parser configuration. + const ParserConfig &parserConfig; }; class DialectReader : public DialectBytecodeReader { @@ -1142,6 +1149,18 @@ return offsetReader.emitError( "unexpected trailing data in the Attribute/Type offset section"); } + + // Fill up the dialect to dialectVersion map for every dialect version + // available. + for (auto &dialect : dialects) { + EncodingReader encReader(dialect.versionBuffer, fileLoc); + DialectReader dialectReader(*this, stringReader, resourceReader, encReader); + if (failed(dialect.load(dialectReader, fileLoc.getContext()))) + return failure(); + if (dialect.loadedVersion.get()) + dialectsVersionMap.insert({dialect.name, dialect.loadedVersion.get()}); + } + return success(); } @@ -1210,6 +1229,37 @@ DialectReader dialectReader(*this, stringReader, resourceReader, reader); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); + + if constexpr (std::is_same_v) { + // Try parsing with callbacks first if available. + for (const auto &callback : parserConfig.getTypeBytecodeCallbacks()) { + if (failed(callback->parse(fileLoc.getContext(), dialectReader, + dialectsVersionMap, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } else { + // Try parsing with callbacks first if available. + for (const auto &callback : parserConfig.getAttributeBytecodeCallbacks()) { + if (failed(callback->parse(fileLoc.getContext(), dialectReader, + dialectsVersionMap, entry.entry))) + return failure(); + // Early return if parsing was successful. + if (!!entry.entry) + return success(); + + // Reset the reader if we failed to parse, so we can fall through the + // other parsing functions. + reader = EncodingReader(entry.data, reader.getLoc()); + } + } + // Ensure that the dialect implements the bytecode interface. if (!entry.dialect->interface) { return reader.emitError("dialect '", entry.dialect->name, @@ -1252,7 +1302,7 @@ llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc), + attrTypeReader(stringReader, resourceReader, fileLoc, config), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1519,6 +1569,7 @@ /// The table of IR units referenced within the bytecode file. SmallVector dialects; + llvm::StringMap dialectMap; SmallVector opNames; /// The reader used to process resources within the bytecode. diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -18,15 +18,10 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include "llvm/Support/Endian.h" -#include -#include -#include +#include "llvm/Support/raw_ostream.h" #include -#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -47,6 +42,12 @@ /// The producer of the bytecode. StringRef producer; + /// Printer callbacks used to emit custom type and attribute encodings. + llvm::SmallVector>> + attributePrinterCallbacks; + llvm::SmallVector>> + typePrinterCallbacks; + /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; }; @@ -60,6 +61,26 @@ } BytecodeWriterConfig::~BytecodeWriterConfig() = default; +llvm::SmallVector>> & +BytecodeWriterConfig::getAttributePrinterCallbacks() const { + return impl->attributePrinterCallbacks; +} + +llvm::SmallVector>> & +BytecodeWriterConfig::getTypePrinterCallbacks() const { + return impl->typePrinterCallbacks; +} + +void BytecodeWriterConfig::attachAttributeCallback( + std::unique_ptr> callback) { + impl->attributePrinterCallbacks.emplace_back(std::move(callback)); +} + +void BytecodeWriterConfig::attachTypeCallback( + std::unique_ptr> callback) { + impl->typePrinterCallbacks.emplace_back(std::move(callback)); +} + void BytecodeWriterConfig::attachResourcePrinter( std::unique_ptr printer) { impl->externalResourcePrinters.emplace_back(std::move(printer)); @@ -767,23 +788,41 @@ auto emitAttrOrType = [&](auto &entry) { auto entryValue = entry.getValue(); - // First, try to emit this entry using the dialect bytecode interface. bool hasCustomEncoding = false; - if (const BytecodeDialectInterface *interface = entry.dialect->interface) { + // TODO: We don't currently support custom encoded mutable types and + // attributes. + if (!entryValue.template hasTrait() && + !entryValue.template hasTrait()) { // The writer used when emitting using a custom bytecode encoding. DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, numberingState, stringSection); - if constexpr (std::is_same_v, Type>) { - // TODO: We don't currently support custom encoded mutable types. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeType(entryValue, dialectWriter)); + for (const auto &callback : config.typePrinterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + hasCustomEncoding = true; + } } else { - // TODO: We don't currently support custom encoded mutable attributes. - hasCustomEncoding = - !entryValue.template hasTrait() && - succeeded(interface->writeAttribute(entryValue, dialectWriter)); + for (const auto &callback : config.attributePrinterCallbacks) { + if (succeeded(callback->write(entryValue, dialectWriter))) + hasCustomEncoding = true; + } + } + + if (!hasCustomEncoding) { + if (const BytecodeDialectInterface *interface = + entry.dialect->interface) { + // The writer used when emitting using a custom bytecode encoding. + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, + numberingState, stringSection); + if constexpr (std::is_same_v, + Type>) { + hasCustomEncoding = + succeeded(interface->writeType(entryValue, dialectWriter)); + } else { + hasCustomEncoding = + succeeded(interface->writeAttribute(entryValue, dialectWriter)); + } + } } } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -200,9 +200,16 @@ // If this attribute will be emitted using the bytecode format, perform a // dummy writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable attributes right now. - if (!attr.hasTrait()) { + // TODO: We don't allow custom encodings for mutable attributes right now. + if (!attr.hasTrait()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getAttributePrinterCallbacks()) { + NumberingDialectWriter writer(*this); + if (succeeded(callback->write(attr, writer))) + return; + } + + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeAttribute(attr, writer))) return; @@ -350,9 +357,18 @@ // If this type will be emitted using the bytecode format, perform a dummy // writing to number any nested components. - if (const auto *interface = numbering->dialect->interface) { - // TODO: We don't allow custom encodings for mutable types right now. - if (!type.hasTrait()) { + // TODO: We don't allow custom encodings for mutable types right now. + if (!type.hasTrait()) { + // Try overriding emission with callbacks. + for (const auto &callback : config.getTypePrinterCallbacks()) { + NumberingDialectWriter writer(*this); + if (succeeded(callback->write(type, writer))) + return; + } + + // If this attribute will be emitted using the bytecode format, perform a + // dummy writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { NumberingDialectWriter writer(*this); if (succeeded(interface->writeType(type, writer))) return; diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -12,8 +12,8 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinDialect.h" -#include "BuiltinDialectBytecode.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinDialectBytecode.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" diff --git a/mlir/lib/IR/BuiltinDialectBytecode.h b/mlir/lib/IR/BuiltinDialectBytecode.h deleted file mode 100644 --- a/mlir/lib/IR/BuiltinDialectBytecode.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- BuiltinDialectBytecode.h - MLIR Bytecode Implementation --*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header defines hooks into the builtin dialect bytecode implementation. -// -//===----------------------------------------------------------------------===// - -#ifndef LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H -#define LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H - -namespace mlir { -class BuiltinDialect; - -namespace builtin_dialect_detail { -/// Add the interfaces necessary for encoding the builtin dialect components in -/// bytecode. -void addBytecodeInterface(BuiltinDialect *dialect); -} // namespace builtin_dialect_detail -} // namespace mlir - -#endif // LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "BuiltinDialectBytecode.h" +#include "mlir/IR/BuiltinDialectBytecode.h" #include "AttributeDetail.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/IR/BuiltinAttributes.h" @@ -81,37 +81,38 @@ #include "mlir/IR/BuiltinDialectBytecode.cpp.inc" -/// This class implements the bytecode interface for the builtin dialect. -struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { - BuiltinDialectBytecodeInterface(Dialect *dialect) - : BytecodeDialectInterface(dialect) {} +} // namespace - //===--------------------------------------------------------------------===// - // Attributes +namespace mlir { +namespace builtin_dialect_detail { +//===--------------------------------------------------------------------===// +// Attributes - Attribute readAttribute(DialectBytecodeReader &reader) const override { - return ::readAttribute(getContext(), reader); - } +Attribute BuiltinDialectBytecodeInterface::readAttribute( + DialectBytecodeReader &reader) const { + return ::readAttribute(getContext(), reader); +} - LogicalResult writeAttribute(Attribute attr, - DialectBytecodeWriter &writer) const override { - return ::writeAttribute(attr, writer); - } +LogicalResult BuiltinDialectBytecodeInterface::writeAttribute( + Attribute attr, DialectBytecodeWriter &writer) const { + return ::writeAttribute(attr, writer); +} - //===--------------------------------------------------------------------===// - // Types +//===--------------------------------------------------------------------===// +// Types - Type readType(DialectBytecodeReader &reader) const override { - return ::readType(getContext(), reader); - } +Type BuiltinDialectBytecodeInterface::readType( + DialectBytecodeReader &reader) const { + return ::readType(getContext(), reader); +} - LogicalResult writeType(Type type, - DialectBytecodeWriter &writer) const override { - return ::writeType(type, writer); - } -}; -} // namespace +LogicalResult BuiltinDialectBytecodeInterface::writeType( + Type type, DialectBytecodeWriter &writer) const { + return ::writeType(type, writer); +} -void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) { +void addBytecodeInterface(BuiltinDialect *dialect) { dialect->addInterfaces(); } +} // namespace builtin_dialect_detail +} // namespace mlir diff --git a/mlir/test/Bytecode/bytecode_callback.mlir b/mlir/test/Bytecode/bytecode_callback.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=1.2" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_1_2 +// RUN: mlir-opt %s --test-bytecode-callback="test-dialect-version=2.0" -verify-diagnostics | FileCheck %s --check-prefix=VERSION_2_0 + +func.func @base_test(%arg0 : i32) -> f32 { + %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %1 = "test.cast"(%0) : (i32) -> f32 + return %1 : f32 +} + +// VERSION_1_2: Overriding IntegerType encoding... +// VERSION_1_2: Overriding parsing of IntegerType encoding... + +// VERSION_2_0-NOT: Overriding IntegerType encoding... +// VERSION_2_0-NOT: Overriding parsing of IntegerType encoding... diff --git a/mlir/test/Bytecode/bytecode_callback_full_override.mlir b/mlir/test/Bytecode/bytecode_callback_full_override.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_full_override.mlir @@ -0,0 +1,18 @@ +// RUN: not mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=5" 2>&1 | FileCheck %s + +// CHECK-NOT: failed to read bytecode +func.func @base_test(%arg0 : i32) -> f32 { + %0 = "test.addi"(%arg0, %arg0) : (i32, i32) -> i32 + %1 = "test.cast"(%0) : (i32) -> f32 + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: error: unknown attribute code: 99 +// CHECK: failed to read bytecode +func.func @base_test(%arg0 : !test.i32) -> f32 { + %0 = "test.addi"(%arg0, %arg0) : (!test.i32, !test.i32) -> !test.i32 + %1 = "test.cast"(%0) : (!test.i32) -> f32 + return %1 : f32 +} diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_with_custom_attribute.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=3" | FileCheck %s --check-prefix=TEST_3 +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=4" | FileCheck %s --check-prefix=TEST_4 + +"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> () + +// TEST_3: Overriding TestAttrParamsAttr encoding... +// TEST_3: Parsing builtin DenseIntElementsAttr encoding... +// TEST_3: "test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> () + +// ----- + +"test.versionedC"() <{attribute = dense<[42, 24]> : tensor<2xi32>}> : () -> () + +// TEST_4: Overriding parsing of TestAttrParamsAttr encoding... +// TEST_4: "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> () diff --git a/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_with_custom_type.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=1" | FileCheck %s --check-prefix=TEST_1 +// RUN: mlir-opt %s -split-input-file --test-bytecode-callback="callback-test=2" | FileCheck %s --check-prefix=TEST_2 + +func.func @base_test(%arg0: !test.i32, %arg1: f32) { + return +} + +// TEST_1: Overriding TestI32Type encoding... +// TEST_1: Parsing builtin IntegerType encoding... +// TEST_1: func.func @base_test([[ARG0:%.+]]: i32, [[ARG1:%.+]]: f32) { + +// ----- + +func.func @base_test(%arg0: i32, %arg1: f32) { + return +} + +// TEST_2: Overriding parsing of TestI32Type encoding... +// TEST_2: func.func @base_test([[ARG0:%.+]]: !test.i32, [[ARG1:%.+]]: f32) { diff --git a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir --- a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir +++ b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir @@ -5,12 +5,12 @@ // Index //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX +// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=INDEX // INDEX: invalid Attribute index: 3 //===--------------------------------------------------------------------===// // Trailing Data //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA +// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA // TRAILING_DATA: trailing characters found after Attribute assembly format: trailing diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,9 +14,10 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H -#include "TestTypes.h" #include "TestAttributes.h" #include "TestInterfaces.h" +#include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -57,6 +58,19 @@ #include "TestOpsDialect.h.inc" namespace test { + +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion : public mlir::DialectVersion { + TestDialectVersion() = default; + TestDialectVersion(uint32_t _major, uint32_t _minor) + : major(_major), minor(_minor){}; + uint32_t major = 2; + uint32_t minor = 0; +}; + // Define some classes to exercises the Properties feature. struct PropertiesWithCustomPrint { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -10,7 +10,6 @@ #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" -#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -119,15 +118,6 @@ registry.insert(); } -//===----------------------------------------------------------------------===// -// TestDialect version utilities -//===----------------------------------------------------------------------===// - -struct TestDialectVersion : public DialectVersion { - uint32_t major = 2; - uint32_t minor = 0; -}; - //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -152,7 +142,7 @@ }; namespace { -enum test_encoding { k_attr_params = 0 }; +enum test_encoding { k_attr_params = 0, k_test_i32 = 99 }; } // Test support for interacting with the Bytecode reader/writer. @@ -161,6 +151,25 @@ TestBytecodeDialectInterface(Dialect *dialect) : BytecodeDialectInterface(dialect) {} + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const final { + if (auto concreteType = llvm::dyn_cast(type)) { + writer.writeVarInt(test_encoding::k_test_i32); + return success(); + } + return failure(); + } + + Type readType(DialectBytecodeReader &reader, + const DialectVersion &version_) const final { + uint64_t encoding; + if (failed(reader.readVarInt(encoding))) + return Type(); + if (encoding == test_encoding::k_test_i32) + return TestI32Type::get(getContext()); + return Type(); + } + LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const final { if (auto concreteAttr = llvm::dyn_cast(attr)) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1311,8 +1311,9 @@ } def TestAddIOp : TEST_Op<"addi"> { - let arguments = (ins I32:$op1, I32:$op2); - let results = (outs I32); + let arguments = (ins AnyTypeOf<[I32, TestI32]>:$op1, + AnyTypeOf<[I32, TestI32]>:$op2); + let results = (outs AnyTypeOf<[I32, TestI32]>); } def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { @@ -3315,6 +3316,12 @@ ); } +def TestVersionedOpC : TEST_Op<"versionedC"> { + let arguments = (ins AnyAttrOf<[TestAttrParams, + I32ElementsAttr]>:$attribute + ); +} + //===----------------------------------------------------------------------===// // Test Properties //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -369,4 +369,8 @@ let assemblyFormat = "`<` (`?`) : (struct($a, $b)^)? `>`"; } +def TestI32 : Test_Type<"TestI32"> { + let mnemonic = "i32"; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestBytecodeCallbacks.cpp TestBuiltinAttributeInterfaces.cpp TestBuiltinDistinctAttributes.cpp TestClone.cpp diff --git a/mlir/test/lib/IR/TestBytecodeCallbacks.cpp b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestBytecodeCallbacks.cpp @@ -0,0 +1,386 @@ +//===- TestBytecodeCallbacks.cpp - Pass to test bytecode callback hooks --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Bytecode/BytecodeReader.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinDialectBytecode.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/raw_ostream.h" +#include + +using namespace mlir; +using namespace llvm; + +namespace { +class TestDialectVersionParser : public cl::parser { +public: + TestDialectVersionParser(cl::Option &O) + : cl::parser(O) {} + + bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg, + test::TestDialectVersion &v) { + long long major, minor; + if (getAsSignedInteger(arg.split(".").first, 10, major)) + return O.error("Invalid argument '" + arg); + if (getAsSignedInteger(arg.split(".").second, 10, minor)) + return O.error("Invalid argument '" + arg); + v = test::TestDialectVersion(major, minor); + // Returns true on error. + return false; + } + static void print(raw_ostream &os, const test::TestDialectVersion &v) { + os << v.major << "." << v.minor; + }; +}; + +/// This is a test pass which uses callbacks to encode attributes and types in a +/// custom fashion. +struct TestBytecodeCallbackPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestBytecodeCallbackPass) + + StringRef getArgument() const final { return "test-bytecode-callback"; } + StringRef getDescription() const final { + return "Test encoding of a dialect type/attributes with a custom callback"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + TestBytecodeCallbackPass() = default; + TestBytecodeCallbackPass(const TestBytecodeCallbackPass &) {} + + void runOnOperation() override { + switch (testKind) { + case (0): + return runTest0(getOperation()); + case (1): + return runTest1(getOperation()); + case (2): + return runTest2(getOperation()); + case (3): + return runTest3(getOperation()); + case (4): + return runTest4(getOperation()); + case (5): + return runTest5(getOperation()); + default: + llvm_unreachable("unhandled test kind for TestBytecodeCallbacks pass"); + } + } + + mlir::Pass::Option + targetVersion{*this, "test-dialect-version", + llvm::cl::desc( + "Specifies the test dialect version to emit and parse"), + cl::init(test::TestDialectVersion())}; + + mlir::Pass::Option testKind{ + *this, "callback-test", + llvm::cl::desc("Specifies the test kind to execute"), cl::init(0)}; + +private: + void doRoundtripWithConfigs(Operation *op, + const BytecodeWriterConfig &writeConfig, + const ParserConfig &parseConfig) { + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + if (failed(writeBytecodeToFile(op, os, writeConfig))) { + op->emitError() << "failed to write bytecode\n"; + signalPassFailure(); + return; + } + auto newModuleOp = parseSourceString(StringRef(bytecode), parseConfig); + if (!newModuleOp.get()) { + op->emitError() << "failed to read bytecode\n"; + signalPassFailure(); + return; + } + // Print the module to the output stream, so that we can filecheck the + // result. + newModuleOp->print(llvm::outs()); + return; + } + + // Test0: let's assume that versions older than 2.0 were relying on a special + // integer attribute of the builtin dialect that is now deprecated. Assume + // that its encoding was made by two varInts, the first was the ID (999) and + // the second contained width and signedness info. We can emit it using a + // callback emitting a custom encoding, and parse it back with a custom parser + // reading the same encoding. Note that the ID 999 does not correspond to a + // valid integer type in the current encodings of builtin types. + void runTest0(Operation *op) { + test::TestDialectVersion targetEmissionVersion = targetVersion; + BytecodeWriterConfig writeConfig; + writeConfig.attachTypeCallback( + [&](Type entryValue, DialectBytecodeWriter &writer) -> LogicalResult { + // Do not override anything if version less than 2.0. + if (targetEmissionVersion.major >= 2) + return failure(); + + // For version less than 2.0, override the encoding of IntegerType. + if (auto type = llvm::dyn_cast(entryValue)) { + llvm::outs() << "Overriding IntegerType encoding...\n"; + writer.writeVarInt(/* IntegerType */ 999); + writer.writeVarInt(type.getWidth() << 2 | type.getSignedness()); + return success(); + } + return failure(); + }); + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + parseConfig.attachTypeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) -> LogicalResult { + // Get test dialect version from the version map. + assert( + versionMap.contains("test") && + "expected versionMap to contain all the available version info"); + test::TestDialectVersion &version = + static_cast(*versionMap.at("test")); + + // TODO: once back-deployment is formally supported, + // targetEmissionVersion will be encoded in the bytecode file, and + // exposed through the versionMap. Right now though this is not yet + // supported. For the purpose of the test, just use + // `targetEmissionVersion`. + (void)version; + if (targetEmissionVersion.major >= 2) + return success(); + + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || encoding != 999) + return success(); + llvm::outs() << "Overriding parsing of IntegerType encoding...\n"; + uint64_t _widthAndSignedness, width; + IntegerType::SignednessSemantics signedness; + if (succeeded(reader.readVarInt(_widthAndSignedness)) && + ((width = _widthAndSignedness >> 2), true) && + ((signedness = static_cast( + _widthAndSignedness & 0x3)), + true)) + entry = IntegerType::get(ctx, width, signedness); + // Return nullopt to fall through the rest of the parsing code path. + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test1: When writing bytecode, we override the encoding of TestI32Type with + // the encoding of builtin IntegerType. At parsing, we use such encoding to + // read the type and assemble the builtin IntegerType. + void runTest1(Operation *op) { + auto builtin = op->getContext()->getOrLoadDialect(); + auto iface = builtin->getRegisteredInterface< + mlir::builtin_dialect_detail::BuiltinDialectBytecodeInterface>(); + BytecodeWriterConfig writeConfig; + writeConfig.attachTypeCallback( + [&](Type entryValue, DialectBytecodeWriter &writer) -> LogicalResult { + // Emit TestIntegerType using the builtin dialect encoding. + if (llvm::isa(entryValue)) { + llvm::outs() << "Overriding TestI32Type encoding...\n"; + auto builtinI32Type = + IntegerType::get(op->getContext(), 32, + IntegerType::SignednessSemantics::Signless); + if (succeeded(iface->writeType(builtinI32Type, writer))) + return success(); + } + return failure(); + }); + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + parseConfig.attachTypeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) -> LogicalResult { + // Override only the case where the return type of the builtin reader + // is an i32 and ignore all the other cases, since we want to still + // use TestDialect normal codepath to parse the other types. + Type builtinAttr = iface->readType(reader); + if (auto integerType = + llvm::dyn_cast_or_null(builtinAttr)) { + if (integerType.getWidth() == 32 && integerType.isSignless()) { + llvm::outs() << "Parsing builtin IntegerType encoding...\n"; + entry = builtinAttr; + } + } + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test2: When writing bytecode, we write standard builtin IntegerTypes. At + // parsing, we use the encoding of IntegerType to intercept all i32. Then, + // instead of creating i32s, we assemble TestI32Type and return it. + void runTest2(Operation *op) { + auto builtin = op->getContext()->getOrLoadDialect(); + auto iface = builtin->getRegisteredInterface< + mlir::builtin_dialect_detail::BuiltinDialectBytecodeInterface>(); + BytecodeWriterConfig writeConfig; + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/true); + parseConfig.attachTypeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) -> LogicalResult { + Type builtinAttr = iface->readType(reader); + if (auto integerType = + llvm::dyn_cast_or_null(builtinAttr)) { + if (integerType.getWidth() == 32 && integerType.isSignless()) { + llvm::outs() << "Overriding parsing of TestI32Type encoding...\n"; + entry = test::TestI32Type::get(ctx); + } + } + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test3: When writing bytecode, we override the encoding of + // TestAttrParamsAttr with the encoding of builtin DenseIntElementsAttr. At + // parsing, we use such encoding to read the type and assemble the builtin + // DenseIntElementsAttr. + void runTest3(Operation *op) { + auto i32Type = IntegerType::get(op->getContext(), 32, + IntegerType::SignednessSemantics::Signless); + auto builtin = op->getContext()->getOrLoadDialect(); + auto iface = builtin->getRegisteredInterface< + mlir::builtin_dialect_detail::BuiltinDialectBytecodeInterface>(); + BytecodeWriterConfig writeConfig; + writeConfig.attachAttributeCallback( + [&](Attribute entryValue, + DialectBytecodeWriter &writer) -> LogicalResult { + // Emit TestIntegerType using the builtin dialect encoding. + if (auto testParamAttrs = + llvm::dyn_cast(entryValue)) { + llvm::outs() << "Overriding TestAttrParamsAttr encoding...\n"; + auto denseAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, i32Type), + {testParamAttrs.getV0(), testParamAttrs.getV1()}); + if (succeeded(iface->writeAttribute(denseAttr, writer))) + return success(); + } + return failure(); + }); + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); + parseConfig.attachAttributeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Attribute &entry) -> LogicalResult { + // Override only the case where the return type of the builtin reader + // is an i32 and ignore all the other cases, since we want to still + // use TestDialect normal codepath to parse the other types. + Attribute builtinAttr = iface->readAttribute(reader); + if (auto denseAttr = + llvm::dyn_cast_or_null(builtinAttr)) { + if (denseAttr.getType().getShape() == ArrayRef(2) && + denseAttr.getElementType() == i32Type) { + llvm::outs() + << "Parsing builtin DenseIntElementsAttr encoding...\n"; + entry = denseAttr; + } + } + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test4: When writing bytecode, we write standard builtin + // DenseIntElementsAttr. At parsing, we use the encoding of + // DenseIntElementsAttr to intercept all ElementsAttr that have shaped type of + // <2xi32>. Instead of assembling a DenseIntElementsAttr, we assemble + // TestAttrParamsAttr and return it. + void runTest4(Operation *op) { + auto i32Type = IntegerType::get(op->getContext(), 32, + IntegerType::SignednessSemantics::Signless); + auto builtin = op->getContext()->getOrLoadDialect(); + auto iface = builtin->getRegisteredInterface< + mlir::builtin_dialect_detail::BuiltinDialectBytecodeInterface>(); + BytecodeWriterConfig writeConfig; + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); + parseConfig.attachAttributeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Attribute &entry) -> LogicalResult { + // Override only the case where the return type of the builtin reader + // is an i32 and fall through on all the other cases, since we want to + // still use TestDialect normal codepath to parse the other types. + Attribute builtinAttr = iface->readAttribute(reader); + if (auto denseAttr = + llvm::dyn_cast_or_null(builtinAttr)) { + if (denseAttr.getType().getShape() == ArrayRef(2) && + denseAttr.getElementType() == i32Type) { + llvm::outs() + << "Overriding parsing of TestAttrParamsAttr encoding...\n"; + int v0 = denseAttr.getValues()[0].getInt(); + int v1 = denseAttr.getValues()[1].getInt(); + entry = test::TestAttrParamsAttr::get(ctx, v0, v1); + } + } + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } + + // Test5: When writing bytecode, we want TestDialect to use nothing else than + // the builtin types and attributes and take full control of the encoding, + // returning failure if any type or attribute is not part of builtin. + void runTest5(Operation *op) { + auto builtin = op->getContext()->getOrLoadDialect(); + auto iface = builtin->getRegisteredInterface< + mlir::builtin_dialect_detail::BuiltinDialectBytecodeInterface>(); + BytecodeWriterConfig writeConfig; + writeConfig.attachAttributeCallback( + [&](Attribute attr, DialectBytecodeWriter &writer) -> LogicalResult { + return iface->writeAttribute(attr, writer); + }); + writeConfig.attachTypeCallback( + [&](Type type, DialectBytecodeWriter &writer) -> LogicalResult { + return iface->writeType(type, writer); + }); + ParserConfig parseConfig(op->getContext(), /*verifyAfterParse=*/false); + parseConfig.attachAttributeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Attribute &entry) -> LogicalResult { + Attribute builtinAttr = iface->readAttribute(reader); + if (!builtinAttr) + return failure(); + entry = builtinAttr; + return success(); + }); + parseConfig.attachTypeBytecodeCallback( + [&](MLIRContext *ctx, DialectBytecodeReader &reader, + const llvm::StringMap &versionMap, + Type &entry) -> LogicalResult { + Type builtinType = iface->readType(reader); + if (!builtinType) { + return failure(); + } + entry = builtinType; + return success(); + }); + doRoundtripWithConfigs(op, writeConfig, parseConfig); + return; + } +}; +} // namespace + +namespace mlir { +void registerTestBytecodeCallbackPasses() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -43,6 +43,7 @@ void registerRegionTestPasses(); void registerTestAffineDataCopyPass(); void registerTestAffineReifyValueBoundsPass(); +void registerTestBytecodeCallbackPasses(); void registerTestDecomposeAffineOpPass(); void registerTestAffineLoopUnswitchingPass(); void registerTestAllReduceLoweringPass(); @@ -164,6 +165,7 @@ registerTestDecomposeAffineOpPass(); registerTestAffineLoopUnswitchingPass(); registerTestAllReduceLoweringPass(); + registerTestBytecodeCallbackPasses(); registerTestFunc(); registerTestGpuMemoryPromotionPass(); registerTestLoopPermutationPass();