diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -207,7 +207,26 @@ ##### Dialect Defined Encoding -TODO: This is not yet supported. +In addition to the assembly format fallback, dialects may also provide a custom +encoding for their attributes and types. Custom encodings are very beneficial in +that they are significantly smaller and faster to read and write. + +Dialects can opt-in to providing custom encodings by implementing the +`BytecodeDialectInterface`. This interface provides hooks, namely +`readAttribute`/`readType` and `writeAttribute`/`writeType`, that will be used +by the bytecode reader and writer. These hooks are provided a reader and writer +implementation that can be used to encode various constructs in the underlying +bytecode format. A unique feature of this interface is that dialects may choose +to only encode a subset of their attributes and types in a custom bytecode +format, which can simplify adding new or experimental components that aren't +fully baked. + +When implementing the bytecode interface, dialects are responsible for all +aspects of the encoding. This includes the indicator for which kind of attribute +or type is being encoded; the bytecode reader will only know that it has +encountered an attribute or type of a given dialect, it doesn't encode any +further information. As such, a common encoding idiom is to use a leading +`varint` code to indicate how the attribute or type was encoded. ### IR Section diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -0,0 +1,220 @@ +//===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines various interfaces and utilities necessary for dialects +// to hook into bytecode serialization. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H +#define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/Twine.h" + +namespace mlir { +//===----------------------------------------------------------------------===// +// DialectBytecodeReader +//===----------------------------------------------------------------------===// + +/// This class defines a virtual interface for reading a bytecode stream, +/// providing hooks into the bytecode reader. As such, this class should only be +/// derived and defined by the main bytecode reader, users (i.e. dialects) +/// should generally only interact with this class via the +/// BytecodeDialectInterface below. +class DialectBytecodeReader { +public: + virtual ~DialectBytecodeReader() = default; + + /// Emit an error to the reader. + virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0; + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + /// Read out a list of elements, invoking the provided callback for each + /// element. The callback function may be in any of the following forms: + /// * LogicalResult(T &) + /// * FailureOr() + template + LogicalResult readList(SmallVectorImpl &result, CallbackFn &&callback) { + uint64_t size; + if (failed(readVarInt(size))) + return failure(); + result.reserve(size); + + for (uint64_t i = 0; i < size; ++i) { + // Check if the callback uses FailureOr, or populates the result by + // reference. + if constexpr (llvm::function_traits>::num_args) { + T element = {}; + if (failed(callback(element))) + return failure(); + result.emplace_back(std::move(element)); + } else { + FailureOr element = callback(); + if (failed(element)) + return failure(); + result.emplace_back(std::move(*element)); + } + } + return success(); + } + + /// Read a reference to the given attribute. + virtual LogicalResult readAttribute(Attribute &result) = 0; + template + LogicalResult readAttributes(SmallVectorImpl &attrs) { + return readList(attrs, [this](T &attr) { return readAttribute(attr); }); + } + template + LogicalResult parseAttribute(T &result) { + Attribute baseResult; + if (failed(parseAttribute(baseResult))) + return failure(); + if ((result = baseResult.dyn_cast())) + return success(); + return emitError() << "expected attribute of type: " + << llvm::getTypeName() << ", but got: " << baseResult; + } + + /// Read a reference to the given type. + virtual LogicalResult readType(Type &result) = 0; + template + LogicalResult readTypes(SmallVectorImpl &types) { + return readList(types, [this](T &type) { return readType(type); }); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + /// Read a variable width integer. + // TODO: Add a signed variant when necessary. + virtual LogicalResult readVarInt(uint64_t &result) = 0; + + /// Read a string from the bytecode. + virtual LogicalResult readString(StringRef &result) = 0; +}; + +//===----------------------------------------------------------------------===// +// DialectBytecodeWriter +//===----------------------------------------------------------------------===// + +/// This class defines a virtual interface for writing to a bytecode stream, +/// providing hooks into the bytecode writer. As such, this class should only be +/// derived and defined by the main bytecode writer, users (i.e. dialects) +/// should generally only interact with this class via the +/// BytecodeDialectInterface below. +class DialectBytecodeWriter { +public: + virtual ~DialectBytecodeWriter() = default; + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + /// Write out a list of elements, invoking the provided callback for each + /// element. + template + void writeList(RangeT &&range, CallbackFn &&callback) { + writeVarInt(llvm::size(range)); + for (auto &element : range) + callback(element); + } + + /// Write a reference to the given attribute. + virtual void writeAttribute(Attribute attr) = 0; + template + void writeAttributes(ArrayRef attrs) { + writeList(attrs, [this](T attr) { writeAttribute(attr); }); + } + + /// Write a reference to the given type. + virtual void writeType(Type type) = 0; + template + void writeTypes(ArrayRef types) { + writeList(types, [this](T type) { writeType(type); }); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + /// Write a variable width integer to the output stream. This should be the + /// preferred method for emitting integers whenever possible. + // TODO: Add a signed variant when necessary. + virtual void writeVarInt(uint64_t value) = 0; + + /// Write a string to the bytecode, which is owned by the caller and is + /// guaranteed to not die before the end of the bytecode process. This should + /// only be called if such a guarantee can be made, such as when the string is + /// owned by an attribute or type. + virtual void writeOwnedString(StringRef str) = 0; +}; + +//===----------------------------------------------------------------------===// +// BytecodeDialectInterface +//===----------------------------------------------------------------------===// + +class BytecodeDialectInterface + : public DialectInterface::Base { +public: + using Base::Base; + + //===--------------------------------------------------------------------===// + // Reading + //===--------------------------------------------------------------------===// + + /// Read an attribute belonging to this dialect from the given reader. This + /// method should return null in the case of failure. + virtual Attribute readAttribute(DialectBytecodeReader &reader) const { + reader.emitError() << "dialect " << getDialect()->getNamespace() + << " does not support reading attributes from bytecode"; + return Attribute(); + } + + /// Read a type belonging to this dialect from the given reader. This method + /// should return null in the case of failure. + virtual Type readType(DialectBytecodeReader &reader) const { + reader.emitError() << "dialect " << getDialect()->getNamespace() + << " does not support reading types from bytecode"; + return Type(); + } + + //===--------------------------------------------------------------------===// + // Writing + //===--------------------------------------------------------------------===// + + /// Write the given attribute, which belongs to this dialect, to the given + /// writer. This method may return failure to indicate that the given + /// attribute could not be encoded, in which case the textual format will be + /// used to encode this attribute instead. + virtual LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const { + return failure(); + } + + /// Write the given type, which belongs to this dialect, to the given writer. + /// This method may return failure to indicate that the given type could not + /// be encoded, in which case the textual format will be used to encode this + /// type instead. + virtual LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const { + return failure(); + } +}; + +} // namespace mlir + +#endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H diff --git a/mlir/include/mlir/IR/DialectInterface.h b/mlir/include/mlir/IR/DialectInterface.h --- a/mlir/include/mlir/IR/DialectInterface.h +++ b/mlir/include/mlir/IR/DialectInterface.h @@ -50,6 +50,9 @@ /// Return the dialect that this interface represents. Dialect *getDialect() const { return dialect; } + /// Return the context that holds the parent dialect of this interface. + MLIRContext *getContext() const; + /// Return the derived interface id. TypeID getID() const { return interfaceID; } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -12,6 +12,7 @@ #include "mlir/Bytecode/BytecodeReader.h" #include "../Encoding.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/OpImplementation.h" @@ -66,7 +67,7 @@ /// Emit an error using the given arguments. template - LogicalResult emitError(Args &&...args) const { + InFlightDiagnostic emitError(Args &&...args) const { return ::emitError(fileLoc).append(std::forward(args)...); } @@ -326,6 +327,11 @@ "-allow-unregistered-dialect with the MLIR tool used."); } dialect = loadedDialect; + + // If the dialect was actually loaded, check to see if it has a bytecode + // interface. + if (loadedDialect) + interface = dyn_cast(loadedDialect); return success(); } @@ -333,6 +339,11 @@ /// load, nullptr if we failed to load, otherwise the loaded dialect. Optional dialect; + /// The bytecode interface of the dialect, or nullptr if the dialect does not + /// implement the bytecode interface. This field should only be checked if the + /// `dialect` field is non-None. + const BytecodeDialectInterface *interface = nullptr; + /// The name of the dialect. StringRef name; }; @@ -397,7 +408,8 @@ using TypeEntry = Entry; public: - AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {} + AttrTypeReader(StringSectionReader &stringReader, Location fileLoc) + : stringReader(stringReader), fileLoc(fileLoc) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -456,6 +468,10 @@ LogicalResult parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType); + /// The string section reader used to resolve string references when parsing + /// custom encoded attribute/type entries. + StringSectionReader &stringReader; + /// The set of attribute and type entries. SmallVector attributes; SmallVector types; @@ -463,6 +479,47 @@ /// A location used for error emission. Location fileLoc; }; + +class DialectReader : public DialectBytecodeReader { +public: + DialectReader(AttrTypeReader &attrTypeReader, + StringSectionReader &stringReader, EncodingReader &reader) + : attrTypeReader(attrTypeReader), stringReader(stringReader), + reader(reader) {} + + InFlightDiagnostic emitError(const Twine &msg) override { + return reader.emitError(msg); + } + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + LogicalResult readAttribute(Attribute &result) override { + return attrTypeReader.parseAttribute(reader, result); + } + + LogicalResult readType(Type &result) override { + return attrTypeReader.parseType(reader, result); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + LogicalResult readVarInt(uint64_t &result) override { + return reader.parseVarInt(result); + } + + LogicalResult readString(StringRef &result) override { + return stringReader.parseString(reader, result); + } + +private: + AttrTypeReader &attrTypeReader; + StringSectionReader &stringReader; + EncodingReader &reader; +}; } // namespace LogicalResult @@ -486,7 +543,7 @@ size_t currentIndex = 0, endIndex = range.size(); // Parse an individual entry. - auto parseEntryFn = [&](BytecodeDialect *dialect) { + auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { auto &entry = range[currentIndex++]; uint64_t entrySize; @@ -548,8 +605,7 @@ } if (!reader.empty()) { - (void)reader.emitError("unexpected trailing bytes after " + entryType + - " entry"); + reader.emitError("unexpected trailing bytes after " + entryType + " entry"); return T(); } return entry.entry; @@ -584,8 +640,22 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - // FIXME: Add support for reading custom attribute/type encodings. - return reader.emitError("unexpected Attribute encoding"); + if (failed(entry.dialect->load(reader, fileLoc.getContext()))) + return failure(); + + // Ensure that the dialect implements the bytecode interface. + if (!entry.dialect->interface) { + return reader.emitError("dialect '", entry.dialect->name, + "' does not implement the bytecode interface"); + } + + // Ask the dialect to parse the entry. + DialectReader dialectReader(*this, stringReader, reader); + if constexpr (std::is_same_v) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); + return success(!!entry.entry); } //===----------------------------------------------------------------------===// @@ -597,7 +667,7 @@ class BytecodeReader { public: BytecodeReader(Location fileLoc, const ParserConfig &config) - : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc), + : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -9,6 +9,7 @@ #include "mlir/Bytecode/BytecodeWriter.h" #include "../Encoding.h" #include "IRNumbering.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/CachedHashString.h" @@ -358,22 +359,78 @@ //===----------------------------------------------------------------------===// // Attributes and Types +namespace { +class DialectWriter : public DialectBytecodeWriter { +public: + DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, + StringSectionBuilder &stringSection) + : emitter(emitter), numberingState(numberingState), + stringSection(stringSection) {} + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + void writeAttribute(Attribute attr) override { + emitter.emitVarInt(numberingState.getNumber(attr)); + } + void writeType(Type type) override { + emitter.emitVarInt(numberingState.getNumber(type)); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } + + void writeOwnedString(StringRef str) override { + emitter.emitVarInt(stringSection.insert(str)); + } + +private: + EncodingEmitter &emitter; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; +}; +} // namespace + void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { EncodingEmitter attrTypeEmitter; EncodingEmitter offsetEmitter; offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes())); offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes())); + // The writer used when emitting using a custom bytecode encoding. + DialectWriter dialectWriter(attrTypeEmitter, numberingState, stringSection); + // A functor used to emit an attribute or type entry. uint64_t prevOffset = 0; auto emitAttrOrType = [&](auto &entry) { - // TODO: Allow dialects to provide more optimal implementations of attribute - // and type encodings. + auto entryValue = entry.getValue(); + + // First, try to emit this entry using the dialect bytecode interface. bool hasCustomEncoding = false; + if (const BytecodeDialectInterface *interface = entry.dialect->interface) { + if constexpr (std::is_same_v, Type>) { + // TODO: We don't currently support custom encoded mutable types. + hasCustomEncoding = + !entryValue.template hasTrait() && + succeeded(interface->writeType(entryValue, dialectWriter)); + } else { + // TODO: We don't currently support custom encoded mutable attributes. + hasCustomEncoding = + !entryValue.template hasTrait() && + succeeded(interface->writeAttribute(entryValue, dialectWriter)); + } + } - // Emit the entry using the textual format. - raw_emitter_ostream(attrTypeEmitter) << entry.getValue(); - attrTypeEmitter.emitByte(0); + // If the entry was not emitted using the dialect interface, emit it using + // the textual format. + if (!hasCustomEncoding) { + raw_emitter_ostream(attrTypeEmitter) << entryValue; + attrTypeEmitter.emitByte(0); + } // Record the offset of this entry. uint64_t curOffset = attrTypeEmitter.size(); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -18,6 +18,7 @@ #include "llvm/ADT/MapVector.h" namespace mlir { +class BytecodeDialectInterface; class BytecodeWriterConfig; namespace bytecode { @@ -90,8 +91,8 @@ /// The number assigned to the dialect. unsigned number; - /// The loaded dialect, or nullptr if the dialect isn't loaded. - Dialect *dialect = nullptr; + /// The bytecode dialect interface of the dialect if defined. + const BytecodeDialectInterface *interface = nullptr; }; //===----------------------------------------------------------------------===// @@ -147,6 +148,10 @@ } private: + /// This class is used to provide a fake dialect writer for numbering nested + /// attributes and types. + struct NumberingDialectWriter; + /// Number the given IR unit for bytecode emission. void number(Attribute attr); void number(Block &block); diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "IRNumbering.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -14,6 +15,28 @@ using namespace mlir; using namespace mlir::bytecode::detail; +//===----------------------------------------------------------------------===// +// NumberingDialectWriter +//===----------------------------------------------------------------------===// + +struct IRNumberingState::NumberingDialectWriter : public DialectBytecodeWriter { + NumberingDialectWriter(IRNumberingState &state) : state(state) {} + + void writeAttribute(Attribute attr) override { state.number(attr); } + void writeType(Type type) override { state.number(type); } + + /// Stubbed out methods that are not used for numbering. + void writeVarInt(uint64_t) override {} + void writeOwnedString(StringRef) override { + // TODO: It might be nice to prenumber strings and sort by the number of + // references. This could potentially be useful for optimizing things like + // file locations. + } + + /// The parent numbering state that is populated by this writer. + IRNumberingState &state; +}; + //===----------------------------------------------------------------------===// // IR Numbering //===----------------------------------------------------------------------===// @@ -138,10 +161,22 @@ // have a registered dialect when it got created. We don't want to encode this // as the builtin OpaqueAttr, we want to encode it as if the dialect was // actually loaded. - if (OpaqueAttr opaqueAttr = attr.dyn_cast()) + if (OpaqueAttr opaqueAttr = attr.dyn_cast()) { numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); - else - numbering->dialect = &numberDialect(&attr.getDialect()); + return; + } + numbering->dialect = &numberDialect(&attr.getDialect()); + + // If this attribute will be emitted using the bytecode format, perform a + // dummy writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { + // TODO: We don't allow custom encodings for mutable attributes right now. + if (attr.hasTrait()) + return; + + NumberingDialectWriter writer(*this); + (void)interface->writeAttribute(attr, writer); + } } void IRNumberingState::number(Block &block) { @@ -164,7 +199,7 @@ DialectNumbering *&numbering = registeredDialects[dialect]; if (!numbering) { numbering = &numberDialect(dialect->getNamespace()); - numbering->dialect = dialect; + numbering->interface = dyn_cast(dialect); } return *numbering; } @@ -244,8 +279,20 @@ // registered dialect when it got created. We don't want to encode this as the // builtin OpaqueType, we want to encode it as if the dialect was actually // loaded. - if (OpaqueType opaqueType = type.dyn_cast()) + if (OpaqueType opaqueType = type.dyn_cast()) { numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); - else - numbering->dialect = &numberDialect(&type.getDialect()); + return; + } + numbering->dialect = &numberDialect(&type.getDialect()); + + // If this type will be emitted using the bytecode format, perform a dummy + // writing to number any nested components. + if (const auto *interface = numbering->dialect->interface) { + // TODO: We don't allow custom encodings for mutable types right now. + if (type.hasTrait()) + return; + + NumberingDialectWriter writer(*this); + (void)interface->writeType(type, writer); + } } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinDialect.h" +#include "BuiltinDialectBytecode.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -117,6 +118,7 @@ auto &blobInterface = addInterface(); addInterface(blobInterface); + builtin_dialect_detail::addBytecodeInterface(this); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialectBytecode.h b/mlir/lib/IR/BuiltinDialectBytecode.h new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/BuiltinDialectBytecode.h @@ -0,0 +1,26 @@ +//===- BuiltinDialectBytecode.h - MLIR Bytecode Implementation --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the builtin dialect bytecode implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H +#define LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H + +namespace mlir { +class BuiltinDialect; + +namespace builtin_dialect_detail { +/// Add the interfaces necessary for encoding the builtin dialect components in +/// bytecode. +void addBytecodeInterface(BuiltinDialect *dialect); +} // namespace builtin_dialect_detail +} // namespace mlir + +#endif // LIB_MLIR_IR_BUILTINDIALECTBYTECODE_H diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -0,0 +1,269 @@ +//===- BuiltinDialectBytecode.cpp - Builtin Bytecode Implementation -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "BuiltinDialectBytecode.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Encoding +//===----------------------------------------------------------------------===// + +namespace { +namespace builtin_encoding { +/// This enum contains marker codes used to indicate which attribute is +/// currently being decoded, and how it should be decoded. The order of these +/// codes should generally be unchanged, as any changes will inevitably break +/// compatibility with older bytecode. +enum AttributeCode { + /// ArrayAttr { + /// elements: Attribute[] + /// } + /// + kArrayAttr = 0, + + /// DictionaryAttr { + /// attrs: [] + /// } + kDictionaryAttr = 1, + + /// StringAttr { + /// string + /// } + kStringAttr = 2, +}; + +/// This enum contains marker codes used to indicate which type is currently +/// being decoded, and how it should be decoded. The order of these codes should +/// generally be unchanged, as any changes will inevitably break compatibility +/// with older bytecode. +enum TypeCode { + /// IntegerType { + /// widthAndSignedness: varint // (width << 2) | (signedness) + /// } + /// + kIntegerType = 0, + + /// IndexType { + /// } + /// + kIndexType = 1, + + /// FunctionType { + /// inputs: Type[], + /// results: Type[] + /// } + /// + kFunctionType = 2, +}; + +} // namespace builtin_encoding +} // namespace + +//===----------------------------------------------------------------------===// +// BuiltinDialectBytecodeInterface +//===----------------------------------------------------------------------===// + +namespace { +/// This class implements the bytecode interface for the builtin dialect. +struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { + BuiltinDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + //===--------------------------------------------------------------------===// + // Attributes + + Attribute readAttribute(DialectBytecodeReader &reader) const override; + ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const; + DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const; + StringAttr readStringAttr(DialectBytecodeReader &reader) const; + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override; + void write(ArrayAttr attr, DialectBytecodeWriter &writer) const; + void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const; + void write(StringAttr attr, DialectBytecodeWriter &writer) const; + + //===--------------------------------------------------------------------===// + // Types + + Type readType(DialectBytecodeReader &reader) const override; + IntegerType readIntegerType(DialectBytecodeReader &reader) const; + FunctionType readFunctionType(DialectBytecodeReader &reader) const; + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override; + void write(IntegerType type, DialectBytecodeWriter &writer) const; + void write(FunctionType type, DialectBytecodeWriter &writer) const; +}; +} // namespace + +void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) { + dialect->addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Attributes: Reader + +Attribute BuiltinDialectBytecodeInterface::readAttribute( + DialectBytecodeReader &reader) const { + uint64_t code; + if (failed(reader.readVarInt(code))) + return Attribute(); + switch (code) { + case builtin_encoding::kArrayAttr: + return readArrayAttr(reader); + case builtin_encoding::kDictionaryAttr: + return readDictionaryAttr(reader); + case builtin_encoding::kStringAttr: + return readStringAttr(reader); + default: + reader.emitError() << "unknown builtin attribute code: " << code; + return Attribute(); + } +} + +ArrayAttr BuiltinDialectBytecodeInterface::readArrayAttr( + DialectBytecodeReader &reader) const { + SmallVector elements; + if (failed(reader.readAttributes(elements))) + return ArrayAttr(); + return ArrayAttr::get(getContext(), elements); +} + +DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr( + DialectBytecodeReader &reader) const { + auto readNamedAttr = [&]() -> FailureOr { + StringAttr name; + Attribute value; + if (failed(reader.readAttribute(name)) || + failed(reader.readAttribute(value))) + return failure(); + return NamedAttribute(name, value); + }; + SmallVector attrs; + if (failed(reader.readList(attrs, readNamedAttr))) + return DictionaryAttr(); + return DictionaryAttr::get(getContext(), attrs); +} + +StringAttr BuiltinDialectBytecodeInterface::readStringAttr( + DialectBytecodeReader &reader) const { + StringRef string; + if (failed(reader.readString(string))) + return StringAttr(); + return StringAttr::get(getContext(), string); +} + +//===----------------------------------------------------------------------===// +// Attributes: Writer + +LogicalResult BuiltinDialectBytecodeInterface::writeAttribute( + Attribute attr, DialectBytecodeWriter &writer) const { + return TypeSwitch(attr) + .Case([&](auto attr) { + write(attr, writer); + return success(); + }) + .Default([&](Attribute) { return failure(); }); +} + +void BuiltinDialectBytecodeInterface::write( + ArrayAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kArrayAttr); + writer.writeAttributes(attr.getValue()); +} + +void BuiltinDialectBytecodeInterface::write( + DictionaryAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kDictionaryAttr); + writer.writeList(attr.getValue(), [&](NamedAttribute attr) { + writer.writeAttribute(attr.getName()); + writer.writeAttribute(attr.getValue()); + }); +} + +void BuiltinDialectBytecodeInterface::write( + StringAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kStringAttr); + writer.writeOwnedString(attr.getValue()); +} + +//===----------------------------------------------------------------------===// +// Types: Reader + +Type BuiltinDialectBytecodeInterface::readType( + DialectBytecodeReader &reader) const { + uint64_t code; + if (failed(reader.readVarInt(code))) + return Type(); + switch (code) { + case builtin_encoding::kIntegerType: + return readIntegerType(reader); + case builtin_encoding::kIndexType: + return IndexType::get(getContext()); + + case builtin_encoding::kFunctionType: + return readFunctionType(reader); + default: + reader.emitError() << "unknown builtin type code: " << code; + return Type(); + } +} + +IntegerType BuiltinDialectBytecodeInterface::readIntegerType( + DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding))) + return IntegerType(); + return IntegerType::get( + getContext(), encoding >> 2, + static_cast(encoding & 0x3)); +} + +FunctionType BuiltinDialectBytecodeInterface::readFunctionType( + DialectBytecodeReader &reader) const { + SmallVector inputs, results; + if (failed(reader.readTypes(inputs)) || failed(reader.readTypes(results))) + return FunctionType(); + return FunctionType::get(getContext(), inputs, results); +} + +//===----------------------------------------------------------------------===// +// Types: Writer + +LogicalResult BuiltinDialectBytecodeInterface::writeType( + Type type, DialectBytecodeWriter &writer) const { + return TypeSwitch(type) + .Case([&](auto type) { + write(type, writer); + return success(); + }) + .Case([&](IndexType) { + return writer.writeVarInt(builtin_encoding::kIndexType), success(); + }) + .Default([&](Type) { return failure(); }); +} + +void BuiltinDialectBytecodeInterface::write( + IntegerType type, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kIntegerType); + writer.writeVarInt((type.getWidth() << 2) | type.getSignedness()); +} + +void BuiltinDialectBytecodeInterface::write( + FunctionType type, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kFunctionType); + writer.writeTypes(type.getInputs()); + writer.writeTypes(type.getResults()); +} diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -8,6 +8,7 @@ BuiltinAttributeInterfaces.cpp BuiltinAttributes.cpp BuiltinDialect.cpp + BuiltinDialectBytecode.cpp BuiltinTypes.cpp BuiltinTypeInterfaces.cpp Diagnostics.cpp diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -113,6 +113,10 @@ DialectInterface::~DialectInterface() = default; +MLIRContext *DialectInterface::getContext() const { + return dialect->getContext(); +} + DialectInterfaceCollectionBase::DialectInterfaceCollectionBase( MLIRContext *ctx, TypeID interfaceKind) { for (auto *dialect : ctx->getLoadedDialects()) { diff --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: s390x- + +// CHECK-LABEL: @TestArray +module @TestArray attributes { + // CHECK: bytecode.array = [unit] + bytecode.array = [unit] +} {} + +// CHECK-LABEL: @TestString +module @TestString attributes { + // CHECK: bytecode.string = "hello" + bytecode.string = "hello" +} {} diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: s390x- + +// CHECK-LABEL: @TestInteger +module @TestInteger attributes { + // CHECK: bytecode.int = i1024, + // CHECK: bytecode.int1 = si32, + // CHECK: bytecode.int2 = ui512 + bytecode.int = i1024, + bytecode.int1 = si32, + bytecode.int2 = ui512 +} {} + +// CHECK-LABEL: @TestIndex +module @TestIndex attributes { + // CHECK: bytecode.index = index + bytecode.index = index +} {} + +// CHECK-LABEL: @TestFunc +module @TestFunc attributes { + // CHECK: bytecode.func = () -> (), + // CHECK: bytecode.func1 = (i1) -> i32 + bytecode.func = () -> (), + bytecode.func1 = (i1) -> (i32) +} {}