diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -0,0 +1,437 @@ +//===- LLVMDialect.h - MLIR LLVM dialec types -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the types for the LLVM dialect in MLIR. These MLIR types +// correspond to the LLVM IR type system. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ +#define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ + +#include "mlir/IR/Types.h" + +namespace llvm { +class ElementCount; +} + +namespace mlir { + +class DialectAsmParser; +class DialectAsmPrinter; + +namespace LLVM { +namespace detail { +struct LLVMFunctionTypeStorage; +struct LLVMIntegerTypeStorage; +struct LLVMPointerTypeStorage; +struct LLVMStructTypeStorage; +struct LLVMTypeAndSizeStorage; +} // namespace detail + +/// Base class for LLVM dialect types. +/// +/// The LLVM dialect in MLIR fully reflects the LLVM IR type system, prodiving a +/// sperate MLIR type for each LLVM IR type. All types are represted as separate +/// subclasses and are compatible with the isa/cast infrastructure. For +/// convenience, the base class provides most of the APIs available on +/// llvm::Type in addition to MLIR-compatible APIs. +/// +/// The LLVM dialect type system is closed: parametric types can only refer to +/// other LLVM dialect types. This is consistent with LLVM IR and enables a more +/// concise pretty-printing format. +/// +/// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR +/// context, have an immutable identifier (for most types except identified +/// structs, the entire type is the identifier) and are thread-safe. +class LLVMTypeNew : public Type { +public: + enum Kind { + // Keep non-parametric types contiguous in the enum. + VOID_TYPE = FIRST_LLVM_TYPE + 1, + HALF_TYPE, + BFLOAT_TYPE, + FLOAT_TYPE, + DOUBLE_TYPE, + FP128_TYPE, + X86_FP80_TYPE, + PPC_FP128_TYPE, + X86_MMX_TYPE, + LABEL_TYPE, + TOKEN_TYPE, + METADATA_TYPE, + // End of non-parametric types. + FUNCTION_TYPE, + INTEGER_TYPE, + POINTER_TYPE, + FIXED_VECTOR_TYPE, + SCALABLE_VECTOR_TYPE, + ARRAY_TYPE, + STRUCT_TYPE, + FIRST_NEW_LLVM_TYPE = VOID_TYPE, + LAST_NEW_LLVM_TYPE = STRUCT_TYPE, + FIRST_TRIVIAL_TYPE = VOID_TYPE, + LAST_TRIVIAL_TYPE = METADATA_TYPE + }; + + /// Inherit base constructors. + using Type::Type; + + /// Support for PointerLikeTypeTraits. + using Type::getAsOpaquePointer; + static LLVMTypeNew getFromOpaquePointer(const void *ptr) { + return LLVMTypeNew(static_cast(const_cast(ptr))); + } + + /// Support for isa/cast. + static bool kindof(unsigned kind) { + return FIRST_NEW_LLVM_TYPE <= kind && kind <= LAST_NEW_LLVM_TYPE; + } +}; + +/// Simple, non-parametric LLVM dialect type. +template +class LLVMTrivialType + : public Type::TypeBase { +public: + using Type::TypeBase::TypeBase; + using Base = LLVMTrivialType; + + static bool kindof(unsigned kind) { return kind == Kind; } + + /// Get or create an instance of the LLVM dialect type in the given context. + static Derived get(MLIRContext *context) { + return Type::TypeBase::get(context, + Kind); + } +}; + +// Batch-define trivial types. +#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \ + class ClassName : public LLVMTrivialType { \ + public: \ + using Base::Base; \ + } + +DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMTypeNew::VOID_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMTypeNew::HALF_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMTypeNew::BFLOAT_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMTypeNew::FLOAT_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMTypeNew::DOUBLE_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMTypeNew::FP128_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMTypeNew::X86_FP80_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMTypeNew::PPC_FP128_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMTypeNew::X86_MMX_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMTypeNew::TOKEN_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMTypeNew::LABEL_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMTypeNew::METADATA_TYPE); + +#undef DEFINE_TRIVIAL_LLVM_TYPE + +/// LLVM dialect array type. It is an aggregate type representing consecutive +/// elements in memory, parameterized by the number of elements and the element +/// type. +class LLVMArrayType : public Type::TypeBase { +public: + /// Inherit base constructors. + using Base::Base; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { return kind == LLVMTypeNew::ARRAY_TYPE; } + + /// Gets or creates an instance of LLVM dialect array type containing + /// `numElements` of `elementType`, in the same context as `elementType`. + static LLVMArrayType get(LLVMTypeNew elementType, unsigned numElements); + + /// Returns the element type of the array. + LLVMTypeNew getElementType(); + + /// Returns the number of elements in the array type. + unsigned getNumElements(); +}; + +/// LLVM dialect function type. It consists of a single return type (unlike MLIR +/// which can have multiple), a list of parameter types and can optionally be +/// variadic. +class LLVMFunctionType + : public Type::TypeBase { +public: + /// Inherit base constructors. + using Base::Base; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::FUNCTION_TYPE; + } + + /// Gets or creates an instance of LLVM dialect function in the same context + /// as the `result` type. + static LLVMFunctionType get(LLVMTypeNew result, + ArrayRef arguments, + bool isVarArg = false); + + /// Returns the result type of the function. + LLVMTypeNew getReturnType(); + + /// Returns the number of arguments to the function. + unsigned getNumParams(); + + /// Returns `i`-th argument of the function. Asserts on out-of-bounds. + LLVMTypeNew getParamType(unsigned i); + + /// Checks whether the function is variadic. + bool isVarArg(); + + /// Returns a list of argument types of the function. + ArrayRef getParams(); + ArrayRef params() { return getParams(); } +}; + +/// LLVM dialect integer type parameterized by bitwidth. +class LLVMIntegerType : public Type::TypeBase { +public: + /// Inherit base constructor. + using Base::Base; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::INTEGER_TYPE; + } + + /// Gets or creates an instance of the integer of the specified `bitwidth` in + /// the given context. + static LLVMIntegerType get(MLIRContext *ctx, unsigned bitwidth); + + /// Returns the bitwidth of an integer type. + unsigned getBitWidth(); +}; + +/// LLVM dialect pointer type. This type typically represents a reference to an +/// object in memory. It is parameterized by the element type and the address +/// space. +class LLVMPointerType : public Type::TypeBase { +public: + /// Inherit base constructors. + using Base::Base; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::POINTER_TYPE; + } + + /// Gets or creates an instance of LLVM dialect pointer type pointing to an + /// object of `pointee` type in the given address space. The pointer type is + /// created in the same context as `pointee`. + static LLVMPointerType get(LLVMTypeNew pointee, unsigned addressSpace = 0); + + /// Returns the pointed-to type. + LLVMTypeNew getElementType(); + + /// Returns the address space of the pointer. + unsigned getAddressSpace(); +}; + +/// LLVM dialect structure type representing a collection of different-typed +/// elements manipulated together. Structured can optionally be packed, meaning +/// that their elements immediately follow each other in memory without +/// accounting for potential alignment. +/// +/// Structure types can be identified (named) or literal. Literal structures +/// are uniquely represented by the list of types they contain and packedness. +/// Literal structure types are immutable after construction. +/// +/// Identified structures are uniquely represented by their name, a string. They +/// have a mutable component, consisting of the list of types they contain, +/// the packedness and the opacity bits. Identified structs can be created +/// without prodiving the lits of element types, making them suitable to +/// represent recursive, i.e. self-referring, structures. Identified structs +/// without body are considered opaque. For such structs, one can set the body. +/// Identified structs can be created as intentionally-opaque, implying that the +/// caller does not intend to ever set the body (e.g. forward-declarations of +/// structs from another module) and wants to disallow further modification of +/// the body. For intentionally-opaque structs or non-opaque structs with the +/// body, one is not allowed to set another body (however, one can set exactly +/// the same body). +/// +/// Note that the packedness of the struct takes place in uniquing of literal +/// structs, but does not in uniquing of identified structs. +class LLVMStructType : public Type::TypeBase { +public: + /// Inherit base construtors. + using Base::Base; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { return kind == LLVMTypeNew::STRUCT_TYPE; } + + /// Gets or creates an identified struct with the given name in the provided + /// context. Note that unlike llvm::StructType::create, this function will + /// _NOT_ rename a struct in case a struct with the same name already exists + /// in the context. Instead, it will just return the existing struct, + /// similarly to the rest of MLIR type ::get methods. + static LLVMStructType getIdentified(MLIRContext *context, StringRef name); + + /// Gets or creates a literal struct with the given body in the provided + /// context. + static LLVMStructType getLiteral(MLIRContext *context, ArrayRef types, + bool isPacked = false); + + /// Gets or creates an intentionally-opaque identified struct. Such struct + /// cannot have its body set. To create an opaque struct with mutable body, + /// use `getIdentified`. Note that unlike llvm::StructType::create, this + /// function will _NOT_ rename a struct in case a struct with the same name + /// already exists in the context. Instead, it will just return the existing + /// struct, similarly to the rest of MLIR type ::get methods. + static LLVMStructType getOpaque(StringRef name, MLIRContext *context); + + /// Set a body of an identified struct. Returns failure if the body could not + /// be set, e.g. if the struct already has a body or if it was marked as + /// intentionally opaque. This might happen in a multi-threaded context when a + /// different thread modified the struct after it was created. Most callers + /// are likely to assert this always succeeds, but it is possible to implement + /// a local renaming scheme based on the result of this call. + LogicalResult setBody(ArrayRef types, bool isPacked); + + /// Checks if a struct is packed. + bool isPacked(); + + /// Checks if a struct is identified. + bool isIdentified(); + + /// Checks if a struct is opaque. + bool isOpaque(); + + /// Returns the identifier of an identified struct. + StringRef identifier(); + + /// Returns the list of element types contained in a non-opaque struct. + ArrayRef getBody(); +}; + +/// LLVM dialect vector type, represents a sequence of elements that can be +/// processed as one, typically in SIMD context. This is a base class for fixed +/// and scalable vectors. +class LLVMVectorType : public LLVMTypeNew { +public: + /// Inherit base constructor. + using LLVMTypeNew::LLVMTypeNew; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::FIXED_VECTOR_TYPE || + kind == LLVMTypeNew::SCALABLE_VECTOR_TYPE; + } + + /// Returns the element type of the vector. + LLVMTypeNew getElementType(); + + /// Returns the number of elements in the vector. + llvm::ElementCount getElementCount(); +}; + +/// LLVM dialect fixed vector type, represents a sequence of elements of known +/// length that can be processed as one. +class LLVMFixedVectorType + : public Type::TypeBase { +public: + /// Inherit base constructor. + using Base::Base; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::FIXED_VECTOR_TYPE; + } + + /// Gets or creates a fixed vector type containing `numElements` of + /// `elementType` in the same context as `elementType`. + static LLVMFixedVectorType get(LLVMTypeNew elementType, unsigned numElements); + + /// Returns the number of elements in the fixed vector. + unsigned getNumElements(); +}; + +/// LLVM dialect scalable vector type, represents a sequence of elements of +/// unknown length that is known to be divisible by some constant. These +/// elements can be processed as one in SIMD context. +class LLVMScalableVectorType + : public Type::TypeBase { +public: + /// Inherit base constructor. + using Base::Base; + + /// Support for isa/cast. + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::SCALABLE_VECTOR_TYPE; + } + + /// Gets or creates a scalable vector type containing a non-zero multiple of + /// `minNumElements` of `elementType` in the same context as `elementType`. + static LLVMScalableVectorType get(LLVMTypeNew elementType, + unsigned minNumElements); + + /// Returns the scaling factor of the number of elements in the vector. The + /// vector contains at least the resulting number of elements, or any non-zero + /// multiple of this number. + unsigned getMinNumElements(); +}; + +namespace detail { +/// Parses an LLVM dialect type. +LLVMTypeNew parseType(DialectAsmParser &parser); + +/// Prints an LLVM Dialect type. +void printType(LLVMTypeNew type, DialectAsmPrinter &printer); +} // namespace detail + +} // namespace LLVM +} // namespace mlir + +namespace llvm { + +// LLVMTypeNew instances hash just like pointers. +template <> +struct DenseMapInfo { + static mlir::LLVM::LLVMTypeNew getEmptyKey() { + void *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::LLVM::LLVMTypeNew( + static_cast(pointer)); + } + static mlir::LLVM::LLVMTypeNew getTombstoneKey() { + void *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::LLVM::LLVMTypeNew( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::LLVM::LLVMTypeNew val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::LLVM::LLVMTypeNew lhs, + mlir::LLVM::LLVMTypeNew rhs) { + return lhs == rhs; + } +}; + +template <> +struct PointerLikeTypeTraits { + static inline void *getAsVoidPointer(mlir::LLVM::LLVMTypeNew I) { + return const_cast(I.getAsOpaquePointer()); + } + static inline mlir::LLVM::LLVMTypeNew getFromVoidPointer(void *P) { + return mlir::LLVM::LLVMTypeNew::getFromOpaquePointer(P); + } + static constexpr int NumLowBitsAvailable = 3; +}; + +} // namespace llvm + +#endif // MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -203,6 +203,9 @@ /// Parse a `=` token if present. virtual ParseResult parseOptionalEqual() = 0; + /// Parse a quoted string token if present. + virtual ParseResult parseOptionalString(StringRef *string) = 0; + /// Parse a given keyword. ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { auto loc = getCurrentLocation(); @@ -323,6 +326,8 @@ return success(); } + virtual OptionalParseResult parseOptionalType(Type &result) = 0; + /// Parse a 'x' separated dimension list. This populates the dimension list, /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on /// `?` otherwise. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/AsmParser/Parser.h" #include "llvm/Bitcode/BitcodeReader.h" @@ -1650,6 +1651,7 @@ namespace mlir { namespace LLVM { + namespace detail { struct LLVMDialectImpl { LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} @@ -1754,10 +1756,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// LLVMType. -//===----------------------------------------------------------------------===// - namespace mlir { namespace LLVM { namespace detail { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -0,0 +1,1085 @@ +//===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the types for the LLVM dialect in MLIR. These MLIR types +// correspond to the LLVM IR type system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/TypeSupport.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/TypeSize.h" + +using namespace mlir; +using namespace mlir::LLVM; + +//===----------------------------------------------------------------------===// +// Storage structures. +//===----------------------------------------------------------------------===// + +namespace { + +/// Type traits to verify that all types match. +template +struct are_same { + constexpr static bool value = + std::is_same::value && are_same::value; +}; +template +struct are_same { + constexpr static bool value = std::is_same::value; +}; + +/// A simple container that uses `ReservedBits` lowest bits of `UnderlyingTy` +/// integer type to store flags of `FlagTy` type. Since this is intended for +/// storing flags, setter/getter functions are expecting binary flags (e.g. 0x1, +/// 0x4) or masks (e.g. 0x1 | 0x4) rather than bit positions (e.g. 1, 3 like +/// BitVector does). This uses low bits for flags to avoid the need for getting +/// the bit width of `UnderlyingType`. In memory, this stores exactly one +/// `UnderlyingTy` object. +template +class IntAndFlags { +public: + /// Sets the given flags. + template + void setFlags(Arg arg, Args... args) { + static_assert(are_same::value, "incorrect flag type"); + UnderlyingTy mask = flags(arg, args...); + assert(mask == keepBits(mask) && + "bit mask contains values beyond the bit range"); + data |= mask; + } + + /// Resets all flags. + void resetFlags() { data &= ~kBitMask; } + + /// Tests if the given flag is set. + bool testFlag(FlagTy flag) const { + return data & static_cast(flag); + } + + /// Gets the integer value. + UnderlyingTy value() const { return data >> ReservedBits; } + + /// Sets the integer value. + void setValue(UnderlyingTy value) { + // If this ever triggers, consider restituting the bits currently + // taken for flags and store the flags separately. + assert((value << ReservedBits) >> ReservedBits == value && + "Value does not fit into IntAndFlags"); + data = value << ReservedBits | keepBits(data); + } + +private: + /// Bit mask storing the lowest bits. + constexpr static UnderlyingTy kBitMask = (1u << ReservedBits) - 1; + + /// Computes the mask by ORing the flags. + template + static UnderlyingTy flags(Args... args) { + UnderlyingTy mask = 0; + for (auto a : {args...}) + mask |= static_cast(a); + return mask; + } + + /// Keeps the bits storing the flags in the given value. + static UnderlyingTy keepBits(UnderlyingTy value) { return value & kBitMask; } + + /// Storage for the integer and the flags. + UnderlyingTy data = 0; +}; +} // end namespace + +namespace mlir { +namespace LLVM { +namespace detail { +/// Type storage for LLVM structure types. +/// +/// Structures are uniqued using: +/// - a bit indicating whether a struct is literal or identified; +/// - for identified structs, in addition to the bit: +/// - a string identifier; +/// - for literal structs, in addition to the bit: +/// - a list of contained types; +/// - a bit indicating whether a literal struct is packed. +/// +/// Identified structures only have a mutable component consisting of: +/// - a list of contained types; +/// - a bit indicating whether an identified struct is packed; +/// - a bit indicating whether an identified struct is intentionally opaque; +/// - a bit indicating whether an identified struct has been initialized. +/// Uninitialized structs are considered opaque by the user, and can be mutated. +/// Initialized and still opaque structs cannot be mutated. +/// +/// The struct storage consists of: +/// - immutable part: +/// - a pointer to the first element of the key (character for identified +/// structs, type for literal structs); +/// - the number of elements in the key packed together with bits indicating +/// whether a type is literal or identified, and the packedness bit for +/// literal structs only; +/// - mutable part: +/// - a pointer to the first contained type for identified structs only; +/// - the number of contained types packed together with bits of the mutable +/// component, for identified structs only. +struct LLVMStructTypeStorage : public ::mlir::TypeStorage { +private: + /// Flags used in the key component. + enum class KeyFlags { Identified = 0x1, Packed = 0x2 }; + /// Flags used in the mutable component. + enum class MutableFlags { Opaque = 0x1, Packed = 0x2, Initialized = 0x4 }; + +public: + /// Construction/uniquing key class for LLVM dialect structure storage. Note + /// that this is a transient helper data structure that is NOT stored. + /// Therefore, it intentionally avoids bit manipulation and type erasure in + /// pointers to make manipulation more straightforward. Not all elements of + /// the key participate in uniquing, but all elements participate in + /// construction. + class Key { + public: + /// Constructs a key for an identified struct. + Key(StringRef name, bool opaque) + : name(name), identified(true), packed(false), opaque(opaque) {} + /// Constructs a key for a literal struct. + Key(ArrayRef types, bool packed) + : types(types), identified(false), packed(packed), opaque(false) {} + + /// Checks a specific property of the struct. + bool isIdentified() const { return identified; } + bool isPacked() const { + assert(!isIdentified() && + "'packed' bit is not part of the key for identified stucts"); + return packed; + } + bool isOpaque() const { + assert(isIdentified() && + "'opaque' bit is meaningless on literal structs"); + return opaque; + } + + /// Returns the identifier of a key for identified structs. + StringRef identifier() const { + assert(isIdentified() && + "non-identified struct key canont have an identifier"); + return name; + } + + /// Returns the list of type contained in the key of a literal struct. + ArrayRef typelist() const { + assert(!isIdentified() && + "identified struct key cannot have a type list"); + return types; + } + + /// Returns the hash value of the key. This combines various flags into a + /// single value: the identified flag sets the first bit, and the packedness + /// flag sets the second bit. Opacity bit is only used for construction and + /// does not participate in uniquing. + llvm::hash_code hashValue() const { + unsigned flags = 0; + if (isIdentified()) { + flags |= 1; + return llvm::hash_combine(flags, identifier()); + } + if (isPacked()) + flags |= 2; + return llvm::hash_combine(flags, typelist()); + } + + /// Compares two keys. + bool operator==(const Key &other) const { + if (isIdentified() ^ other.isIdentified()) + return false; + if (isIdentified()) + return other.identifier().equals(identifier()); + return other.isPacked() == isPacked() && other.typelist() == typelist(); + } + + /// Copies dynamically-sized components of the key into the given allocator. + Key copyIntoAllocator(TypeStorageAllocator &allocator) const { + if (isIdentified()) + return Key(allocator.copyInto(name), opaque); + return Key(allocator.copyInto(types), packed); + } + + private: + ArrayRef types; + StringRef name; + bool identified; + bool packed; + bool opaque; + }; + using KeyTy = Key; + + /// Returns the string identifier of an identified struct. + StringRef identifier() const { + assert(isIdentified() && "requested identifier on a non-identified struct"); + return StringRef(static_cast(keyPtr), keySize()); + } + + /// Returns the list of types (partially) identifying a literal struct. + ArrayRef typelist() const { + // If this triggers, use identifiedStructBody() instead. + assert(!isIdentified() && "requested typelist on an identified struct"); + return ArrayRef(static_cast(keyPtr), keySize()); + } + + /// Returns the list of types contained in an identified struct. + ArrayRef identifiedStructBody() const { + // If this triggers, use typelist() instead. + assert(isIdentified() && + "requested struct body on a non-identified struct"); + return ArrayRef(identifiedBodyArray, identifiedBodySize()); + } + + /// Checks whether the struct is identified. + bool isIdentified() const { + return keySizeAndFlags.testFlag(KeyFlags::Identified); + } + + /// Checks whether the struct is packed (both literal and identified structs). + bool isPacked() const { + return isIdentified() + ? (identifiedBodySizeAndFlags.testFlag(MutableFlags::Packed)) + : (keySizeAndFlags.testFlag(KeyFlags::Packed)); + } + + /// Checks whether a struct is marked as intentionally opaque (an unintilized + /// struct is also considered opaque by the user, call isInitialized to check + /// that). + bool isOpaque() const { + return identifiedBodySizeAndFlags.testFlag(MutableFlags::Opaque); + } + + /// Checks whether an identified struct has been explicitly initialized either + /// by setting its body or by marking it as intentionally opaque. + bool isInitialized() const { + return identifiedBodySizeAndFlags.testFlag(MutableFlags::Initialized); + } + + /// Constructs the storage from the given key. This sets up the uniquing key + /// components and optionally the mutable component if they construction key + /// has the relevant information. In the latter case, the struct is considered + /// as initalized and can no longer be mutated. + LLVMStructTypeStorage(const KeyTy &key) { + if (key.isIdentified()) { + StringRef name = key.identifier(); + keyPtr = static_cast(name.data()); + setKeySize(name.size()); + keySizeAndFlags.setFlags(KeyFlags::Identified); + + // If the struct is being constructed directly as opaque, mark it as + // initialized. + if (key.isOpaque()) + identifiedBodySizeAndFlags.setFlags(MutableFlags::Initialized, + MutableFlags::Opaque); + } else { + ArrayRef types = key.typelist(); + keyPtr = static_cast(types.data()); + setKeySize(types.size()); + if (key.isPacked()) + keySizeAndFlags.setFlags(KeyFlags::Packed); + } + } + + /// Hook into the type unquing infrastructure. + bool operator==(const KeyTy &other) const { return getKey() == other; }; + static llvm::hash_code hashKey(const KeyTy &key) { return key.hashValue(); } + static LLVMStructTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMStructTypeStorage(key.copyIntoAllocator(allocator)); + } + + /// Sets the body of an identified struct. If the struct is already + /// initialized, succeeds only if the body is equal to the current body. Fails + /// if the struct is marked as intentionally opaque. The struct will be marked + /// as initialized as a result of this operation and can no longer be changed. + LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef body, + bool packed) { + if (!isIdentified()) + return failure(); + if (isInitialized()) + return success(!isOpaque() && body == identifiedStructBody() && + packed == isPacked()); + + identifiedBodySizeAndFlags.setFlags(MutableFlags::Initialized); + if (packed) + identifiedBodySizeAndFlags.setFlags(MutableFlags::Packed); + + ArrayRef typesInAllocator = allocator.copyInto(body); + identifiedBodyArray = typesInAllocator.data(); + setIdentifiedBodySize(typesInAllocator.size()); + + return success(); + } + +private: + /// Returns the number of elements in the key. + unsigned keySize() const { return keySizeAndFlags.value(); } + + /// Sets the number of elements in the key. + void setKeySize(unsigned value) { keySizeAndFlags.setValue(value); } + + /// Returns the number of types contained in an identified struct. + unsigned identifiedBodySize() const { + return identifiedBodySizeAndFlags.value(); + } + /// Sets the number of types contained in an identified struct. + void setIdentifiedBodySize(unsigned value) { + identifiedBodySizeAndFlags.setValue(value); + } + + /// Returns the key for the current storage. + Key getKey() const { + if (isIdentified()) + return Key(identifier(), isOpaque()); + return Key(typelist(), isPacked()); + } + + /// Pointer to the first element of the uniquing key. + // Note: cannot use PointerUnion because bump-ptr allocator does not guarantee + // address alignment. + const void *keyPtr = nullptr; + + /// Pointer to the first type contained in an identified struct. + const Type *identifiedBodyArray = nullptr; + + /// Size of the uniquing key combined with identified/literal and + /// packedness bits. + IntAndFlags<2, KeyFlags> keySizeAndFlags; + + /// Number of the types contained in an identified struct combined with + /// mutable flags. + IntAndFlags<3, MutableFlags> identifiedBodySizeAndFlags; +}; + +/// Type storage for LLVM dialect function types. These are uniqued using the +/// list of types they contain and the vararg bit. +struct LLVMFunctionTypeStorage : public ::mlir::TypeStorage { + using KeyTy = std::tuple, bool>; + + /// Construct a storage from the given components. The list is expected to be + /// allocated in the context. + LLVMFunctionTypeStorage(LLVMTypeNew result, ArrayRef arguments, + bool variadic) + : argumentTypes(arguments) { + returnTypeAndVariadic.setPointerAndInt(result, variadic); + } + + /// Hook into the type uniquing infrastructure. + static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMFunctionTypeStorage(std::get<0>(key), + allocator.copyInto(std::get<1>(key)), + std::get<2>(key)); + } + + static unsigned hashKey(const KeyTy &key) { + // LLVM doesn't like hashing bools in tuples. + return llvm::hash_combine(std::get<0>(key), std::get<1>(key), + static_cast(std::get<2>(key))); + } + + bool operator==(const KeyTy &key) const { + return std::make_tuple(getReturnType(), getArgumentTypes(), isVariadic()) == + key; + } + + /// Returns the list of function argument types. + ArrayRef getArgumentTypes() const { return argumentTypes; } + + /// Checks whether the function type is variadic. + bool isVariadic() const { return returnTypeAndVariadic.getInt(); } + + /// Returns the function result type. + LLVMTypeNew getReturnType() const { + return returnTypeAndVariadic.getPointer(); + } + +private: + /// Function result type packed with the variadic bit. + llvm::PointerIntPair returnTypeAndVariadic; + /// Argument types. + ArrayRef argumentTypes; +}; + +/// Storage type for LLVM dialect integer types. These are uniqued by bitwidth. +struct LLVMIntegerTypeStorage : public ::mlir::TypeStorage { + using KeyTy = unsigned; + + LLVMIntegerTypeStorage(unsigned width) : bitwidth(width) {} + + static LLVMIntegerTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMIntegerTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { return key == bitwidth; } + + unsigned bitwidth; +}; + +/// Storage type for LLVM dialect pointer types. These are uniqued by a pair of +/// element type and address space. +struct LLVMPointerTypeStorage : public ::mlir::TypeStorage { + using KeyTy = std::tuple; + + LLVMPointerTypeStorage(const KeyTy &key) + : pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {} + + static LLVMPointerTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMPointerTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return std::make_tuple(pointeeType, addressSpace) == key; + } + + LLVMTypeNew pointeeType; + unsigned addressSpace; +}; + +/// Common storage used for LLVM dialect types that need an element type and a +/// number: arrays, fixed and scalable vectors. The actual semantics of the +/// type is defined by its kind. +struct LLVMTypeAndSizeStorage : public ::mlir::TypeStorage { + using KeyTy = std::tuple; + + LLVMTypeAndSizeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), numElements(std::get<1>(key)) {} + + static LLVMTypeAndSizeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMTypeAndSizeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return std::make_tuple(elementType, numElements) == key; + } + + LLVMTypeNew elementType; + unsigned numElements; +}; + +} // end namespace detail +} // end namespace LLVM +} // end namespace mlir + +//===----------------------------------------------------------------------===// +// Type definitions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Array type. + +LLVMArrayType LLVMArrayType::get(LLVMTypeNew elementType, + unsigned numElements) { + assert(elementType && "expected non-null subtype"); + return Base::get(elementType.getContext(), LLVMTypeNew::ARRAY_TYPE, + elementType, numElements); +} + +LLVMTypeNew LLVMArrayType::getElementType() { return getImpl()->elementType; } + +unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; } + +//===----------------------------------------------------------------------===// +// Function type. + +LLVMFunctionType LLVMFunctionType::get(LLVMTypeNew result, + ArrayRef arguments, + bool isVarArg) { + assert(result && "expected non-null result"); + return Base::get(result.getContext(), LLVMTypeNew::FUNCTION_TYPE, result, + arguments, isVarArg); +} + +LLVMTypeNew LLVMFunctionType::getReturnType() { + return getImpl()->getReturnType(); +} + +unsigned LLVMFunctionType::getNumParams() { + return getImpl()->getArgumentTypes().size(); +} + +LLVMTypeNew LLVMFunctionType::getParamType(unsigned i) { + return getImpl()->getArgumentTypes()[i]; +} + +bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); } + +ArrayRef LLVMFunctionType::getParams() { + return getImpl()->getArgumentTypes(); +} + +//===----------------------------------------------------------------------===// +// Integer type. + +LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) { + return Base::get(ctx, LLVMTypeNew::INTEGER_TYPE, bitwidth); +} + +unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; } + +//===----------------------------------------------------------------------===// +// Pointer type. + +LLVMPointerType LLVMPointerType::get(LLVMTypeNew pointee, + unsigned addressSpace) { + assert(pointee && "expected non-null subtype"); + return Base::get(pointee.getContext(), LLVMTypeNew::POINTER_TYPE, pointee, + addressSpace); +} + +LLVMTypeNew LLVMPointerType::getElementType() { return getImpl()->pointeeType; } + +unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; } + +//===----------------------------------------------------------------------===// +// Struct type. + +LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, + StringRef name) { + return Base::get(context, LLVMTypeNew::STRUCT_TYPE, name, /*opaque=*/false); +} + +LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, + ArrayRef types, bool isPacked) { + return Base::get(context, LLVMTypeNew::STRUCT_TYPE, types, isPacked); +} + +LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) { + return Base::get(context, LLVMTypeNew::STRUCT_TYPE, name, /*opaque=*/true); +} + +LogicalResult LLVMStructType::setBody(ArrayRef types, bool isPacked) { + assert(isIdentified() && "can only set bodies of identified structs"); + return Base::mutate(types, isPacked); +} + +bool LLVMStructType::isPacked() { return getImpl()->isPacked(); } +bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); } +bool LLVMStructType::isOpaque() { + return getImpl()->isOpaque() || !getImpl()->isInitialized(); +} +StringRef LLVMStructType::identifier() { return getImpl()->identifier(); } +ArrayRef LLVMStructType::getBody() { + return isIdentified() ? getImpl()->identifiedStructBody() + : getImpl()->typelist(); +} + +//===----------------------------------------------------------------------===// +// Vector types. + +LLVMTypeNew LLVMVectorType::getElementType() { + // Both derived classes share the implementation type. + return static_cast(impl)->elementType; +} + +llvm::ElementCount LLVMVectorType::getElementCount() { + // Both derived classes share the implementation type. + return llvm::ElementCount( + static_cast(impl)->numElements, + this->isa()); +} + +LLVMFixedVectorType LLVMFixedVectorType::get(LLVMTypeNew elementType, + unsigned numElements) { + assert(elementType && "expected non-null subtype"); + return Base::get(elementType.getContext(), LLVMTypeNew::FIXED_VECTOR_TYPE, + elementType, numElements) + .cast(); +} + +unsigned LLVMFixedVectorType::getNumElements() { + return getImpl()->numElements; +} + +LLVMScalableVectorType LLVMScalableVectorType::get(LLVMTypeNew elementType, + unsigned minNumElements) { + assert(elementType && "expected non-null subtype"); + return Base::get(elementType.getContext(), LLVMTypeNew::SCALABLE_VECTOR_TYPE, + elementType, minNumElements) + .cast(); +} + +unsigned LLVMScalableVectorType::getMinNumElements() { + return getImpl()->numElements; +} + +//===----------------------------------------------------------------------===// +// Printing. +//===----------------------------------------------------------------------===// + +static void printTypeImpl(llvm::raw_ostream &os, LLVMTypeNew type, + llvm::SetVector &stack); + +/// Returns the keyword to use for the given type. +static StringRef getTypeKeyword(LLVMTypeNew type) { + switch (type.getKind()) { + case LLVMTypeNew::VOID_TYPE: + return "void"; + case LLVMTypeNew::HALF_TYPE: + return "half"; + case LLVMTypeNew::BFLOAT_TYPE: + return "bfloat"; + case LLVMTypeNew::FLOAT_TYPE: + return "float"; + case LLVMTypeNew::DOUBLE_TYPE: + return "double"; + case LLVMTypeNew::FP128_TYPE: + return "fp128"; + case LLVMTypeNew::X86_FP80_TYPE: + return "x86_fp80"; + case LLVMTypeNew::PPC_FP128_TYPE: + return "ppc_fp128"; + case LLVMTypeNew::X86_MMX_TYPE: + return "x86_mmx"; + case LLVMTypeNew::TOKEN_TYPE: + return "token"; + case LLVMTypeNew::LABEL_TYPE: + return "label"; + case LLVMTypeNew::METADATA_TYPE: + return "metadata"; + case LLVMTypeNew::FUNCTION_TYPE: + return "func"; + case LLVMTypeNew::INTEGER_TYPE: + return "i"; + case LLVMTypeNew::POINTER_TYPE: + return "ptr"; + case LLVMTypeNew::FIXED_VECTOR_TYPE: + case LLVMTypeNew::SCALABLE_VECTOR_TYPE: + return "vec"; + case LLVMTypeNew::ARRAY_TYPE: + return "array"; + case LLVMTypeNew::STRUCT_TYPE: + return "struct"; + } + llvm_unreachable("unhandled type kind"); +} + +/// Prints the body of a structure type. Uses `stack` to avoid printing +/// recursive structs indefinitely. +static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type, + llvm::SetVector &stack) { + if (type.isIdentified() && type.isOpaque()) { + os << "opaque"; + return; + } + + if (type.isPacked()) + os << "packed "; + + // Put the current type on stack to avoid infinite recursion. + os << '('; + if (type.isIdentified()) + stack.insert(type.identifier()); + llvm::interleaveComma(type.getBody(), os, [&](Type subtype) { + printTypeImpl(os, subtype.cast(), stack); + }); + if (type.isIdentified()) + stack.pop_back(); + os << ')'; +} + +/// Prints a structure type. Uses `stack` to keep track of the identifiers of +/// the structs being printed. Checks if the identifier of a struct is contained +/// in `stack`, i.e. whether a self-reference to a recursive stack is being +/// printed, and only prints the name to avoid infinite recursion. +static void printStructType(llvm::raw_ostream &os, LLVMStructType type, + llvm::SetVector &stack) { + os << "<"; + if (type.isIdentified()) { + os << '"' << type.identifier() << '"'; + // If we are printing a reference to one of the enclosing structs, just + // print the name and stop to avoid infinitely long output. + if (stack.count(type.identifier())) { + os << '>'; + return; + } + os << ", "; + } + + printStructTypeBody(os, type, stack); + os << '>'; +} + +/// Prints a type containing a fixed number of elements. +template +static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type, + llvm::SetVector &stack) { + os << '<' << type.getNumElements() << " x "; + printTypeImpl(os, type.getElementType(), stack); + os << '>'; +} + +/// Prints a function type. +static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType, + llvm::SetVector &stack) { + os << '<'; + printTypeImpl(os, funcType.getReturnType(), stack); + os << " ("; + llvm::interleaveComma(funcType.getParams(), os, + [&os, &stack](LLVMTypeNew subtype) { + printTypeImpl(os, subtype, stack); + }); + if (funcType.isVarArg()) { + if (funcType.getNumParams() != 0) + os << ", "; + os << "..."; + } + os << ")>"; +} + +/// Prints the given LLVM dialect type recursively. This leverages closedness of +/// the LLVM dialect type system to avoid printing the dialect prefix +/// repeatedly. For recursive structures, only prints the name of the structure +/// when printing a self-reference. Note that this does not apply to sibling +/// references. For example, +/// struct<"a", (ptr>)> +/// struct<"c", (ptr>)>>, +/// ptr>)>>)> +/// note that "b" is printed twice. +static void printTypeImpl(llvm::raw_ostream &os, LLVMTypeNew type, + llvm::SetVector &stack) { + if (!type) { + os << "<>"; + return; + } + + unsigned kind = type.getKind(); + os << getTypeKeyword(type); + + // Trivial types only consist of their keyword. + if (LLVMTypeNew::FIRST_TRIVIAL_TYPE <= kind && + kind <= LLVMTypeNew::LAST_TRIVIAL_TYPE) + return; + + if (auto intType = type.dyn_cast()) { + os << intType.getBitWidth(); + return; + } + + if (auto ptrType = type.dyn_cast()) { + os << '<'; + printTypeImpl(os, ptrType.getElementType(), stack); + if (ptrType.getAddressSpace() != 0) + os << ", " << ptrType.getAddressSpace(); + os << '>'; + return; + } + + if (auto arrayType = type.dyn_cast()) + return printArrayOrVectorType(os, arrayType, stack); + if (auto vectorType = type.dyn_cast()) + return printArrayOrVectorType(os, vectorType, stack); + + if (auto vectorType = type.dyn_cast()) { + os << "'; + return; + } + + if (auto structType = type.dyn_cast()) + return printStructType(os, structType, stack); + + printFunctionType(os, type.cast(), stack); +} + +void mlir::LLVM::detail::printType(LLVMTypeNew type, + DialectAsmPrinter &printer) { + llvm::SetVector stack; + return printTypeImpl(printer.getStream(), type, stack); +} + +//===----------------------------------------------------------------------===// +// Parsing. +//===----------------------------------------------------------------------===// + +static LLVMTypeNew parseTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack); + +/// Helper to be chained with other parsing functions. +static ParseResult parseTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack, + LLVMTypeNew &result) { + result = parseTypeImpl(parser, stack); + return success(result != nullptr); +} + +/// Parses an LLVM dialect function type. +/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` +static LLVMFunctionType parseFunctionType(DialectAsmParser &parser, + llvm::SetVector &stack) { + LLVMTypeNew returnType; + if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) || + parser.parseLParen()) + return LLVMFunctionType(); + + // Function type without arguments. + if (succeeded(parser.parseOptionalRParen())) { + if (succeeded(parser.parseGreater())) + return LLVMFunctionType::get(returnType, {}, /*isVarArg=*/false); + return LLVMFunctionType(); + } + + // Parse arguments. + SmallVector argTypes; + do { + if (succeeded(parser.parseOptionalEllipsis())) { + if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) + return LLVMFunctionType(); + return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/true); + } + + argTypes.push_back(parseTypeImpl(parser, stack)); + if (!argTypes.back()) + return LLVMFunctionType(); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) + return LLVMFunctionType(); + return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/false); +} + +/// Parses an LLVM dialect pointer type. +/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` +static LLVMPointerType parsePointerType(DialectAsmParser &parser, + llvm::SetVector &stack) { + LLVMTypeNew elementType; + if (parser.parseLess() || parseTypeImpl(parser, stack, elementType)) + return LLVMPointerType(); + + unsigned addressSpace = 0; + if (succeeded(parser.parseOptionalComma()) && + failed(parser.parseInteger(addressSpace))) + return LLVMPointerType(); + if (failed(parser.parseGreater())) + return LLVMPointerType(); + return LLVMPointerType::get(elementType, addressSpace); +} + +/// Parses an LLVM dialect vector type. +/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` +/// Supports both fixed and scalable vectors. +static LLVMVectorType parseVectorType(DialectAsmParser &parser, + llvm::SetVector &stack) { + SmallVector dims; + llvm::SMLoc dimPos; + LLVMTypeNew elementType; + if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || + parser.parseDimensionList(dims, /*allowDynamic=*/true) || + parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) + return LLVMVectorType(); + + if (dims.empty() || dims.size() > 2 || (dims.size() == 2 ^ dims[0] == -1) || + (dims.size() == 2 && dims[1] == -1)) { + parser.emitError(dimPos) + << "expected '? x x ' or ' x '"; + return LLVMVectorType(); + } + + bool isScalable = dims.size() == 2; + return isScalable ? static_cast( + LLVMScalableVectorType::get(elementType, dims[1])) + : LLVMFixedVectorType::get(elementType, dims[0]); +} + +/// Parses an LLVM dialect array type. +/// llvm-type ::= `array<` integer `x` llvm-type `>` +static LLVMArrayType parseArrayType(DialectAsmParser &parser, + llvm::SetVector &stack) { + SmallVector dims; + llvm::SMLoc sizePos; + LLVMTypeNew elementType; + if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || + parser.parseDimensionList(dims, /*allowDynamic=*/false) || + parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) + return LLVMArrayType(); + + if (dims.size() != 1) { + parser.emitError(sizePos) << "expected ? x "; + return LLVMArrayType(); + } + + return LLVMArrayType::get(elementType, dims[0]); +} + +/// Attempts to set the body of an identified structure type. Reports a parsing +/// error at `subtypesLoc` in case of failure, uses `stack` to make sure the +/// types printed in the error message look like they did when parsed. +static LLVMStructType trySetStructBody(LLVMStructType type, + ArrayRef subtypes, bool isPacked, + DialectAsmParser &parser, + llvm::SMLoc subtypesLoc, + llvm::SetVector &stack) { + if (succeeded(type.setBody(subtypes, isPacked))) + return type; + + std::string currentBody; + llvm::raw_string_ostream currentBodyStream(currentBody); + printStructTypeBody(currentBodyStream, type, stack); + (parser.emitError(subtypesLoc) + << "identified type already used with a different body") + .attachNote() + << "existing body: " << currentBodyStream.str(); + return LLVMStructType(); +} + +/// Parses an LLVM dialect structure type. +/// llvm-type ::= `struct<` (string-literal `,`)? `packed`? +/// `(` llvm-type-list `)` `>` +/// | `struct<` string-literal `>` +/// | `struct<` string-literal `, opaque>` +static LLVMStructType parseStructType(DialectAsmParser &parser, + llvm::SetVector &stack) { + MLIRContext *ctx = parser.getBuilder().getContext(); + + if (failed(parser.parseLess())) + return LLVMStructType(); + + // If we are parsing a self-reference to a recursive struct, i.e. the parsing + // stack already contains a struct with the same identifier, bail out after + // the name. + StringRef name; + bool isIdentified = succeeded(parser.parseOptionalString(&name)); + if (isIdentified) { + if (stack.count(name)) { + if (failed(parser.parseGreater())) + return LLVMStructType(); + auto t = LLVMStructType::getIdentified(ctx, name); + return t; + } + if (failed(parser.parseComma())) + return LLVMStructType(); + } + + // Handle intentionally opaque structs. + llvm::SMLoc kwLoc = parser.getCurrentLocation(); + if (succeeded(parser.parseOptionalKeyword("opaque"))) { + if (!isIdentified) + return parser.emitError(kwLoc, "only identified structs can be opaque"), + LLVMStructType(); + if (failed(parser.parseGreater())) + return LLVMStructType(); + auto type = LLVMStructType::getOpaque(name, ctx); + if (!type.isOpaque()) { + parser.emitError(kwLoc, "redeclaring defined struct as opaque"); + return LLVMStructType(); + } + return type; + } + + // Check for packedness. + bool isPacked = succeeded(parser.parseOptionalKeyword("packed")); + if (failed(parser.parseLParen())) + return LLVMStructType(); + + // Fast pass for structs with zero subtypes. + if (succeeded(parser.parseOptionalRParen())) { + if (failed(parser.parseGreater())) + return LLVMStructType(); + if (!isIdentified) + return LLVMStructType::getLiteral(ctx, {}, isPacked); + auto type = LLVMStructType::getIdentified(ctx, name); + return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack); + } + + // Parse subtypes. For identified structs, put the identifier of the struct on + // the stack to support self-references in the recursive calls. + SmallVector subtypes; + llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); + do { + if (isIdentified) + stack.insert(name); + Type type = parseTypeImpl(parser, stack); + if (!type) + return LLVMStructType(); + subtypes.push_back(type); + if (isIdentified) + stack.pop_back(); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen() || parser.parseGreater()) + return LLVMStructType(); + + // Construct the struct with body. + if (!isIdentified) + return LLVMStructType::getLiteral(ctx, subtypes, isPacked); + auto type = LLVMStructType::getIdentified(ctx, name); + return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack); +} + +/// Parses one of the LLVM dialect types. +static LLVMTypeNew parseTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack) { + // Special case for integers (i[1-9][0-9]*) that are literals rather than + // keywords for the parser, so they are not caught by the main dispatch below. + // Try parsing it a built-in integer type instead. + Type maybeIntegerType; + MLIRContext *ctx = parser.getBuilder().getContext(); + llvm::SMLoc keyLoc = parser.getCurrentLocation(); + OptionalParseResult result = parser.parseOptionalType(maybeIntegerType); + if (result.hasValue()) { + if (failed(*result)) + return LLVMTypeNew(); + + if (!maybeIntegerType.isSignlessInteger()) { + parser.emitError(keyLoc) << "unexpected type, expected i* or keyword"; + return LLVMTypeNew(); + } + return LLVMIntegerType::get(ctx, maybeIntegerType.getIntOrFloatBitWidth()); + } + + // Dispatch to concrete functions. + StringRef key; + if (failed(parser.parseKeyword(&key))) + return LLVMTypeNew(); + + return llvm::StringSwitch>(key) + .Case("void", [&] { return LLVMVoidType::get(ctx); }) + .Case("half", [&] { return LLVMHalfType::get(ctx); }) + .Case("bfloat", [&] { return LLVMBFloatType::get(ctx); }) + .Case("float", [&] { return LLVMFloatType::get(ctx); }) + .Case("double", [&] { return LLVMDoubleType::get(ctx); }) + .Case("fp128", [&] { return LLVMFP128Type::get(ctx); }) + .Case("x86_fp80", [&] { return LLVMX86FP80Type::get(ctx); }) + .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); }) + .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); }) + .Case("token", [&] { return LLVMTokenType::get(ctx); }) + .Case("label", [&] { return LLVMLabelType::get(ctx); }) + .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) + .Case("func", [&] { return parseFunctionType(parser, stack); }) + .Case("ptr", [&] { return parsePointerType(parser, stack); }) + .Case("vec", [&] { return parseVectorType(parser, stack); }) + .Case("array", [&] { return parseArrayType(parser, stack); }) + .Case("struct", [&] { return parseStructType(parser, stack); }) + .Default([&] { + parser.emitError(keyLoc) << "unknown LLVM type: " << key; + return LLVMTypeNew(); + })(); +} + +LLVMTypeNew mlir::LLVM::detail::parseType(DialectAsmParser &parser) { + llvm::SetVector stack; + return parseTypeImpl(parser, stack); +} diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -237,6 +237,17 @@ return success(parser.consumeIf(Token::star)); } + /// Parses a quoted string token if present. + ParseResult parseOptionalString(StringRef *string) override { + if (!parser.getToken().is(Token::string)) + return failure(); + + if (string) + *string = parser.getTokenSpelling().drop_front().drop_back(); + parser.consumeToken(); + return success(); + } + /// Returns if the current token corresponds to a keyword. bool isCurrentTokenAKeyword() const { return parser.getToken().is(Token::bare_identifier) || @@ -297,6 +308,10 @@ return parser.parseDimensionListRanked(dimensions, allowDynamic); } + OptionalParseResult parseOptionalType(Type &result) override { + return parser.parseOptionalType(result); + } + private: /// The full symbol specification. StringRef fullSpec; diff --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir @@ -0,0 +1,101 @@ +// RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s + +func @repeated_struct_name() { + "some.op"() : () -> !llvm2.struct<"a", (ptr>)> + // expected-error @+2 {{identified type already used with a different body}} + // expected-note @+1 {{existing body: (ptr>)}} + "some.op"() : () -> !llvm2.struct<"a", (i32)> +} + +// ----- + +func @repeated_struct_name_packed() { + "some.op"() : () -> !llvm2.struct<"a", packed (i32)> + // expected-error @+2 {{identified type already used with a different body}} + // expected-note @+1 {{existing body: packed (i32)}} + "some.op"() : () -> !llvm2.struct<"a", (i32)> +} + +// ----- + +func @repeated_struct_opaque() { + "some.op"() : () -> !llvm2.struct<"a", opaque> + // expected-error @+2 {{identified type already used with a different body}} + // expected-note @+1 {{existing body: opaque}} + "some.op"() : () -> !llvm2.struct<"a", ()> +} + +// ----- + +func @repeated_struct_opaque_non_empty() { + "some.op"() : () -> !llvm2.struct<"a", opaque> + // expected-error @+2 {{identified type already used with a different body}} + // expected-note @+1 {{existing body: opaque}} + "some.op"() : () -> !llvm2.struct<"a", (i32, i32)> +} + +// ----- + +func @repeated_struct_opaque_redefinition() { + "some.op"() : () -> !llvm2.struct<"a", ()> + // expected-error @+1 {{redeclaring defined struct as opaque}} + "some.op"() : () -> !llvm2.struct<"a", opaque> +} + +// ----- + +func @struct_literal_opaque() { + // expected-error @+1 {{only identified structs can be opaque}} + "some.op"() : () -> !llvm2.struct +} + +// ----- + +func @unexpected_type() { + // expected-error @+1 {{unexpected type, expected i* or keyword}} + "some.op"() : () -> !llvm2.f32 +} + +// ----- + +func @unexpected_type() { + // expected-error @+1 {{unknown LLVM type}} + "some.op"() : () -> !llvm2.ifoo +} + +// ----- + +func @explicitly_opaque_struct() { + "some.op"() : () -> !llvm2.struct<"a", opaque> + // expected-error @+2 {{identified type already used with a different body}} + // expected-note @+1 {{existing body: opaque}} + "some.op"() : () -> !llvm2.struct<"a", ()> +} + +// ----- + +func @dynamic_vector() { + // expected-error @+1 {{expected '? x x ' or ' x '}} + "some.op"() : () -> !llvm2.vec +} + +// ----- + +func @dynamic_scalable_vector() { + // expected-error @+1 {{expected '? x x ' or ' x '}} + "some.op"() : () -> !llvm2.vec +} + +// ----- + +func @unscalable_vector() { + // expected-error @+1 {{expected '? x x ' or ' x '}} + "some.op"() : () -> !llvm2.vec<4 x 4 x i32> +} + +// ----- + +func @dynamic_array() { + // expected-error @+1 {{expected type} + "some.op"() : () -> !llvm.array +} diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -0,0 +1,184 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | mlir-opt -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @primitive +func @primitive() { + // CHECK: !llvm2.void + "some.op"() : () -> !llvm2.void + // CHECK: !llvm2.half + "some.op"() : () -> !llvm2.half + // CHECK: !llvm2.bfloat + "some.op"() : () -> !llvm2.bfloat + // CHECK: !llvm2.float + "some.op"() : () -> !llvm2.float + // CHECK: !llvm2.double + "some.op"() : () -> !llvm2.double + // CHECK: !llvm2.fp128 + "some.op"() : () -> !llvm2.fp128 + // CHECK: !llvm2.x86_fp80 + "some.op"() : () -> !llvm2.x86_fp80 + // CHECK: !llvm2.ppc_fp128 + "some.op"() : () -> !llvm2.ppc_fp128 + // CHECK: !llvm2.x86_mmx + "some.op"() : () -> !llvm2.x86_mmx + // CHECK: !llvm2.token + "some.op"() : () -> !llvm2.token + // CHECK: !llvm2.label + "some.op"() : () -> !llvm2.label + // CHECK: !llvm2.metadata + "some.op"() : () -> !llvm2.metadata + return +} + +// CHECK-LABEL: @func +func @func() { + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + return +} + +// CHECK-LABEL: @integer +func @integer() { + // CHECK: !llvm2.i1 + "some.op"() : () -> !llvm2.i1 + // CHECK: !llvm2.i8 + "some.op"() : () -> !llvm2.i8 + // CHECK: !llvm2.i16 + "some.op"() : () -> !llvm2.i16 + // CHECK: !llvm2.i32 + "some.op"() : () -> !llvm2.i32 + // CHECK: !llvm2.i64 + "some.op"() : () -> !llvm2.i64 + // CHECK: !llvm2.i57 + "some.op"() : () -> !llvm2.i57 + // CHECK: !llvm2.i129 + "some.op"() : () -> !llvm2.i129 + return +} + +// CHECK-LABEL: @ptr +func @ptr() { + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr> + "some.op"() : () -> !llvm2.ptr> + // CHECK: !llvm2.ptr>>>> + "some.op"() : () -> !llvm2.ptr>>>> + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr, 9> + "some.op"() : () -> !llvm2.ptr, 9> + return +} + +// CHECK-LABEL: @vec +func @vec() { + // CHECK: !llvm2.vec<4 x i32> + "some.op"() : () -> !llvm2.vec<4 x i32> + // CHECK: !llvm2.vec<4 x float> + "some.op"() : () -> !llvm2.vec<4 x float> + // CHECK: !llvm2.vec + "some.op"() : () -> !llvm2.vec + // CHECK: !llvm2.vec + "some.op"() : () -> !llvm2.vec + // CHECK: !llvm2.vec<4 x ptr> + "some.op"() : () -> !llvm2.vec<4 x ptr> + return +} + +// CHECK-LABEL: @array +func @array() { + // CHECK: !llvm2.array<10 x i32> + "some.op"() : () -> !llvm2.array<10 x i32> + // CHECK: !llvm2.array<8 x float> + "some.op"() : () -> !llvm2.array<8 x float> + // CHECK: !llvm2.array<10 x ptr> + "some.op"() : () -> !llvm2.array<10 x ptr> + // CHECK: !llvm2.array<10 x array<4 x float>> + "some.op"() : () -> !llvm2.array<10 x array<4 x float>> + return +} + +// CHECK-LABEL: @literal_struct +func @literal_struct() { + // CHECK: !llvm2.struct<()> + "some.op"() : () -> !llvm2.struct<()> + // CHECK: !llvm2.struct<(i32)> + "some.op"() : () -> !llvm2.struct<(i32)> + // CHECK: !llvm2.struct<(float, i32)> + "some.op"() : () -> !llvm2.struct<(float, i32)> + // CHECK: !llvm2.struct<(struct<(i32)>)> + "some.op"() : () -> !llvm2.struct<(struct<(i32)>)> + // CHECK: !llvm2.struct<(i32, struct<(i32)>, float)> + "some.op"() : () -> !llvm2.struct<(i32, struct<(i32)>, float)> + + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct)> + "some.op"() : () -> !llvm2.struct)> + // CHECK: !llvm2.struct, float)> + "some.op"() : () -> !llvm2.struct, float)> + + // CHECK: !llvm2.struct<(struct)> + "some.op"() : () -> !llvm2.struct<(struct)> + // CHECK: !llvm2.struct)> + "some.op"() : () -> !llvm2.struct)> + return +} + +// CHECK-LABEL: @identified_struct +func @identified_struct() { + // CHECK: !llvm2.struct<"empty", ()> + "some.op"() : () -> !llvm2.struct<"empty", ()> + // CHECK: !llvm2.struct<"opaque", opaque> + "some.op"() : () -> !llvm2.struct<"opaque", opaque> + // CHECK: !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr>)> + "some.op"() : () -> !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr>)> + // CHECK: !llvm2.struct<"self-recursive", (ptr>)> + "some.op"() : () -> !llvm2.struct<"self-recursive", (ptr>)> + // CHECK: !llvm2.struct<"unpacked", (i32)> + "some.op"() : () -> !llvm2.struct<"unpacked", (i32)> + // CHECK: !llvm2.struct<"packed", packed (i32)> + "some.op"() : () -> !llvm2.struct<"packed", packed (i32)> + // CHECK: !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> + "some.op"() : () -> !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> + + // CHECK: !llvm2.struct<"mutually-a", (ptr, 3>)>>)> + "some.op"() : () -> !llvm2.struct<"mutually-a", (ptr, 3>)>>)> + // CHECK: !llvm2.struct<"mutually-b", (ptr>)>, 3>)> + "some.op"() : () -> !llvm2.struct<"mutually-b", (ptr>)>, 3>)> + // CHECK: !llvm2.struct<"referring-another", (ptr>)> + "some.op"() : () -> !llvm2.struct<"referring-another", (ptr>)> + + // CHECK: !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> + "some.op"() : () -> !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> + // CHECK: !llvm2.array<10 x struct<"array-of-structs", (i32)>> + "some.op"() : () -> !llvm2.array<10 x struct<"array-of-structs", (i32)>> + // CHECK: !llvm2.ptr> + "some.op"() : () -> !llvm2.ptr> + return +} + diff --git a/mlir/test/lib/Dialect/LLVMIR/LLVMTypeTestDialect.cpp b/mlir/test/lib/Dialect/LLVMIR/LLVMTypeTestDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/LLVMIR/LLVMTypeTestDialect.cpp @@ -0,0 +1,50 @@ +#ifndef DIALECT_LLVMIR_LLVMTYPETESTDIALECT_H_ +#define DIALECT_LLVMIR_LLVMTYPETESTDIALECT_H_ + +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Dialect.h" + +namespace mlir { +namespace LLVM { +namespace { +class LLVMDialectNewTypes : public Dialect { +public: + LLVMDialectNewTypes(MLIRContext *ctx) : Dialect(getDialectNamespace(), ctx) { + // clang-format off + addTypes(); + // clang-format on + } + static StringRef getDialectNamespace() { return "llvm2"; } + + Type parseType(DialectAsmParser &parser) const override { + return detail::parseType(parser); + } + void printType(Type type, DialectAsmPrinter &printer) const override { + detail::printType(type.cast(), printer); + } +}; +} // namespace +} // namespace LLVM +} // namespace mlir + +static mlir::DialectRegistration reg; + +#endif // DIALECT_LLVMIR_LLVMTYPETESTDIALECT_H_