diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -287,6 +287,37 @@ } }; +/// Helper for resource handle reading that returns LogicalResult. +template +static LogicalResult readResourceHandle(DialectBytecodeReader &reader, + FailureOr &value, Ts &&...params) { + FailureOr handle = reader.readResourceHandle(); + if (failed(handle)) + return failure(); + if (auto *result = dyn_cast(&*handle)) { + value = std::move(*result); + return success(); + } + return failure(); +} + +/// Helper method that injects context only if needed, this helps unify some of +/// the attribute construction methods. +template +auto get(MLIRContext *context, Ts &&...params) { + // Prefer a direct `get` method if one exists. + if constexpr (llvm::is_detected::value) { + (void)context; + return T::get(std::forward(params)...); + } else if constexpr (llvm::is_detected::value) { + return T::get(context, std::forward(params)...); + } else { + // Otherwise, pass to the base get. + return T::Base::get(context, std::forward(params)...); + } +} + } // namespace mlir #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -0,0 +1,564 @@ +//===-- BuiltinBytecode.td - Builtin bytecode defs ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Builtin bytecode reader/writer definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef BUILTIN_BYTECODE +#define BUILTIN_BYTECODE + +include "mlir/IR/BytecodeBase.td" + +def LocationAttr : AttributeKind; + +def Location : CompositeBytecode { + dag members = (attr + WithGetter<"(LocationAttr)$_attrType", WithType<"LocationAttr", LocationAttr>>:$value + ); + let cBuilder = "Location($_args)"; +} + +def String : + WithParser <"succeeded($_reader.readString($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedString($_getter)", + WithGetter <"$_attrType", + WithType <"StringRef">>>>>; + +// enum AttributeCode { +// /// ArrayAttr { +// /// elements: Attribute[] +// /// } +// /// +// kArrayAttr = 0, +// +def ArrayAttr : DialectAttribute<(attr + Array:$value +)>; + +let cType = "StringAttr" in { +// /// StringAttr { +// /// value: string +// /// } +// kStringAttr = 2, +def StringAttr : DialectAttribute<(attr + String:$value +)> { + let printerPredicate = "$_val.getType().isa()"; +} + +// /// StringAttrWithType { +// /// value: string, +// /// type: Type +// /// } +// /// A variant of StringAttr with a type. +// kStringAttrWithType = 3, +def StringAttrWithType : DialectAttribute<(attr + String:$value, + Type:$type +)> { let printerPredicate = "!$_val.getType().isa()"; } +} + +// /// DictionaryAttr { +// /// attrs: [] +// /// } +// kDictionaryAttr = 1, +def NamedAttribute : CompositeBytecode { + dag members = (attr + StringAttr:$name, + Attribute:$value + ); + let cBuilder = "NamedAttribute($_args)"; +} +def DictionaryAttr : DialectAttribute<(attr + Array:$value +)>; + +// /// FlatSymbolRefAttr { +// /// rootReference: StringAttr +// /// } +// /// A variant of SymbolRefAttr with no leaf references. +// kFlatSymbolRefAttr = 4, +def FlatSymbolRefAttr: DialectAttribute<(attr + StringAttr:$rootReference +)>; + +// /// SymbolRefAttr { +// /// rootReference: StringAttr, +// /// leafReferences: FlatSymbolRefAttr[] +// /// } +// kSymbolRefAttr = 5, +def SymbolRefAttr: DialectAttribute<(attr + StringAttr:$rootReference, + Array:$nestedReferences +)>; + +// /// TypeAttr { +// /// value: Type +// /// } +// kTypeAttr = 6, +def TypeAttr: DialectAttribute<(attr + Type:$value +)>; + +// /// UnitAttr { +// /// } +// kUnitAttr = 7, +def UnitAttr: DialectAttribute<(attr)>; + +// /// IntegerAttr { +// /// type: Type +// /// value: APInt, +// /// } +// kIntegerAttr = 8, +def IntegerAttr: DialectAttribute<(attr + Type:$type, + KnownWidthAPInt<"type">:$value +)> { + let cBuilder = "get<$_resultType>(context, type, *value)"; +} + +// +// /// FloatAttr { +// /// type: FloatType +// /// value: APFloat +// /// } +// kFloatAttr = 9, +defvar FloatType = Type; +def FloatAttr : DialectAttribute<(attr + FloatType:$type, + KnownSemanticsAPFloat<"type">:$value +)> { + let cBuilder = "get<$_resultType>(context, type, *value)"; +} + +// /// CallSiteLoc { +// /// callee: LocationAttr, +// /// caller: LocationAttr +// /// } +// kCallSiteLoc = 10, +def CallSiteLoc : DialectAttribute<(attr + LocationAttr:$callee, + LocationAttr:$caller +)>; + +// /// FileLineColLoc { +// /// filename: StringAttr, +// /// line: varint, +// /// column: varint +// /// } +// kFileLineColLoc = 11, +def FileLineColLoc : DialectAttribute<(attr + StringAttr:$filename, + VarInt:$line, + VarInt:$column +)>; + +let cType = "FusedLoc", + cBuilder = "cast(get(context, $_args))" in { +// /// FusedLoc { +// /// locations: Location[] +// /// } +// kFusedLoc = 12, +def FusedLoc : DialectAttribute<(attr + Array:$locations +)> { + let printerPredicate = "!$_val.getMetadata()"; +} + +// /// FusedLocWithMetadata { +// /// locations: LocationAttr[], +// /// metadata: Attribute +// /// } +// /// A variant of FusedLoc with metadata. +// kFusedLocWithMetadata = 13, +def FusedLocWithMetadata : DialectAttribute<(attr + Array:$locations, + Attribute:$metadata +)> { + let printerPredicate = "$_val.getMetadata()"; +} +} + +// /// NameLoc { +// /// name: StringAttr, +// /// childLoc: LocationAttr +// /// } +// kNameLoc = 14, +def NameLoc : DialectAttribute<(attr + StringAttr:$name, + LocationAttr:$childLoc +)>; + +// /// UnknownLoc { +// /// } +// kUnknownLoc = 15, +def UnknownLoc : DialectAttribute<(attr)>; + +// /// DenseResourceElementsAttr { +// /// type: Type, +// /// handle: ResourceHandle +// /// } +// kDenseResourceElementsAttr = 16, +def DenseResourceElementsAttr : DialectAttribute<(attr + Type:$type, + ResourceHandle<"DenseResourceElementsHandle">:$rawHandle +)> { + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, type, *rawHandle)"; +} + +let cType = "RankedTensorType" in { +// /// RankedTensorType { +// /// shape: svarint[], +// /// elementType: Type, +// /// } +// /// +// kRankedTensorType = 13, +def RankedTensorType : DialectType<(type + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "!$_val.getEncoding()"; +} + +// /// RankedTensorTypeWithEncoding { +// /// encoding: Attribute, +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// Variant of RankedTensorType with an encoding. +// kRankedTensorTypeWithEncoding = 14, +def RankedTensorTypeWithEncoding : DialectType<(type + Attribute:$encoding, + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "$_val.getEncoding()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, shape, elementType, encoding)"; +} +} + +// /// DenseArrayAttr { +// /// elementType: Type, +// /// size: varint +// /// data: blob +// /// } +// kDenseArrayAttr = 17, +def DenseArrayAttr : DialectAttribute<(attr + Type:$elementType, + VarInt:$size, + Blob:$rawData +)>; + +// /// DenseIntOrFPElementsAttr { +// /// type: ShapedType, +// /// data: blob +// /// } +// kDenseIntOrFPElementsAttr = 18, +defvar ShapedType = Type; +def DenseElementsAttr : WithType<"DenseIntElementsAttr", Attribute>; +def DenseIntOrFPElementsAttr : DialectAttribute<(attr + ShapedType:$type, + Blob:$rawData +)> { + let cBuilder = "cast<$_resultType>($_resultType::getFromRawBuffer($_args))"; +} + +// /// DenseStringElementsAttr { +// /// type: ShapedType, +// /// isSplat: varint, +// /// data: string[] +// /// } +// kDenseStringElementsAttr = 19, +def DenseStringElementsAttr : DialectAttribute<(attr + ShapedType:$type, + WithGetter<"$_attrType.isSplat()", VarInt>:$_isSplat, + WithBuilder<"$_args", + WithType<"SmallVector", + WithParser <"succeeded(readPotentiallySplatString($_reader, type, _isSplat, $_var))", + WithPrinter<"writePotentiallySplatString($_writer, $_name)">>>>:$rawStringData +)>; + +// /// SparseElementsAttr { +// /// type: ShapedType, +// /// indices: DenseIntElementsAttr, +// /// values: DenseElementsAttr +// /// } +// kSparseElementsAttr = 20, +def DenseIntElementsAttr : WithType<"DenseIntElementsAttr", Attribute>; +def SparseElementsAttr : DialectAttribute<(attr + ShapedType:$type, + DenseIntElementsAttr:$indices, + DenseElementsAttr:$values +)>; + +// Types +// ----- + +// enum TypeCode { +// /// IntegerType { +// /// widthAndSignedness: varint // (width << 2) | (signedness) +// /// } +// /// +// kIntegerType = 0, +def IntegerType : DialectType<(type + // Yes not pretty, + WithParser<"succeeded($_reader.readVarInt($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeVarInt($_name.getWidth() << 2 | $_name.getSignedness())", + WithType <"uint64_t">>>>:$_widthAndSignedness, + // Split up parsed varint for create method. + LocalVar<"uint64_t", "_widthAndSignedness >> 2">:$width, + LocalVar<"IntegerType::SignednessSemantics", + "static_cast(_widthAndSignedness & 0x3)">:$signedness +)>; + +// +// /// IndexType { +// /// } +// /// +// kIndexType = 1, +def IndexType : DialectType<(type)>; + +// /// FunctionType { +// /// inputs: Type[], +// /// results: Type[] +// /// } +// /// +// kFunctionType = 2, +def FunctionType : DialectType<(type + Array:$inputs, + Array:$results +)>; + +// /// BFloat16Type { +// /// } +// /// +// kBFloat16Type = 3, +def BFloat16Type : DialectType<(type)>; + +// /// Float16Type { +// /// } +// /// +// kFloat16Type = 4, +def Float16Type : DialectType<(type)>; + +// /// Float32Type { +// /// } +// /// +// kFloat32Type = 5, +def Float32Type : DialectType<(type)>; + +// /// Float64Type { +// /// } +// /// +// kFloat64Type = 6, +def Float64Type : DialectType<(type)>; + +// /// Float80Type { +// /// } +// /// +// kFloat80Type = 7, +def Float80Type : DialectType<(type)>; + +// /// Float128Type { +// /// } +// /// +// kFloat128Type = 8, +def Float128Type : DialectType<(type)>; + +// /// ComplexType { +// /// elementType: Type +// /// } +// /// +// kComplexType = 9, +def ComplexType : DialectType<(type + Type:$elementType +)>; + +let cType = "MemRefType" in { +// /// MemRefType { +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// +// kMemRefType = 10, +def MemRefType : DialectType<(type + Array:$shape, + Type:$elementType, + Attribute:$layout +)> { + let printerPredicate = "!$_val.getMemorySpace()"; +} + +// /// MemRefTypeWithMemSpace { +// /// memorySpace: Attribute, +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// Variant of MemRefType with non-default memory space. +// kMemRefTypeWithMemSpace = 11, +def MemRefTypeWithMemSpace : DialectType<(type + Attribute:$memorySpace, + Array:$shape, + Type:$elementType, + Attribute:$layout +)> { + let printerPredicate = "!!$_val.getMemorySpace()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)"; +} +} + +// /// NoneType { +// /// } +// /// +// kNoneType = 12, +def NoneType : DialectType<(type)>; + +// /// TupleType { +// /// elementTypes: Type[] +// /// } +// kTupleType = 15, +def TupleType : DialectType<(type + Array:$types +)>; + +let cType = "UnrankedMemRefType" in { +// /// UnrankedMemRefType { +// /// elementType: Type +// /// } +// /// +// kUnrankedMemRefType = 16, +def UnrankedMemRefType : DialectType<(type + Type:$elementType +)> { + let printerPredicate = "!$_val.getMemorySpace()"; + let cBuilder = "get<$_resultType>(context, elementType, Attribute())"; +} + +// /// UnrankedMemRefTypeWithMemSpace { +// /// memorySpace: Attribute, +// /// elementType: Type +// /// } +// /// Variant of UnrankedMemRefType with non-default memory space. +// kUnrankedMemRefTypeWithMemSpace = 17, +def UnrankedMemRefTypeWithMemSpace : DialectType<(type + Attribute:$memorySpace, + Type:$elementType +)> { + let printerPredicate = "$_val.getMemorySpace()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, elementType, memorySpace)"; +} +} + +// /// UnrankedTensorType { +// /// elementType: Type +// /// } +// /// +// kUnrankedTensorType = 18, +def UnrankedTensorType : DialectType<(type + Type:$elementType +)>; + +let cType = "VectorType" in { +// /// VectorType { +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// +// kVectorType = 19, +def VectorType : DialectType<(type + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "!$_val.getNumScalableDims()"; +} + +// /// VectorTypeWithScalableDims { +// /// numScalableDims: varint, +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// Variant of VectorType with scalable dimensions. +// kVectorTypeWithScalableDims = 20, +def VectorTypeWithScalableDims : DialectType<(type + VarInt:$numScalableDims, + Array:$shape, + Type:$elementType +)> { + let printerPredicate = "$_val.getNumScalableDims()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)"; +} +} + +/// This enum contains marker codes used to indicate which attribute is +/// currently being decoded, and how it should be decoded. The order of these +/// codes should generally be unchanged, as any changes will inevitably break +/// compatibility with older bytecode. + +def BuiltinDialectAttributes : DialectAttributes<"Builtin"> { + let elems = [ + ArrayAttr, + DictionaryAttr, + StringAttr, + StringAttrWithType, + FlatSymbolRefAttr, + SymbolRefAttr, + TypeAttr, + UnitAttr, + IntegerAttr, + FloatAttr, + CallSiteLoc, + FileLineColLoc, + FusedLoc, + FusedLocWithMetadata, + NameLoc, + UnknownLoc, + DenseResourceElementsAttr, + DenseArrayAttr, + DenseIntOrFPElementsAttr, + DenseStringElementsAttr, + SparseElementsAttr + ]; +} + +def BuiltinDialectTypes : DialectTypes<"Builtin"> { + let elems = [ + IntegerType, + IndexType, + FunctionType, + BFloat16Type, + Float16Type, + Float32Type, + Float64Type, + Float80Type, + Float128Type, + ComplexType, + MemRefType, + MemRefTypeWithMemSpace, + NoneType, + RankedTensorType, + RankedTensorTypeWithEncoding, + TupleType, + UnrankedMemRefType, + UnrankedMemRefTypeWithMemSpace, + UnrankedTensorType, + VectorType, + VectorTypeWithScalableDims + ]; +} + +#endif // BUILTIN_BYTECODE diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BytecodeBase.td @@ -0,0 +1,159 @@ +//===-- BytecodeBase.td - Base bytecode R/W defs -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the base bytecode reader/writer definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef BYTECODE_BASE +#define BYTECODE_BASE + +class Bytecode { + // Template for parsing. + // $_reader == dialect bytecode reader + // $_resultType == result type of parsed instance + // $_var == variable being parsed + // If parser is not specified, then the parse of members is used. + string cParser = parse; + + // Template for building from parsed. + // $_resultType == result type of parsed instance + // $_args == args/members comma separated + string cBuilder = build; + + // Template for printing. + // $_writer == dialect bytecode writer + // $_name == parent attribute/type name + // $_getter == getter + string cPrinter = print; + + // Template for getter from in memory form. + // $_attrType == attribute/type + // $_member == member instance + // $_getMember == get + UpperCamelFromSnake($_member) + string cGetter = get; + + // Type built. + // Note: if cType is empty, then name of def is used. + string cType = t; + + // Predicate guarding parse method as an Attribute/Type could have multiple + // parse methods, specify predicates to be orthogonal and cover entire + // "print space" to avoid order dependence. + // If empty then method is unconditional. + // $_val == predicate function to apply on value dyn_casted to cType. + string printerPredicate = ""; +} + +class WithParser> : + Bytecode; +class WithBuilder> : + Bytecode; +class WithPrinter> : + Bytecode; +class WithType> : + Bytecode; +class WithGetter> : + Bytecode; + +class CompositeBytecode : WithType; + +class AttributeKind : + WithParser <"succeeded($_reader.readAttribute($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeAttribute($_getter)">>>; +def Attribute : AttributeKind; +class TypeKind : + WithParser <"succeeded($_reader.readType($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeType($_getter)">>>; +def Type : TypeKind; +def VarInt : + WithParser <"succeeded($_reader.readVarInt($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeVarInt($_getter)", + WithType <"uint64_t">>>>; +def SignedVarInt : + WithParser <"succeeded($_reader.readSignedVarInt($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeSignedVarInt($_getter)", + WithGetter<"$_attrType", + WithType <"int64_t">>>>>; +def Blob : + WithParser <"succeeded($_reader.readBlob($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedBlob($_getter)", + WithType <"ArrayRef">>>>; + +class KnownWidthAPInt : + WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeAPIntWithKnownWidth($_getter)", + WithType <"FailureOr">>>>; +class KnownSemanticsAPFloat : + WithParser <"succeeded(readAPFloatWithKnownSemantics($_reader, " # s # ", $_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeAPFloatWithKnownSemantics($_getter)", + WithType <"FailureOr">>>>; +class ResourceHandle : + WithParser <"succeeded(readResourceHandle<" # s # ">($_reader, $_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeResourceHandle($_getter)", + WithType <"FailureOr<" # s # ">">>>>; + +// Helper to define variable that is defined later but not parsed nor printed. +class LocalVar : + WithParser <"(($_var = " # d # "), true)", + WithBuilder<"$_args", + WithPrinter<"", + WithType >>>; + +// Array instances. +class Array { + Bytecode elemT = t; + + string cBuilder = "$_args"; +} + +// Define dialect attribute or type. +class DialectAttrOrType { + // Any members starting with underscore is not fed to create function but + // treated as purely local variable. + dag members = d; + + // When needing to specify a custom return type. + string cType = ""; + + // Any post-processing that needs to be done. + code postProcess = ""; +} + +class DialectAttribute : DialectAttrOrType, AttributeKind { + let cParser = "succeeded($_reader.readAttribute<$_resultType>($_var))"; + let cBuilder = "get<$_resultType>(context, $_args)"; +} +class DialectType : DialectAttrOrType, TypeKind { + let cParser = "succeeded($_reader.readType<$_resultType>($_var))"; + let cBuilder = "get<$_resultType>(context, $_args)"; +} + +class DialectAttributes { + string dialect = d; + list elems; +} + +class DialectTypes { + string dialect = d; + list elems; +} + +def attr; +def type; + +#endif // BYTECODE_BASE + diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -17,6 +17,10 @@ mlir_tablegen(BuiltinDialect.cpp.inc -gen-dialect-defs) add_public_tablegen_target(MLIRBuiltinDialectIncGen) +set(LLVM_TARGET_DEFINITIONS BuiltinDialectBytecode.td) +mlir_tablegen(BuiltinDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Builtin") +add_public_tablegen_target(MLIRBuiltinDialectBytecodeIncGen) + set(LLVM_TARGET_DEFINITIONS BuiltinLocationAttributes.td) mlir_tablegen(BuiltinLocationAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(BuiltinLocationAttributes.cpp.inc -gen-attrdef-defs) diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -16,547 +16,62 @@ using namespace mlir; -//===----------------------------------------------------------------------===// -// Encoding -//===----------------------------------------------------------------------===// - -namespace { -namespace builtin_encoding { -/// This enum contains marker codes used to indicate which attribute is -/// currently being decoded, and how it should be decoded. The order of these -/// codes should generally be unchanged, as any changes will inevitably break -/// compatibility with older bytecode. -enum AttributeCode { - /// ArrayAttr { - /// elements: Attribute[] - /// } - /// - kArrayAttr = 0, - - /// DictionaryAttr { - /// attrs: [] - /// } - kDictionaryAttr = 1, - - /// StringAttr { - /// value: string - /// } - kStringAttr = 2, - - /// StringAttrWithType { - /// value: string, - /// type: Type - /// } - /// A variant of StringAttr with a type. - kStringAttrWithType = 3, - - /// FlatSymbolRefAttr { - /// rootReference: StringAttr - /// } - /// A variant of SymbolRefAttr with no leaf references. - kFlatSymbolRefAttr = 4, - - /// SymbolRefAttr { - /// rootReference: StringAttr, - /// leafReferences: FlatSymbolRefAttr[] - /// } - kSymbolRefAttr = 5, - - /// TypeAttr { - /// value: Type - /// } - kTypeAttr = 6, - - /// UnitAttr { - /// } - kUnitAttr = 7, - - /// IntegerAttr { - /// type: Type - /// value: APInt, - /// } - kIntegerAttr = 8, - - /// FloatAttr { - /// type: FloatType - /// value: APFloat - /// } - kFloatAttr = 9, - - /// CallSiteLoc { - /// callee: LocationAttr, - /// caller: LocationAttr - /// } - kCallSiteLoc = 10, - - /// FileLineColLoc { - /// file: StringAttr, - /// line: varint, - /// column: varint - /// } - kFileLineColLoc = 11, - - /// FusedLoc { - /// locations: LocationAttr[] - /// } - kFusedLoc = 12, - - /// FusedLocWithMetadata { - /// locations: LocationAttr[], - /// metadata: Attribute - /// } - /// A variant of FusedLoc with metadata. - kFusedLocWithMetadata = 13, - - /// NameLoc { - /// name: StringAttr, - /// childLoc: LocationAttr - /// } - kNameLoc = 14, - - /// UnknownLoc { - /// } - kUnknownLoc = 15, - - /// DenseResourceElementsAttr { - /// type: Type, - /// handle: ResourceHandle - /// } - kDenseResourceElementsAttr = 16, - - /// DenseArrayAttr { - /// type: RankedTensorType, - /// data: blob - /// } - kDenseArrayAttr = 17, - - /// DenseIntOrFPElementsAttr { - /// type: ShapedType, - /// data: blob - /// } - kDenseIntOrFPElementsAttr = 18, - - /// DenseStringElementsAttr { - /// type: ShapedType, - /// isSplat: varint, - /// data: string[] - /// } - kDenseStringElementsAttr = 19, - - /// SparseElementsAttr { - /// type: ShapedType, - /// indices: DenseIntElementsAttr, - /// values: DenseElementsAttr - /// } - kSparseElementsAttr = 20, -}; - -/// This enum contains marker codes used to indicate which type is currently -/// being decoded, and how it should be decoded. The order of these codes should -/// generally be unchanged, as any changes will inevitably break compatibility -/// with older bytecode. -enum TypeCode { - /// IntegerType { - /// widthAndSignedness: varint // (width << 2) | (signedness) - /// } - /// - kIntegerType = 0, - - /// IndexType { - /// } - /// - kIndexType = 1, - - /// FunctionType { - /// inputs: Type[], - /// results: Type[] - /// } - /// - kFunctionType = 2, - - /// BFloat16Type { - /// } - /// - kBFloat16Type = 3, - - /// Float16Type { - /// } - /// - kFloat16Type = 4, - - /// Float32Type { - /// } - /// - kFloat32Type = 5, - - /// Float64Type { - /// } - /// - kFloat64Type = 6, - - /// Float80Type { - /// } - /// - kFloat80Type = 7, - - /// Float128Type { - /// } - /// - kFloat128Type = 8, - - /// ComplexType { - /// elementType: Type - /// } - /// - kComplexType = 9, - - /// MemRefType { - /// shape: svarint[], - /// elementType: Type, - /// layout: Attribute - /// } - /// - kMemRefType = 10, - - /// MemRefTypeWithMemSpace { - /// memorySpace: Attribute, - /// shape: svarint[], - /// elementType: Type, - /// layout: Attribute - /// } - /// Variant of MemRefType with non-default memory space. - kMemRefTypeWithMemSpace = 11, - - /// NoneType { - /// } - /// - kNoneType = 12, - - /// RankedTensorType { - /// shape: svarint[], - /// elementType: Type, - /// } - /// - kRankedTensorType = 13, - - /// RankedTensorTypeWithEncoding { - /// encoding: Attribute, - /// shape: svarint[], - /// elementType: Type - /// } - /// Variant of RankedTensorType with an encoding. - kRankedTensorTypeWithEncoding = 14, - - /// TupleType { - /// elementTypes: Type[] - /// } - kTupleType = 15, - - /// UnrankedMemRefType { - /// shape: svarint[] - /// } - /// - kUnrankedMemRefType = 16, - - /// UnrankedMemRefTypeWithMemSpace { - /// memorySpace: Attribute, - /// shape: svarint[] - /// } - /// Variant of UnrankedMemRefType with non-default memory space. - kUnrankedMemRefTypeWithMemSpace = 17, - - /// UnrankedTensorType { - /// elementType: Type - /// } - /// - kUnrankedTensorType = 18, - - /// VectorType { - /// shape: svarint[], - /// elementType: Type - /// } - /// - kVectorType = 19, - - /// VectorTypeWithScalableDims { - /// numScalableDims: varint, - /// shape: svarint[], - /// elementType: Type - /// } - /// Variant of VectorType with scalable dimensions. - kVectorTypeWithScalableDims = 20, -}; - -} // namespace builtin_encoding -} // namespace - //===----------------------------------------------------------------------===// // BuiltinDialectBytecodeInterface //===----------------------------------------------------------------------===// namespace { -/// This class implements the bytecode interface for the builtin dialect. -struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { - BuiltinDialectBytecodeInterface(Dialect *dialect) - : BytecodeDialectInterface(dialect) {} - - //===--------------------------------------------------------------------===// - // Attributes - - Attribute readAttribute(DialectBytecodeReader &reader) const override; - ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const; - DenseArrayAttr readDenseArrayAttr(DialectBytecodeReader &reader) const; - DenseElementsAttr - readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader) const; - DenseStringElementsAttr - readDenseStringElementsAttr(DialectBytecodeReader &reader) const; - DenseResourceElementsAttr - readDenseResourceElementsAttr(DialectBytecodeReader &reader) const; - DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const; - FloatAttr readFloatAttr(DialectBytecodeReader &reader) const; - IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const; - SparseElementsAttr - readSparseElementsAttr(DialectBytecodeReader &reader) const; - StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const; - SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader, - bool hasNestedRefs) const; - TypeAttr readTypeAttr(DialectBytecodeReader &reader) const; - - LocationAttr readCallSiteLoc(DialectBytecodeReader &reader) const; - LocationAttr readFileLineColLoc(DialectBytecodeReader &reader) const; - LocationAttr readFusedLoc(DialectBytecodeReader &reader, - bool hasMetadata) const; - LocationAttr readNameLoc(DialectBytecodeReader &reader) const; - - LogicalResult writeAttribute(Attribute attr, - DialectBytecodeWriter &writer) const override; - void write(ArrayAttr attr, DialectBytecodeWriter &writer) const; - void write(DenseArrayAttr attr, DialectBytecodeWriter &writer) const; - void write(DenseIntOrFPElementsAttr attr, - DialectBytecodeWriter &writer) const; - void write(DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const; - void write(DenseResourceElementsAttr attr, - DialectBytecodeWriter &writer) const; - void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const; - void write(IntegerAttr attr, DialectBytecodeWriter &writer) const; - void write(FloatAttr attr, DialectBytecodeWriter &writer) const; - void write(SparseElementsAttr attr, DialectBytecodeWriter &writer) const; - void write(StringAttr attr, DialectBytecodeWriter &writer) const; - void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const; - void write(TypeAttr attr, DialectBytecodeWriter &writer) const; - - void write(CallSiteLoc attr, DialectBytecodeWriter &writer) const; - void write(FileLineColLoc attr, DialectBytecodeWriter &writer) const; - void write(FusedLoc attr, DialectBytecodeWriter &writer) const; - void write(NameLoc attr, DialectBytecodeWriter &writer) const; - LogicalResult write(OpaqueLoc attr, DialectBytecodeWriter &writer) const; - - //===--------------------------------------------------------------------===// - // Types - - Type readType(DialectBytecodeReader &reader) const override; - ComplexType readComplexType(DialectBytecodeReader &reader) const; - IntegerType readIntegerType(DialectBytecodeReader &reader) const; - FunctionType readFunctionType(DialectBytecodeReader &reader) const; - MemRefType readMemRefType(DialectBytecodeReader &reader, - bool hasMemSpace) const; - RankedTensorType readRankedTensorType(DialectBytecodeReader &reader, - bool hasEncoding) const; - TupleType readTupleType(DialectBytecodeReader &reader) const; - UnrankedMemRefType readUnrankedMemRefType(DialectBytecodeReader &reader, - bool hasMemSpace) const; - UnrankedTensorType - readUnrankedTensorType(DialectBytecodeReader &reader) const; - VectorType readVectorType(DialectBytecodeReader &reader, - bool hasScalableDims) const; - - LogicalResult writeType(Type type, - DialectBytecodeWriter &writer) const override; - void write(ComplexType type, DialectBytecodeWriter &writer) const; - void write(IntegerType type, DialectBytecodeWriter &writer) const; - void write(FunctionType type, DialectBytecodeWriter &writer) const; - void write(MemRefType type, DialectBytecodeWriter &writer) const; - void write(RankedTensorType type, DialectBytecodeWriter &writer) const; - void write(TupleType type, DialectBytecodeWriter &writer) const; - void write(UnrankedMemRefType type, DialectBytecodeWriter &writer) const; - void write(UnrankedTensorType type, DialectBytecodeWriter &writer) const; - void write(VectorType type, DialectBytecodeWriter &writer) const; -}; -} // namespace - -void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) { - dialect->addInterfaces(); -} - -//===----------------------------------------------------------------------===// -// Attributes -//===----------------------------------------------------------------------===// - -Attribute BuiltinDialectBytecodeInterface::readAttribute( - DialectBytecodeReader &reader) const { - uint64_t code; - if (failed(reader.readVarInt(code))) - return Attribute(); - switch (code) { - case builtin_encoding::kArrayAttr: - return readArrayAttr(reader); - case builtin_encoding::kDictionaryAttr: - return readDictionaryAttr(reader); - case builtin_encoding::kStringAttr: - return readStringAttr(reader, /*hasType=*/false); - case builtin_encoding::kStringAttrWithType: - return readStringAttr(reader, /*hasType=*/true); - case builtin_encoding::kFlatSymbolRefAttr: - return readSymbolRefAttr(reader, /*hasNestedRefs=*/false); - case builtin_encoding::kSymbolRefAttr: - return readSymbolRefAttr(reader, /*hasNestedRefs=*/true); - case builtin_encoding::kTypeAttr: - return readTypeAttr(reader); - case builtin_encoding::kUnitAttr: - return UnitAttr::get(getContext()); - case builtin_encoding::kIntegerAttr: - return readIntegerAttr(reader); - case builtin_encoding::kFloatAttr: - return readFloatAttr(reader); - case builtin_encoding::kCallSiteLoc: - return readCallSiteLoc(reader); - case builtin_encoding::kFileLineColLoc: - return readFileLineColLoc(reader); - case builtin_encoding::kFusedLoc: - return readFusedLoc(reader, /*hasMetadata=*/false); - case builtin_encoding::kFusedLocWithMetadata: - return readFusedLoc(reader, /*hasMetadata=*/true); - case builtin_encoding::kNameLoc: - return readNameLoc(reader); - case builtin_encoding::kUnknownLoc: - return UnknownLoc::get(getContext()); - case builtin_encoding::kDenseResourceElementsAttr: - return readDenseResourceElementsAttr(reader); - case builtin_encoding::kDenseArrayAttr: - return readDenseArrayAttr(reader); - case builtin_encoding::kDenseIntOrFPElementsAttr: - return readDenseIntOrFPElementsAttr(reader); - case builtin_encoding::kDenseStringElementsAttr: - return readDenseStringElementsAttr(reader); - case builtin_encoding::kSparseElementsAttr: - return readSparseElementsAttr(reader); - default: - reader.emitError() << "unknown builtin attribute code: " << code; - return Attribute(); - } -} - -LogicalResult BuiltinDialectBytecodeInterface::writeAttribute( - Attribute attr, DialectBytecodeWriter &writer) const { - return TypeSwitch(attr) - .Case([&](auto attr) { - write(attr, writer); - return success(); - }) - .Case([&](auto attr) { - write(attr, writer); - return success(); - }) - .Case([&](OpaqueLoc attr) { return write(attr, writer); }) - .Case([&](UnitAttr) { - writer.writeVarInt(builtin_encoding::kUnitAttr); - return success(); - }) - .Case([&](UnknownLoc) { - writer.writeVarInt(builtin_encoding::kUnknownLoc); - return success(); - }) - .Default([&](Attribute) { return failure(); }); -} - -//===----------------------------------------------------------------------===// -// ArrayAttr - -ArrayAttr BuiltinDialectBytecodeInterface::readArrayAttr( - DialectBytecodeReader &reader) const { - SmallVector elements; - if (failed(reader.readAttributes(elements))) - return ArrayAttr(); - return ArrayAttr::get(getContext(), elements); -} - -void BuiltinDialectBytecodeInterface::write( - ArrayAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kArrayAttr); - writer.writeAttributes(attr.getValue()); -} //===----------------------------------------------------------------------===// -// DenseArrayAttr +// Utility functions -DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr( - DialectBytecodeReader &reader) const { - Type elementType; - uint64_t size; - ArrayRef blob; - if (failed(reader.readType(elementType)) || failed(reader.readVarInt(size)) || - failed(reader.readBlob(blob))) - return DenseArrayAttr(); - return DenseArrayAttr::get(elementType, size, blob); -} +// TODO: Move these to separate file. -void BuiltinDialectBytecodeInterface::write( - DenseArrayAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseArrayAttr); - writer.writeType(attr.getElementType()); - writer.writeVarInt(attr.getSize()); - writer.writeOwnedBlob(attr.getRawData()); +// Returns the bitwidth if known, else return 0. +static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) { + if (auto intType = dyn_cast(type)) { + return intType.getWidth(); + } else if (type.isa()) { + return IndexType::kInternalStorageBitWidth; + } + reader.emitError() + << "expected integer or index type for IntegerAttr, but got: " << type; + return 0; } -//===----------------------------------------------------------------------===// -// DenseIntOrFPElementsAttr - -DenseElementsAttr BuiltinDialectBytecodeInterface::readDenseIntOrFPElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - ArrayRef blob; - if (failed(reader.readType(type)) || failed(reader.readBlob(blob))) - return DenseIntOrFPElementsAttr(); - return DenseIntOrFPElementsAttr::getFromRawBuffer(type, blob); +static LogicalResult readAPIntWithKnownWidth(DialectBytecodeReader &reader, + Type type, FailureOr &val) { + unsigned bitWidth = getIntegerBitWidth(reader, type); + if (bitWidth == 0) + return failure(); + val = reader.readAPIntWithKnownWidth(bitWidth); + return val; } -void BuiltinDialectBytecodeInterface::write( - DenseIntOrFPElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseIntOrFPElementsAttr); - writer.writeType(attr.getType()); - writer.writeOwnedBlob(attr.getRawData()); +static LogicalResult +readAPFloatWithKnownSemantics(DialectBytecodeReader &reader, Type type, + FailureOr &val) { + auto ftype = dyn_cast(type); + if (!ftype) + return failure(); + val = reader.readAPFloatWithKnownSemantics(ftype.getFloatSemantics()); + return success(); } -//===----------------------------------------------------------------------===// -// DenseStringElementsAttr - -DenseStringElementsAttr -BuiltinDialectBytecodeInterface::readDenseStringElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - uint64_t isSplat; - if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat))) - return DenseStringElementsAttr(); - - SmallVector values(isSplat ? 1 : type.getNumElements()); - for (StringRef &value : values) +LogicalResult +readPotentiallySplatString(DialectBytecodeReader &reader, ShapedType type, + bool isSplat, + SmallVectorImpl &rawStringData) { + rawStringData.resize(isSplat ? 1 : type.getNumElements()); + for (StringRef &value : rawStringData) if (failed(reader.readString(value))) - return DenseStringElementsAttr(); - return DenseStringElementsAttr::get(type, values); + return failure(); + return success(); } -void BuiltinDialectBytecodeInterface::write( - DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseStringElementsAttr); - writer.writeType(attr.getType()); - +void writePotentiallySplatString(DialectBytecodeWriter &writer, + DenseStringElementsAttr attr) { bool isSplat = attr.isSplat(); - writer.writeVarInt(isSplat); - - // If the attribute is a splat, only write out the single value. if (isSplat) return writer.writeOwnedString(attr.getRawStringData().front()); @@ -564,614 +79,39 @@ writer.writeOwnedString(str); } -//===----------------------------------------------------------------------===// -// DenseResourceElementsAttr - -DenseResourceElementsAttr -BuiltinDialectBytecodeInterface::readDenseResourceElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - if (failed(reader.readType(type))) - return DenseResourceElementsAttr(); - - FailureOr handle = - reader.readResourceHandle(); - if (failed(handle)) - return DenseResourceElementsAttr(); - - return DenseResourceElementsAttr::get(type, *handle); -} - -void BuiltinDialectBytecodeInterface::write( - DenseResourceElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDenseResourceElementsAttr); - writer.writeType(attr.getType()); - writer.writeResourceHandle(attr.getRawHandle()); -} - -//===----------------------------------------------------------------------===// -// DictionaryAttr - -DictionaryAttr BuiltinDialectBytecodeInterface::readDictionaryAttr( - DialectBytecodeReader &reader) const { - auto readNamedAttr = [&]() -> FailureOr { - StringAttr name; - Attribute value; - if (failed(reader.readAttribute(name)) || - failed(reader.readAttribute(value))) - return failure(); - return NamedAttribute(name, value); - }; - SmallVector attrs; - if (failed(reader.readList(attrs, readNamedAttr))) - return DictionaryAttr(); - return DictionaryAttr::get(getContext(), attrs); -} - -void BuiltinDialectBytecodeInterface::write( - DictionaryAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kDictionaryAttr); - writer.writeList(attr.getValue(), [&](NamedAttribute attr) { - writer.writeAttribute(attr.getName()); - writer.writeAttribute(attr.getValue()); - }); -} - -//===----------------------------------------------------------------------===// -// FloatAttr - -FloatAttr BuiltinDialectBytecodeInterface::readFloatAttr( - DialectBytecodeReader &reader) const { - FloatType type; - if (failed(reader.readType(type))) - return FloatAttr(); - FailureOr value = - reader.readAPFloatWithKnownSemantics(type.getFloatSemantics()); - if (failed(value)) - return FloatAttr(); - return FloatAttr::get(type, *value); -} - -void BuiltinDialectBytecodeInterface::write( - FloatAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kFloatAttr); - writer.writeType(attr.getType()); - writer.writeAPFloatWithKnownSemantics(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// IntegerAttr - -IntegerAttr BuiltinDialectBytecodeInterface::readIntegerAttr( - DialectBytecodeReader &reader) const { - Type type; - if (failed(reader.readType(type))) - return IntegerAttr(); - - // Extract the value storage width from the type. - unsigned bitWidth; - if (auto intType = type.dyn_cast()) { - bitWidth = intType.getWidth(); - } else if (type.isa()) { - bitWidth = IndexType::kInternalStorageBitWidth; - } else { - reader.emitError() - << "expected integer or index type for IntegerAttr, but got: " << type; - return IntegerAttr(); - } - - FailureOr value = reader.readAPIntWithKnownWidth(bitWidth); - if (failed(value)) - return IntegerAttr(); - return IntegerAttr::get(type, *value); -} - -void BuiltinDialectBytecodeInterface::write( - IntegerAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kIntegerAttr); - writer.writeType(attr.getType()); - writer.writeAPIntWithKnownWidth(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// SparseElementsAttr - -SparseElementsAttr BuiltinDialectBytecodeInterface::readSparseElementsAttr( - DialectBytecodeReader &reader) const { - ShapedType type; - DenseIntElementsAttr indices; - DenseElementsAttr values; - if (failed(reader.readType(type)) || failed(reader.readAttribute(indices)) || - failed(reader.readAttribute(values))) - return SparseElementsAttr(); - return SparseElementsAttr::get(type, indices, values); -} - -void BuiltinDialectBytecodeInterface::write( - SparseElementsAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kSparseElementsAttr); - writer.writeType(attr.getType()); - writer.writeAttribute(attr.getIndices()); - writer.writeAttribute(attr.getValues()); -} - -//===----------------------------------------------------------------------===// -// StringAttr - -StringAttr -BuiltinDialectBytecodeInterface::readStringAttr(DialectBytecodeReader &reader, - bool hasType) const { - StringRef string; - if (failed(reader.readString(string))) - return StringAttr(); - - // Read the type if present. - Type type; - if (!hasType) - type = NoneType::get(getContext()); - else if (failed(reader.readType(type))) - return StringAttr(); - return StringAttr::get(string, type); -} - -void BuiltinDialectBytecodeInterface::write( - StringAttr attr, DialectBytecodeWriter &writer) const { - // We only encode the type if it isn't NoneType, which is significantly less - // common. - Type type = attr.getType(); - if (!type.isa()) { - writer.writeVarInt(builtin_encoding::kStringAttrWithType); - writer.writeOwnedString(attr.getValue()); - writer.writeType(type); - return; - } - writer.writeVarInt(builtin_encoding::kStringAttr); - writer.writeOwnedString(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// SymbolRefAttr +#include "mlir/IR/BuiltinDialectBytecode.cpp.inc" -SymbolRefAttr BuiltinDialectBytecodeInterface::readSymbolRefAttr( - DialectBytecodeReader &reader, bool hasNestedRefs) const { - StringAttr rootReference; - if (failed(reader.readAttribute(rootReference))) - return SymbolRefAttr(); - SmallVector nestedReferences; - if (hasNestedRefs && failed(reader.readAttributes(nestedReferences))) - return SymbolRefAttr(); - return SymbolRefAttr::get(rootReference, nestedReferences); -} - -void BuiltinDialectBytecodeInterface::write( - SymbolRefAttr attr, DialectBytecodeWriter &writer) const { - ArrayRef nestedRefs = attr.getNestedReferences(); - writer.writeVarInt(nestedRefs.empty() ? builtin_encoding::kFlatSymbolRefAttr - : builtin_encoding::kSymbolRefAttr); - - writer.writeAttribute(attr.getRootReference()); - if (!nestedRefs.empty()) - writer.writeAttributes(nestedRefs); -} - -//===----------------------------------------------------------------------===// -// TypeAttr - -TypeAttr BuiltinDialectBytecodeInterface::readTypeAttr( - DialectBytecodeReader &reader) const { - Type type; - if (failed(reader.readType(type))) - return TypeAttr(); - return TypeAttr::get(type); -} - -void BuiltinDialectBytecodeInterface::write( - TypeAttr attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kTypeAttr); - writer.writeType(attr.getValue()); -} - -//===----------------------------------------------------------------------===// -// CallSiteLoc - -LocationAttr BuiltinDialectBytecodeInterface::readCallSiteLoc( - DialectBytecodeReader &reader) const { - LocationAttr callee, caller; - if (failed(reader.readAttribute(callee)) || - failed(reader.readAttribute(caller))) - return LocationAttr(); - return CallSiteLoc::get(callee, caller); -} - -void BuiltinDialectBytecodeInterface::write( - CallSiteLoc attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kCallSiteLoc); - writer.writeAttribute(attr.getCallee()); - writer.writeAttribute(attr.getCaller()); -} - -//===----------------------------------------------------------------------===// -// FileLineColLoc - -LocationAttr BuiltinDialectBytecodeInterface::readFileLineColLoc( - DialectBytecodeReader &reader) const { - StringAttr filename; - uint64_t line, column; - if (failed(reader.readAttribute(filename)) || - failed(reader.readVarInt(line)) || failed(reader.readVarInt(column))) - return LocationAttr(); - return FileLineColLoc::get(filename, line, column); -} - -void BuiltinDialectBytecodeInterface::write( - FileLineColLoc attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kFileLineColLoc); - writer.writeAttribute(attr.getFilename()); - writer.writeVarInt(attr.getLine()); - writer.writeVarInt(attr.getColumn()); -} - -//===----------------------------------------------------------------------===// -// FusedLoc - -LocationAttr -BuiltinDialectBytecodeInterface::readFusedLoc(DialectBytecodeReader &reader, - bool hasMetadata) const { - // Parse the child locations. - auto readLoc = [&]() -> FailureOr { - LocationAttr locAttr; - if (failed(reader.readAttribute(locAttr))) - return failure(); - return Location(locAttr); - }; - SmallVector locations; - if (failed(reader.readList(locations, readLoc))) - return LocationAttr(); - - // Parse the metadata if present. - Attribute metadata; - if (hasMetadata && failed(reader.readAttribute(metadata))) - return LocationAttr(); - - return FusedLoc::get(locations, metadata, getContext()); -} - -void BuiltinDialectBytecodeInterface::write( - FusedLoc attr, DialectBytecodeWriter &writer) const { - if (Attribute metadata = attr.getMetadata()) { - writer.writeVarInt(builtin_encoding::kFusedLocWithMetadata); - writer.writeAttributes(attr.getLocations()); - writer.writeAttribute(metadata); - } else { - writer.writeVarInt(builtin_encoding::kFusedLoc); - writer.writeAttributes(attr.getLocations()); - } -} - -//===----------------------------------------------------------------------===// -// NameLoc - -LocationAttr BuiltinDialectBytecodeInterface::readNameLoc( - DialectBytecodeReader &reader) const { - StringAttr name; - LocationAttr childLoc; - if (failed(reader.readAttribute(name)) || - failed(reader.readAttribute(childLoc))) - return LocationAttr(); - return NameLoc::get(name, childLoc); -} - -void BuiltinDialectBytecodeInterface::write( - NameLoc attr, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kNameLoc); - writer.writeAttribute(attr.getName()); - writer.writeAttribute(attr.getChildLoc()); -} - -//===----------------------------------------------------------------------===// -// OpaqueLoc - -LogicalResult -BuiltinDialectBytecodeInterface::write(OpaqueLoc attr, - DialectBytecodeWriter &writer) const { - // We can't encode an OpaqueLoc directly given that it is in-memory only, so - // encode the fallback instead. - return writeAttribute(attr.getFallbackLocation(), writer); -} - -//===----------------------------------------------------------------------===// -// Types -//===----------------------------------------------------------------------===// +/// This class implements the bytecode interface for the builtin dialect. +struct BuiltinDialectBytecodeInterface : public BytecodeDialectInterface { + BuiltinDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} -Type BuiltinDialectBytecodeInterface::readType( - DialectBytecodeReader &reader) const { - uint64_t code; - if (failed(reader.readVarInt(code))) - return Type(); - switch (code) { - case builtin_encoding::kIntegerType: - return readIntegerType(reader); - case builtin_encoding::kIndexType: - return IndexType::get(getContext()); - case builtin_encoding::kFunctionType: - return readFunctionType(reader); - case builtin_encoding::kBFloat16Type: - return BFloat16Type::get(getContext()); - case builtin_encoding::kFloat16Type: - return Float16Type::get(getContext()); - case builtin_encoding::kFloat32Type: - return Float32Type::get(getContext()); - case builtin_encoding::kFloat64Type: - return Float64Type::get(getContext()); - case builtin_encoding::kFloat80Type: - return Float80Type::get(getContext()); - case builtin_encoding::kFloat128Type: - return Float128Type::get(getContext()); - case builtin_encoding::kComplexType: - return readComplexType(reader); - case builtin_encoding::kMemRefType: - return readMemRefType(reader, /*hasMemSpace=*/false); - case builtin_encoding::kMemRefTypeWithMemSpace: - return readMemRefType(reader, /*hasMemSpace=*/true); - case builtin_encoding::kNoneType: - return NoneType::get(getContext()); - case builtin_encoding::kRankedTensorType: - return readRankedTensorType(reader, /*hasEncoding=*/false); - case builtin_encoding::kRankedTensorTypeWithEncoding: - return readRankedTensorType(reader, /*hasEncoding=*/true); - case builtin_encoding::kTupleType: - return readTupleType(reader); - case builtin_encoding::kUnrankedMemRefType: - return readUnrankedMemRefType(reader, /*hasMemSpace=*/false); - case builtin_encoding::kUnrankedMemRefTypeWithMemSpace: - return readUnrankedMemRefType(reader, /*hasMemSpace=*/true); - case builtin_encoding::kUnrankedTensorType: - return readUnrankedTensorType(reader); - case builtin_encoding::kVectorType: - return readVectorType(reader, /*hasScalableDims=*/false); - case builtin_encoding::kVectorTypeWithScalableDims: - return readVectorType(reader, /*hasScalableDims=*/true); + //===--------------------------------------------------------------------===// + // Attributes - default: - reader.emitError() << "unknown builtin type code: " << code; - return Type(); + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); } -} - -LogicalResult BuiltinDialectBytecodeInterface::writeType( - Type type, DialectBytecodeWriter &writer) const { - return TypeSwitch(type) - .Case([&](auto type) { - write(type, writer); - return success(); - }) - .Case([&](IndexType) { - return writer.writeVarInt(builtin_encoding::kIndexType), success(); - }) - .Case([&](BFloat16Type) { - return writer.writeVarInt(builtin_encoding::kBFloat16Type), success(); - }) - .Case([&](Float16Type) { - return writer.writeVarInt(builtin_encoding::kFloat16Type), success(); - }) - .Case([&](Float32Type) { - return writer.writeVarInt(builtin_encoding::kFloat32Type), success(); - }) - .Case([&](Float64Type) { - return writer.writeVarInt(builtin_encoding::kFloat64Type), success(); - }) - .Case([&](Float80Type) { - return writer.writeVarInt(builtin_encoding::kFloat80Type), success(); - }) - .Case([&](Float128Type) { - return writer.writeVarInt(builtin_encoding::kFloat128Type), success(); - }) - .Case([&](NoneType) { - return writer.writeVarInt(builtin_encoding::kNoneType), success(); - }) - .Default([&](Type) { return failure(); }); -} - -//===----------------------------------------------------------------------===// -// ComplexType - -ComplexType BuiltinDialectBytecodeInterface::readComplexType( - DialectBytecodeReader &reader) const { - Type elementType; - if (failed(reader.readType(elementType))) - return ComplexType(); - return ComplexType::get(elementType); -} - -void BuiltinDialectBytecodeInterface::write( - ComplexType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kComplexType); - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// IntegerType - -IntegerType BuiltinDialectBytecodeInterface::readIntegerType( - DialectBytecodeReader &reader) const { - uint64_t encoding; - if (failed(reader.readVarInt(encoding))) - return IntegerType(); - return IntegerType::get( - getContext(), encoding >> 2, - static_cast(encoding & 0x3)); -} - -void BuiltinDialectBytecodeInterface::write( - IntegerType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kIntegerType); - writer.writeVarInt((type.getWidth() << 2) | type.getSignedness()); -} - -//===----------------------------------------------------------------------===// -// FunctionType - -FunctionType BuiltinDialectBytecodeInterface::readFunctionType( - DialectBytecodeReader &reader) const { - SmallVector inputs, results; - if (failed(reader.readTypes(inputs)) || failed(reader.readTypes(results))) - return FunctionType(); - return FunctionType::get(getContext(), inputs, results); -} - -void BuiltinDialectBytecodeInterface::write( - FunctionType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kFunctionType); - writer.writeTypes(type.getInputs()); - writer.writeTypes(type.getResults()); -} - -//===----------------------------------------------------------------------===// -// MemRefType - -MemRefType -BuiltinDialectBytecodeInterface::readMemRefType(DialectBytecodeReader &reader, - bool hasMemSpace) const { - Attribute memorySpace; - if (hasMemSpace && failed(reader.readAttribute(memorySpace))) - return MemRefType(); - SmallVector shape; - Type elementType; - MemRefLayoutAttrInterface layout; - if (failed(reader.readSignedVarInts(shape)) || - failed(reader.readType(elementType)) || - failed(reader.readAttribute(layout))) - return MemRefType(); - return MemRefType::get(shape, elementType, layout, memorySpace); -} -void BuiltinDialectBytecodeInterface::write( - MemRefType type, DialectBytecodeWriter &writer) const { - if (Attribute memSpace = type.getMemorySpace()) { - writer.writeVarInt(builtin_encoding::kMemRefTypeWithMemSpace); - writer.writeAttribute(memSpace); - } else { - writer.writeVarInt(builtin_encoding::kMemRefType); + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); } - writer.writeSignedVarInts(type.getShape()); - writer.writeType(type.getElementType()); - writer.writeAttribute(type.getLayout()); -} - -//===----------------------------------------------------------------------===// -// RankedTensorType -RankedTensorType BuiltinDialectBytecodeInterface::readRankedTensorType( - DialectBytecodeReader &reader, bool hasEncoding) const { - Attribute encoding; - if (hasEncoding && failed(reader.readAttribute(encoding))) - return RankedTensorType(); - SmallVector shape; - Type elementType; - if (failed(reader.readSignedVarInts(shape)) || - failed(reader.readType(elementType))) - return RankedTensorType(); - return RankedTensorType::get(shape, elementType, encoding); -} + //===--------------------------------------------------------------------===// + // Types -void BuiltinDialectBytecodeInterface::write( - RankedTensorType type, DialectBytecodeWriter &writer) const { - if (Attribute encoding = type.getEncoding()) { - writer.writeVarInt(builtin_encoding::kRankedTensorTypeWithEncoding); - writer.writeAttribute(encoding); - } else { - writer.writeVarInt(builtin_encoding::kRankedTensorType); + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); } - writer.writeSignedVarInts(type.getShape()); - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// TupleType - -TupleType BuiltinDialectBytecodeInterface::readTupleType( - DialectBytecodeReader &reader) const { - SmallVector elements; - if (failed(reader.readTypes(elements))) - return TupleType(); - return TupleType::get(getContext(), elements); -} - -void BuiltinDialectBytecodeInterface::write( - TupleType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kTupleType); - writer.writeTypes(type.getTypes()); -} - -//===----------------------------------------------------------------------===// -// UnrankedMemRefType - -UnrankedMemRefType BuiltinDialectBytecodeInterface::readUnrankedMemRefType( - DialectBytecodeReader &reader, bool hasMemSpace) const { - Attribute memorySpace; - if (hasMemSpace && failed(reader.readAttribute(memorySpace))) - return UnrankedMemRefType(); - Type elementType; - if (failed(reader.readType(elementType))) - return UnrankedMemRefType(); - return UnrankedMemRefType::get(elementType, memorySpace); -} -void BuiltinDialectBytecodeInterface::write( - UnrankedMemRefType type, DialectBytecodeWriter &writer) const { - if (Attribute memSpace = type.getMemorySpace()) { - writer.writeVarInt(builtin_encoding::kUnrankedMemRefTypeWithMemSpace); - writer.writeAttribute(memSpace); - } else { - writer.writeVarInt(builtin_encoding::kUnrankedMemRefType); + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); } - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// UnrankedTensorType - -UnrankedTensorType BuiltinDialectBytecodeInterface::readUnrankedTensorType( - DialectBytecodeReader &reader) const { - Type elementType; - if (failed(reader.readType(elementType))) - return UnrankedTensorType(); - return UnrankedTensorType::get(elementType); -} - -void BuiltinDialectBytecodeInterface::write( - UnrankedTensorType type, DialectBytecodeWriter &writer) const { - writer.writeVarInt(builtin_encoding::kUnrankedTensorType); - writer.writeType(type.getElementType()); -} - -//===----------------------------------------------------------------------===// -// VectorType - -VectorType -BuiltinDialectBytecodeInterface::readVectorType(DialectBytecodeReader &reader, - bool hasScalableDims) const { - uint64_t numScalableDims = 0; - if (hasScalableDims && failed(reader.readVarInt(numScalableDims))) - return VectorType(); - SmallVector shape; - Type elementType; - if (failed(reader.readSignedVarInts(shape)) || - failed(reader.readType(elementType))) - return VectorType(); - return VectorType::get(shape, elementType, numScalableDims); -} +}; +} // namespace -void BuiltinDialectBytecodeInterface::write( - VectorType type, DialectBytecodeWriter &writer) const { - if (unsigned numScalableDims = type.getNumScalableDims()) { - writer.writeVarInt(builtin_encoding::kVectorTypeWithScalableDims); - writer.writeVarInt(numScalableDims); - } else { - writer.writeVarInt(builtin_encoding::kVectorType); - } - writer.writeSignedVarInts(type.getShape()); - writer.writeType(type.getElementType()); +void builtin_dialect_detail::addBytecodeInterface(BuiltinDialect *dialect) { + dialect->addInterfaces(); } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -43,6 +43,7 @@ DEPENDS MLIRBuiltinAttributesIncGen MLIRBuiltinAttributeInterfacesIncGen + MLIRBuiltinDialectBytecodeIncGen MLIRBuiltinDialectIncGen MLIRBuiltinLocationAttributesIncGen MLIRBuiltinOpsIncGen diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp @@ -0,0 +1,464 @@ +//===- BytecodeDialectGen.cpp - Dialect bytecode read/writer gen ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/IndentedOstream.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include + +using namespace llvm; + +static llvm::cl::OptionCategory dialectGenCat("Options for -gen-bytecode"); +static llvm::cl::opt + selectedBcDialect("bytecode-dialect", + llvm::cl::desc("The dialect to gen for"), + llvm::cl::cat(dialectGenCat), llvm::cl::CommaSeparated); + +namespace { + +/// Helper to replace set of from strings to target in `s`. +/// Assumed: non-overlapping replacements. +std::string format(StringRef templ, std::map &&map) { + std::string s = templ.str(); + for (const auto &[from, to] : map) + // All replacements start with $, don't treat as anchor. + s = std::regex_replace(s, std::regex("\\" + from), to); + return s; +} + +/// Helper class to generate C++ bytecode parser helpers. +class Generator { +public: + Generator(raw_ostream &output) : output(output) {} + + /// Returns whether successfully emitted attribute/type parsers. + void emitParse(StringRef kind, Record &x); + + /// Returns whether successfully emitted attribute/type printers. + void emitPrint(StringRef kind, StringRef type, + ArrayRef> vec); + + /// Emits parse dispatch table. + void emitParseDispatch(StringRef kind, ArrayRef vec); + + /// Emits print dispatch table. + void emitPrintDispatch(StringRef kind, ArrayRef vec); + +private: + /// Emits parse calls to construct given kind. + void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, + ArrayRef args, ArrayRef argNames, + StringRef failure, mlir::raw_indented_ostream &ios); + + /// Emits print instructions. + void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent, + StringRef name, mlir::raw_indented_ostream &ios); + + raw_ostream &output; +}; + +/// Return string with first character capitalized. +static std::string capitalize(StringRef str) { + return ((Twine)toUpper(str[0]) + str.drop_front()).str(); +} + +/// Return the C++ type for the given record. +static std::string getCType(Record *def) { + std::string format = "{0}"; + if (def->isSubClassOf("Array")) { + def = def->getValueAsDef("elemT"); + format = "SmallVector<{0}>"; + } + + StringRef cType = def->getValueAsString("cType"); + if (cType.empty()) { + if (def->isAnonymous()) + PrintFatalError(def->getLoc(), "Unable to determine cType"); + + return formatv(format.c_str(), def->getName().str()); + } + return formatv(format.c_str(), cType.str()); +} + +void Generator::emitParseDispatch(StringRef kind, ArrayRef vec) { + mlir::raw_indented_ostream os(output); + char const *head = + R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; + os << formatv(head, capitalize(kind)); + auto funScope = os.scope(" {\n", "}\n\n"); + + os << "uint64_t kind;\n"; + os << "if (failed(reader.readVarInt(kind)))\n" + << " return " << capitalize(kind) << "();\n"; + os << "switch (kind) "; + { + auto switchScope = os.scope("{\n", "}\n"); + for (const auto &it : llvm::enumerate(vec)) { + os << formatv("case {1}:\n return read{0}(context, reader);\n", + it.value()->getName(), it.index()); + } + os << "default:\n" + << " reader.emitError() << \"unknown attribute code: \" " + << "<< kind;\n" + << " return " << capitalize(kind) << "();\n"; + } + os << "return " << capitalize(kind) << "();\n"; +} + +void Generator::emitParse(StringRef kind, Record &x) { + char const *head = + R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; + mlir::raw_indented_ostream os(output); + std::string returnType = getCType(&x); + os << formatv(head, returnType, x.getName()); + DagInit *members = x.getValueAsDag("members"); + SmallVector argNames = + llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) { + return init->getAsUnquotedString(); + })); + StringRef builder = x.getValueAsString("cBuilder"); + emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, + returnType + "()", os); + os << "\n\n"; +} + +void printParseConditional(mlir::raw_indented_ostream &ios, + ArrayRef args, + ArrayRef argNames) { + ios << "if "; + auto parenScope = ios.scope("(", ") {"); + ios.indent(); + + auto listHelperName = [](StringRef name) { + return formatv("read{0}", capitalize(name)); + }; + + auto parsedArgs = + llvm::to_vector(make_filter_range(args, [](Init *const attr) { + Record *def = cast(attr)->getDef(); + if (def->isSubClassOf("Array")) + return true; + return !def->getValueAsString("cParser").empty(); + })); + + interleave( + zip(parsedArgs, argNames), + [&](std::tuple it) { + Record *attr = cast(std::get<0>(it))->getDef(); + std::string parser; + if (auto optParser = attr->getValueAsOptionalString("cParser")) { + parser = *optParser; + } else if (attr->isSubClassOf("Array")) { + Record *def = attr->getValueAsDef("elemT"); + bool composite = def->isSubClassOf("CompositeBytecode"); + if (!composite && def->isSubClassOf("AttributeKind")) + parser = "succeeded($_reader.readAttributes($_var))"; + else if (!composite && def->isSubClassOf("TypeKind")) + parser = "succeeded($_reader.readTypes($_var))"; + else + parser = ("succeeded($_reader.readList($_var, " + + listHelperName(std::get<1>(it)) + "))") + .str(); + } else { + PrintFatalError(attr->getLoc(), "No parser specified"); + } + std::string type = getCType(attr); + ios << format(parser, {{"$_reader", "reader"}, + {"$_resultType", type}, + {"$_var", std::get<1>(it)}}); + }, + [&]() { ios << " &&\n"; }); +} + +void Generator::emitParseHelper(StringRef kind, StringRef returnType, + StringRef builder, ArrayRef args, + ArrayRef argNames, + StringRef failure, + mlir::raw_indented_ostream &ios) { + auto funScope = ios.scope("{\n", "}"); + + if (args.empty()) { + ios << formatv("return get<{0}>(context);\n", returnType); + return; + } + + // Print decls. + std::string lastCType = ""; + for (auto [arg, name] : zip(args, argNames)) { + DefInit *first = dyn_cast(arg); + if (!first) + PrintFatalError("Unexpected type for " + name); + Record *def = first->getDef(); + + // Create variable decls, if there are a block of same type then create + // comma separated list of them. + std::string cType = getCType(def); + if (lastCType == cType) { + ios << ", "; + } else { + if (!lastCType.empty()) + ios << ";\n"; + ios << cType << " "; + } + ios << name; + lastCType = cType; + } + ios << ";\n"; + + auto listHelperName = [](StringRef name) { + return formatv("read{0}", capitalize(name)); + }; + + // Emit list helper functions. + for (auto [arg, name] : zip(args, argNames)) { + Record *attr = cast(arg)->getDef(); + if (!attr->isSubClassOf("Array")) + continue; + + // TODO: Dedupe readers. + Record *def = attr->getValueAsDef("elemT"); + if (!def->isSubClassOf("CompositeBytecode") && + (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) + continue; + + std::string returnType = getCType(def); + ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" + << returnType << "> "; + SmallVector args; + SmallVector argNames; + if (def->isSubClassOf("CompositeBytecode")) { + DagInit *members = def->getValueAsDag("members"); + args = llvm::to_vector(members->getArgs()); + argNames = llvm::to_vector( + map_range(members->getArgNames(), [](StringInit *init) { + return init->getAsUnquotedString(); + })); + } else { + args = {def->getDefInit()}; + argNames = {"temp"}; + } + StringRef builder = def->getValueAsString("cBuilder"); + emitParseHelper(kind, returnType, builder, args, argNames, "failure()", + ios); + ios << ";\n"; + } + + // Print parse conditional. + printParseConditional(ios, args, argNames); + // FIXME: Split out helper function. + + // Compute args to pass to create method. + auto passedArgs = llvm::to_vector(make_filter_range( + argNames, [](StringRef str) { return !str.starts_with("_"); })); + std::string argStr; + raw_string_ostream argStream(argStr); + interleaveComma(passedArgs, argStream, + [&](const std::string &str) { argStream << str; }); + // Return the invoked constructor. + ios << "\nreturn " + << format(builder, {{"$_resultType", returnType.str()}, + {"$_args", argStream.str()}}) + << ";\n"; + ios.unindent(); + + // TODO: Emit error in debug. + // This assumes the result types in error case can always be empty + // constructed. + ios << "}\nreturn " << failure << ";\n"; +} + +void Generator::emitPrint(StringRef kind, StringRef type, + ArrayRef> vec) { + char const *head = + R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; + mlir::raw_indented_ostream os(output); + os << formatv(head, type, kind); + auto funScope = os.scope("{\n", "}\n\n"); + + // Check that predicates specified if multiple bytecode instances. + for (auto [index, rec] : vec) { + StringRef pred = rec->getValueAsString("printerPredicate"); + if (vec.size() > 1 && pred.empty()) { + for (auto [index, rec] : vec) { + (void)index; + StringRef pred = rec->getValueAsString("printerPredicate"); + if (vec.size() > 1 && pred.empty()) + PrintError(rec->getLoc(), + "Requires parsing predicate given common cType"); + } + PrintFatalError("Unspecified for shared cType " + type); + } + } + + for (auto [index, rec] : vec) { + StringRef pred = rec->getValueAsString("printerPredicate"); + if (!pred.empty()) { + os << "if (" << format(pred, {{"$_val", kind.str()}}) << ") {\n"; + os.indent(); + } + + os << "writer.writeVarInt(/* " << rec->getName() << " */ " << index + << ");\n"; + + auto *members = rec->getValueAsDag("members"); + for (auto [arg, name] : + llvm::zip(members->getArgs(), members->getArgNames())) { + DefInit *def = dyn_cast(arg); + assert(def); + Record *memberRec = def->getDef(); + emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), os); + } + + if (!pred.empty()) { + os.unindent(); + os << "}\n"; + } + } +} + +void Generator::emitPrintHelper(Record *memberRec, StringRef kind, + StringRef parent, StringRef name, + mlir::raw_indented_ostream &ios) { + std::string getter; + if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); + cGetter && !cGetter->empty()) { + getter = format( + *cGetter, + {{"$_attrType", parent.str()}, + {"$_member", name.str()}, + {"$_getMember", "get" + convertToCamelFromSnakeCase(name, true)}}); + } else { + getter = + formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) + .str(); + } + + if (memberRec->isSubClassOf("Array")) { + Record *def = memberRec->getValueAsDef("elemT"); + if (!def->isSubClassOf("CompositeBytecode")) { + if (def->isSubClassOf("AttributeKind")) { + ios << "writer.writeAttributes(" << getter << ");\n"; + return; + } + if (def->isSubClassOf("TypeKind")) { + ios << "writer.writeTypes(" << getter << ");\n"; + return; + } + } + std::string returnType = getCType(def); + ios << "writer.writeList(" << getter << ", [&](" << returnType << " " + << kind << ") "; + auto lambdaScope = ios.scope("{\n", "});\n"); + return emitPrintHelper(def, kind, kind, kind, ios); + } + if (memberRec->isSubClassOf("CompositeBytecode")) { + auto *members = memberRec->getValueAsDag("members"); + for (auto [arg, argName] : + zip(members->getArgs(), members->getArgNames())) { + DefInit *def = dyn_cast(arg); + assert(def); + emitPrintHelper(def->getDef(), kind, parent, + argName->getAsUnquotedString(), ios); + } + } + + if (std::string printer = memberRec->getValueAsString("cPrinter").str(); + !printer.empty()) + ios << format(printer, {{"$_writer", "writer"}, + {"$_name", kind.str()}, + {"$_getter", getter}}) + << ";\n"; +} + +void Generator::emitPrintDispatch(StringRef kind, ArrayRef vec) { + mlir::raw_indented_ostream os(output); + char const *head = R"(static LogicalResult write{0}({0} {1}, + DialectBytecodeWriter &writer))"; + os << formatv(head, capitalize(kind), kind); + auto funScope = os.scope(" {\n", "}\n\n"); + + os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind + << ")"; + auto switchScope = os.scope("", ""); + for (StringRef type : vec) { + os << "\n.Case([&](" << type << " t)"; + auto caseScope = os.scope(" {\n", "})"); + os << "return write(t, writer), success();\n"; + } + os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; +} + +namespace { +/// Container of Attribute or Type for Dialect. +struct AttrOrType { + std::vector attr, type; +}; +} // namespace + +static bool emitBCRW(const RecordKeeper &records, raw_ostream &os) { + MapVector dialectAttrOrType; + for (auto &it : records.getAllDerivedDefinitions("DialectAttributes")) { + if (!selectedBcDialect.empty() && + it->getValueAsString("dialect") != selectedBcDialect) + continue; + dialectAttrOrType[it->getValueAsString("dialect")].attr = + it->getValueAsListOfDefs("elems"); + } + for (auto &it : records.getAllDerivedDefinitions("DialectTypes")) { + if (!selectedBcDialect.empty() && + it->getValueAsString("dialect") != selectedBcDialect) + continue; + dialectAttrOrType[it->getValueAsString("dialect")].type = + it->getValueAsListOfDefs("elems"); + } + + if (dialectAttrOrType.size() != 1) + PrintFatalError("Single dialect per invocation required (either only " + "one in input file or specified via dialect option)"); + + auto it = dialectAttrOrType.front(); + Generator gen(os); + + SmallVector *, 2> vecs; + SmallVector kinds; + vecs.push_back(&it.second.attr); + kinds.push_back("attribute"); + vecs.push_back(&it.second.type); + kinds.push_back("type"); + for (auto [vec, kind] : zip(vecs, kinds)) { + // Handle Attribute/Type emission. + std::map>> perType; + for (auto kt : llvm::enumerate(*vec)) + perType[getCType(kt.value())].emplace_back(kt.index(), kt.value()); + for (const auto &jt : perType) { + for (auto kt : jt.second) + gen.emitParse(kind, *std::get<1>(kt)); + gen.emitPrint(kind, jt.first, jt.second); + } + gen.emitParseDispatch(kind, *vec); + + SmallVector types; + for (const auto &it : perType) { + types.push_back(it.first); + } + gen.emitPrintDispatch(kind, types); + } + + return false; +} +} // namespace + +static mlir::GenRegistration + genBCRW("gen-bytecode", "Generate dialect bytecode readers/writers", + [](const RecordKeeper &records, raw_ostream &os) { + return emitBCRW(records, os); + }); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -9,6 +9,7 @@ EXPORT MLIR AttrOrTypeDefGen.cpp AttrOrTypeFormatGen.cpp + BytecodeDialectGen.cpp DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -90,6 +90,7 @@ ], includes = ["include"], deps = [ + ":BytecodeTdFiles", ":CallInterfacesTdFiles", ":CastInterfacesTdFiles", ":DataLayoutInterfacesTdFiles", @@ -116,6 +117,20 @@ deps = [":BuiltinDialectTdFiles"], ) +gentbl_cc_library( + name = "BuiltinDialectBytecodeGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-bytecode", "-bytecode-dialect=Builtin"], + "include/mlir/IR/BuiltinDialectBytecode.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/IR/BuiltinDialectBytecode.td", + deps = [":BuiltinDialectTdFiles"], +) + gentbl_cc_library( name = "BuiltinAttributesIncGen", strip_include_prefix = "include", @@ -276,6 +291,7 @@ deps = [ ":BuiltinAttributeInterfacesIncGen", ":BuiltinAttributesIncGen", + ":BuiltinDialectBytecodeGen", ":BuiltinDialectIncGen", ":BuiltinLocationAttributesIncGen", ":BuiltinOpsIncGen", @@ -926,6 +942,12 @@ ], ) +td_library( + name = "BytecodeTdFiles", + srcs = ["include/mlir/IR/BytecodeBase.td"], + includes = ["include"], +) + td_library( name = "CallInterfacesTdFiles", srcs = ["include/mlir/Interfaces/CallInterfaces.td"],