diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -45,6 +45,12 @@ namespace detail { struct LLVMTypeStorage; struct LLVMDialectImpl; + +struct LLVMFunctionTypeStorage; +struct LLVMIntegerTypeStorage; +struct LLVMPointerTypeStorage; +struct LLVMStructTypeStorage; +struct LLVMTypeAndSizeStorage; } // namespace detail class LLVMType : public mlir::Type::TypeBase typeBuilder); }; +class LLVMTypeNew : public Type { +public: + enum Kind { + // Keep non-parametric types contiguous in the enum. + VOID_TYPE = FIRST_LLVM_TYPE + 1, + HALF_TYPE, + BFLOAT_TYPE, + FLOAT_TYPE, + DOUBLE_TYPE, + FP128_TYPE, + X86_FP80_TYPE, + PPC_FP128_TYPE, + X86_MMX_TYPE, + LABEL_TYPE, + TOKEN_TYPE, + METADATA_TYPE, + // End of non-parametric types. + FUNCTION_TYPE, + INTEGER_TYPE, + POINTER_TYPE, + FIXED_VECTOR_TYPE, + SCALABLE_VECTOR_TYPE, + ARRAY_TYPE, + STRUCT_TYPE, + FIRST_NEW_LLVM_TYPE = VOID_TYPE, + LAST_NEW_LLVM_TYPE = STRUCT_TYPE, + FIRST_TRIVIAL_TYPE = VOID_TYPE, + LAST_TRIVIAL_TYPE = METADATA_TYPE + }; + + using Type::Type; + + static bool kindof(unsigned kind) { + return FIRST_NEW_LLVM_TYPE <= kind <= LAST_NEW_LLVM_TYPE; + } +}; + +template +class LLVMTrivialType + : public Type::TypeBase { +public: + using Type::TypeBase::TypeBase; + using Base = LLVMTrivialType; + + static bool kindof(unsigned kind) { return kind == Kind; } + + static Derived get(MLIRContext *context) { + return Type::TypeBase::get(context, + Kind); + } +}; + +#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \ + class ClassName : public LLVMTrivialType { \ + public: \ + using Base::Base; \ + } + +DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMTypeNew::VOID_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMTypeNew::HALF_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMTypeNew::BFLOAT_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMTypeNew::FLOAT_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMTypeNew::DOUBLE_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMTypeNew::FP128_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMTypeNew::X86_FP80_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMTypeNew::PPC_FP128_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMTypeNew::X86_MMX_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMTypeNew::TOKEN_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMTypeNew::LABEL_TYPE); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMTypeNew::METADATA_TYPE); + +#undef DEFINE_TRIVIAL_LLVM_TYPE + +class LLVMArrayType : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == LLVMTypeNew::ARRAY_TYPE; } + + static LLVMArrayType get(LLVMTypeNew elementType, unsigned numElements); + LLVMTypeNew getElementType(); + unsigned getNumElements(); +}; + +class LLVMFunctionType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::FUNCTION_TYPE; + } + + static LLVMFunctionType get(LLVMTypeNew result, + ArrayRef arguments, + bool isVarArg = false); + + LLVMTypeNew getReturnType(); + unsigned getNumParams(); + LLVMTypeNew getParamType(unsigned i); + bool isVarArg(); + ArrayRef getParams(); + ArrayRef params() { return getParams(); } +}; + +class LLVMIntegerType : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::INTEGER_TYPE; + } + + static LLVMIntegerType get(MLIRContext *ctx, unsigned bitwidth); + unsigned getBitWidth(); +}; + +class LLVMPointerType : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::POINTER_TYPE; + } + + static LLVMPointerType get(LLVMTypeNew pointee, unsigned addressSpace = 0); + LLVMTypeNew getElementType(); + unsigned getAddressSpace(); +}; + +class LLVMStructType : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == LLVMTypeNew::STRUCT_TYPE; } + + static LLVMStructType getIdentified(MLIRContext *context, StringRef name); + static LLVMStructType getLiteral(MLIRContext *context, ArrayRef types, + bool isPacked = false); + static LLVMStructType getOpaque(StringRef name, MLIRContext *context); + LogicalResult setBody(ArrayRef types, bool isPacked); + + bool isPacked(); + bool isIdentified(); + bool isOpaque(); + StringRef identifier(); + ArrayRef getBody(); +}; + +class LLVMVectorType : public LLVMTypeNew { +public: + using LLVMTypeNew::LLVMTypeNew; + + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::FIXED_VECTOR_TYPE || + kind == LLVMTypeNew::SCALABLE_VECTOR_TYPE; + } + + LLVMTypeNew getElementType(); + llvm::ElementCount getElementCount(); +}; + +class LLVMFixedVectorType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::FIXED_VECTOR_TYPE; + } + + static LLVMFixedVectorType get(LLVMTypeNew elementType, unsigned numElements); + + unsigned getNumElements(); +}; + +class LLVMScalableVectorType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == LLVMTypeNew::SCALABLE_VECTOR_TYPE; + } + + static LLVMScalableVectorType get(LLVMTypeNew elementType, + unsigned minNumElements); + + unsigned getMinNumElements(); +}; + ///// Ops ///// #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" @@ -228,7 +432,68 @@ std::unique_ptr cloneModuleIntoNewContext(llvm::LLVMContext *context, llvm::Module *module); +namespace detail { +LLVMTypeNew parseType(DialectAsmParser &parser); +void printType(LLVMTypeNew type, DialectAsmPrinter &printer); +} // end namespace detail + +class LLVMDialectNewTypes : public Dialect { +public: + LLVMDialectNewTypes(MLIRContext *ctx) : Dialect(getDialectNamespace(), ctx) { + // clang-format off + addTypes(); + // clang-format on + } + static StringRef getDialectNamespace() { return "llvm2"; } + + Type parseType(DialectAsmParser &parser) const override { + return detail::parseType(parser); + } + void printType(Type type, DialectAsmPrinter &printer) const override { + detail::printType(type.cast(), printer); + } +}; + } // end namespace LLVM } // end namespace mlir +namespace llvm { + +// Type hash just like pointers. +template <> +struct DenseMapInfo<::mlir::LLVM::LLVMTypeNew> { + static ::mlir::LLVM::LLVMTypeNew getEmptyKey() { + void *pointer = llvm::DenseMapInfo::getEmptyKey(); + return ::mlir::LLVM::LLVMTypeNew( + static_cast<::mlir::LLVM::LLVMTypeNew::ImplType *>(pointer)); + } + static mlir::Type getTombstoneKey() { + void *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return ::mlir::LLVM::LLVMTypeNew( + static_cast<::mlir::LLVM::LLVMTypeNew::ImplType *>(pointer)); + } + static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); } + static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; } +}; + +} // namespace llvm + #endif // MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_ diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -203,6 +203,9 @@ /// Parse a `=` token if present. virtual ParseResult parseOptionalEqual() = 0; + /// Parse a quoted string token if present. + virtual ParseResult parseOptionalString(StringRef *string) = 0; + /// Parse a given keyword. ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { auto loc = getCurrentLocation(); @@ -323,6 +326,8 @@ return success(); } + virtual OptionalParseResult parseOptionalType(Type &result) = 0; + /// Parse a 'x' separated dimension list. This populates the dimension list, /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on /// `?` otherwise. diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -43,6 +43,7 @@ registerDialect(); registerDialect(); registerDialect(); + registerDialect(); registerDialect(); registerDialect(); registerDialect(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Module.h" #include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/AsmParser/Parser.h" #include "llvm/Bitcode/BitcodeReader.h" @@ -1650,6 +1651,7 @@ namespace mlir { namespace LLVM { + namespace detail { struct LLVMDialectImpl { LLVMDialectImpl() : module("LLVMDialectModule", llvmContext) {} @@ -1730,6 +1732,9 @@ return LLVMType::get(getContext(), type); } +void printLLVMTypeImpl(llvm::raw_ostream &os, LLVMTypeNew type, + llvm::SetVector &stack); + /// Print a type registered to this dialect. void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { auto llvmType = type.dyn_cast(); @@ -1758,6 +1763,745 @@ // LLVMType. //===----------------------------------------------------------------------===// +namespace mlir { +namespace LLVM { +namespace detail { +struct LLVMStructTypeStorage : public ::mlir::TypeStorage { + // Key: flags(identified, opaqueOrPacked), name-or-type-list + // for identified -> opaque, for non-identified -> packed + // Body: initialized, type-list and packed for identified + + constexpr static uint8_t IdentifiedFlag = 0x1; + constexpr static uint8_t InitializedFlag = 0x1; + constexpr static uint8_t PackedFlag = 0x2; + constexpr static uint8_t OpaqueFlag = 0x2; + + using KeyTy = std::tuple; + + LLVMStructTypeStorage(const void *ptr, unsigned size, uint8_t flags) + : keyPtr(ptr), keySize(size), keyFlags(flags), mutableFlags(0) {} + + static bool isIdentified(const KeyTy &key) { + return std::get<2>(key) & IdentifiedFlag; + } + + bool isIdentified() const { return keyFlags & IdentifiedFlag; } + + bool isPacked() const { + return isIdentified() ? (mutableFlags & PackedFlag) + : (keyFlags & PackedFlag); + } + + bool isOpaque() const { return keyFlags & OpaqueFlag; } + + bool isInitialized() const { return mutableFlags & InitializedFlag; } + + static StringRef identifier(const KeyTy &key) { + return StringRef(static_cast(std::get<0>(key)), + std::get<1>(key)); + } + StringRef identifier() const { + return StringRef(static_cast(keyPtr), keySize); + } + + static ArrayRef typelist(const KeyTy &key) { + return ArrayRef(static_cast(std::get<0>(key)), + std::get<1>(key)); + } + ArrayRef typelist() const { + return ArrayRef(static_cast(keyPtr), keySize); + } + + bool operator==(const KeyTy &other) const { + uint8_t otherFlags = std::get<2>(other); + if (keyFlags != otherFlags) + return false; + + if (isIdentified(other)) + return identifier(other).equals(identifier()); + return typelist(other) == typelist(); + }; + + static llvm::hash_code hashKey(const KeyTy &key) { + if (isIdentified(key)) + return llvm::hash_combine(identifier(key), std::get<2>(key)); + return llvm::hash_combine(typelist(key), std::get<2>(key)); + } + + static LLVMStructTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + const void *ptr = isIdentified(key) + ? static_cast( + allocator.copyInto(identifier(key)).data()) + : allocator.copyInto(typelist(key)).data(); + + return new (allocator.allocate()) + LLVMStructTypeStorage(ptr, std::get<1>(key), std::get<2>(key)); + } + + LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef body, + bool packed) { + if (!isIdentified()) + return failure(); + if (isInitialized()) + return success(body == identifiedStructBody && packed == isPacked()); + + mutableFlags |= InitializedFlag; + if (packed) + mutableFlags |= PackedFlag; + identifiedStructBody = allocator.copyInto(body); + return success(); + } + + // Note: cannot use PointerUnion because bump-ptr allocator does not guarantee + // address alignment. + const void *keyPtr; + unsigned keySize; + ArrayRef identifiedStructBody; + uint8_t keyFlags : 2; + uint8_t mutableFlags : 2; +}; + +struct LLVMFunctionTypeStorage : public ::mlir::TypeStorage { + using KeyTy = std::tuple, bool>; + + LLVMFunctionTypeStorage(LLVMTypeNew result, ArrayRef arguments, + bool var) + : resultType(result), argumentTypes(arguments), variadic(var) {} + + static LLVMFunctionTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMFunctionTypeStorage(std::get<0>(key), + allocator.copyInto(std::get<1>(key)), + std::get<2>(key)); + } + + static unsigned hashKey(const KeyTy &key) { + // LLVM doesn't like hashing bools in tuples. + return llvm::hash_combine(std::get<0>(key), std::get<1>(key), + std::get<2>(key)); + } + + bool operator==(const KeyTy &key) const { + return std::make_tuple(std::get<0>(key), argumentTypes, variadic) == key; + } + + LLVMTypeNew resultType; + ArrayRef argumentTypes; + int8_t variadic : 1; +}; + +struct LLVMIntegerTypeStorage : public ::mlir::TypeStorage { + using KeyTy = unsigned; + + LLVMIntegerTypeStorage(unsigned width) : bitwidth(width) {} + + static LLVMIntegerTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMIntegerTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { return key == bitwidth; } + + unsigned bitwidth; +}; + +struct LLVMPointerTypeStorage : public ::mlir::TypeStorage { + using KeyTy = std::tuple; + + LLVMPointerTypeStorage(const KeyTy &key) + : pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {} + + static LLVMPointerTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMPointerTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return std::make_tuple(pointeeType, addressSpace) == key; + } + + LLVMTypeNew pointeeType; + unsigned addressSpace; +}; + +/// Common storage used for LLVM dialect types that need an element type and a +/// number: arrays, fixed and scalable vectors. The actual semantics of the +/// type is defined by its kind. +struct LLVMTypeAndSizeStorage : public ::mlir::TypeStorage { + using KeyTy = std::tuple; + + LLVMTypeAndSizeStorage(const KeyTy &key) + : elementType(std::get<0>(key)), numElements(std::get<1>(key)) {} + + static LLVMTypeAndSizeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + LLVMTypeAndSizeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return std::make_tuple(elementType, numElements) == key; + } + + LLVMTypeNew elementType; + unsigned numElements; +}; + +} // end namespace detail +} // end namespace LLVM +} // end namespace mlir + +// Array type. + +LLVMArrayType LLVMArrayType::get(LLVMTypeNew elementType, + unsigned numElements) { + assert(elementType && "expected non-null subtype"); + return Base::get(elementType.getContext(), LLVMTypeNew::ARRAY_TYPE, + elementType, numElements); +} + +LLVMTypeNew LLVMArrayType::getElementType() { return getImpl()->elementType; } + +unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; } + +// +// Function type. +// + +LLVMFunctionType LLVMFunctionType::get(LLVMTypeNew result, + ArrayRef arguments, + bool isVarArg) { + assert(result && "expected non-null result"); + return Base::get(result.getContext(), LLVMTypeNew::FUNCTION_TYPE, result, + arguments, isVarArg); +} + +LLVMTypeNew LLVMFunctionType::getReturnType() { + return getImpl()->resultType.cast(); +} + +unsigned LLVMFunctionType::getNumParams() { + return getImpl()->argumentTypes.size(); +} + +LLVMTypeNew LLVMFunctionType::getParamType(unsigned i) { + return getImpl()->argumentTypes[i].cast(); +} + +bool LLVMFunctionType::isVarArg() { return getImpl()->variadic; } + +ArrayRef LLVMFunctionType::getParams() { + return getImpl()->argumentTypes; +} + +// +// Integer type. +// + +LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) { + return Base::get(ctx, LLVMTypeNew::INTEGER_TYPE, bitwidth); +} + +unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; } + +// +// Pointer type. +// + +LLVMPointerType LLVMPointerType::get(LLVMTypeNew pointee, + unsigned addressSpace) { + assert(pointee && "expected non-null subtype"); + return Base::get(pointee.getContext(), LLVMTypeNew::POINTER_TYPE, pointee, + addressSpace); +} + +LLVMTypeNew LLVMPointerType::getElementType() { return getImpl()->pointeeType; } + +unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; } + +// +// Struct type. +// + +LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, + StringRef name) { + return Base::get(context, LLVMTypeNew::STRUCT_TYPE, name.data(), name.size(), + detail::LLVMStructTypeStorage::IdentifiedFlag); +} + +LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, + ArrayRef types, bool isPacked) { + uint8_t flags = 0; + if (isPacked) + flags |= detail::LLVMStructTypeStorage::PackedFlag; + return Base::get(context, LLVMTypeNew::STRUCT_TYPE, types.data(), + types.size(), flags); +} + +LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) { + auto flags = detail::LLVMStructTypeStorage::OpaqueFlag | + detail::LLVMStructTypeStorage::IdentifiedFlag; + return Base::get(context, LLVMTypeNew::STRUCT_TYPE, name.data(), name.size(), + flags); +} + +LogicalResult LLVMStructType::setBody(ArrayRef types, bool isPacked) { + assert(isIdentified() && "can only set bodies of identified structs"); + return Base::mutate(types, isPacked); +} + +bool LLVMStructType::isPacked() { return getImpl()->isPacked(); } +bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); } +bool LLVMStructType::isOpaque() { + return getImpl()->isOpaque() || !getImpl()->isInitialized(); +} +StringRef LLVMStructType::identifier() { return getImpl()->identifier(); } +ArrayRef LLVMStructType::getBody() { + return isIdentified() ? getImpl()->identifiedStructBody + : getImpl()->typelist(); +} + +static StringRef getTypeKeyword(LLVMTypeNew type) { + switch (type.getKind()) { + case LLVMTypeNew::VOID_TYPE: + return "void"; + case LLVMTypeNew::HALF_TYPE: + return "half"; + case LLVMTypeNew::BFLOAT_TYPE: + return "bfloat"; + case LLVMTypeNew::FLOAT_TYPE: + return "float"; + case LLVMTypeNew::DOUBLE_TYPE: + return "double"; + case LLVMTypeNew::FP128_TYPE: + return "fp128"; + case LLVMTypeNew::X86_FP80_TYPE: + return "x86_fp80"; + case LLVMTypeNew::PPC_FP128_TYPE: + return "ppc_fp128"; + case LLVMTypeNew::X86_MMX_TYPE: + return "x86_mmx"; + case LLVMTypeNew::TOKEN_TYPE: + return "token"; + case LLVMTypeNew::LABEL_TYPE: + return "label"; + case LLVMTypeNew::METADATA_TYPE: + return "metadata"; + case LLVMTypeNew::FUNCTION_TYPE: + return "func"; + case LLVMTypeNew::INTEGER_TYPE: + return "i"; + case LLVMTypeNew::POINTER_TYPE: + return "ptr"; + case LLVMTypeNew::FIXED_VECTOR_TYPE: + case LLVMTypeNew::SCALABLE_VECTOR_TYPE: + return "vec"; + case LLVMTypeNew::ARRAY_TYPE: + return "array"; + case LLVMTypeNew::STRUCT_TYPE: + return "struct"; + } + llvm_unreachable("unhandled type kind"); +} + +static void printLLVMStructTypeBody(llvm::raw_ostream &os, LLVMStructType type, + llvm::SetVector &stack) { + if (type.isPacked()) + os << "packed "; + + os << '('; + if (type.isIdentified()) + stack.insert(type.identifier()); + llvm::interleaveComma(type.getBody(), os, [&](Type subtype) { + printLLVMTypeImpl(os, subtype.cast(), stack); + }); + if (type.isIdentified()) + stack.pop_back(); + os << ')'; +} + +void printLLVMStructTypeImpl(llvm::raw_ostream &os, LLVMStructType type, + llvm::SetVector &stack) { + os << "<"; + if (type.isIdentified()) { + os << '"' << type.identifier() << '"'; + // If we are printing a reference to one of the enclosing structs, just + // print the name and stop to avoid infinitely long output. + if (stack.count(type.identifier())) { + os << '>'; + return; + } + os << ", "; + + if (type.isOpaque()) { + os << "opaque>"; + return; + } + } + + printLLVMStructTypeBody(os, type, stack); + os << '>'; +} + +template +static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type, + llvm::SetVector &stack) { + os << '<' << type.getNumElements() << " x "; + printLLVMTypeImpl(os, type.getElementType(), stack); + os << '>'; +} + +static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType, + llvm::SetVector &stack) { + os << '<'; + printLLVMTypeImpl(os, funcType.getReturnType(), stack); + os << " ("; + llvm::interleaveComma(funcType.getParams(), os, + [&os, &stack](LLVMTypeNew subtype) { + printLLVMTypeImpl(os, subtype, stack); + }); + if (funcType.isVarArg()) { + if (funcType.getNumParams() != 0) + os << ", "; + os << "..."; + } + os << ")>"; +} + +void printLLVMTypeImpl(llvm::raw_ostream &os, LLVMTypeNew type, + llvm::SetVector &stack) { + if (!type) { + os << "<>"; + return; + } + + unsigned kind = type.getKind(); + os << getTypeKeyword(type); + + // Trivial types only consist of their keyword. + if (LLVMTypeNew::FIRST_TRIVIAL_TYPE <= kind && + kind <= LLVMTypeNew::LAST_TRIVIAL_TYPE) + return; + + if (auto intType = type.dyn_cast()) { + os << intType.getBitWidth(); + return; + } + + if (auto ptrType = type.dyn_cast()) { + os << '<'; + printLLVMTypeImpl(os, ptrType.getElementType(), stack); + if (ptrType.getAddressSpace() != 0) + os << ", " << ptrType.getAddressSpace(); + os << '>'; + return; + } + + if (auto arrayType = type.dyn_cast()) + return printArrayOrVectorType(os, arrayType, stack); + if (auto vectorType = type.dyn_cast()) + return printArrayOrVectorType(os, vectorType, stack); + + if (auto vectorType = type.dyn_cast()) { + os << "'; + return; + } + + if (auto structType = type.dyn_cast()) + return printLLVMStructTypeImpl(os, structType, stack); + + printFunctionType(os, type.cast(), stack); +} + +void mlir::LLVM::detail::printType(LLVMTypeNew type, + DialectAsmPrinter &printer) { + llvm::SetVector stack; + return printLLVMTypeImpl(printer.getStream(), type, stack); +} + +// +// Parsing. +// + +static LLVMTypeNew parseLLVMTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack); + +static ParseResult parseLLVMTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack, + LLVMTypeNew &result) { + result = parseLLVMTypeImpl(parser, stack); + return success(result != nullptr); +} + +static LLVMFunctionType parseFunctionType(DialectAsmParser &parser, + llvm::SetVector &stack) { + LLVMTypeNew returnType; + if (parser.parseLess() || parseLLVMTypeImpl(parser, stack, returnType) || + parser.parseLParen()) + return LLVMFunctionType(); + + // Function type without arguments. + if (succeeded(parser.parseOptionalRParen())) { + if (succeeded(parser.parseGreater())) + return LLVMFunctionType::get(returnType, {}, /*isVarArg=*/false); + return LLVMFunctionType(); + } + + SmallVector argTypes; + do { + if (succeeded(parser.parseOptionalEllipsis())) { + if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) + return LLVMFunctionType(); + return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/true); + } + + argTypes.push_back(parseLLVMTypeImpl(parser, stack)); + if (!argTypes.back()) + return LLVMFunctionType(); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) + return LLVMFunctionType(); + return LLVMFunctionType::get(returnType, argTypes, /*isVarArg=*/false); +} + +static LLVMPointerType parsePointerType(DialectAsmParser &parser, + llvm::SetVector &stack) { + LLVMTypeNew elementType; + if (parser.parseLess() || parseLLVMTypeImpl(parser, stack, elementType)) + return LLVMPointerType(); + + unsigned addressSpace = 0; + if (succeeded(parser.parseOptionalComma()) && + failed(parser.parseInteger(addressSpace))) + return LLVMPointerType(); + if (failed(parser.parseGreater())) + return LLVMPointerType(); + return LLVMPointerType::get(elementType, addressSpace); +} + +static LLVMVectorType parseVectorType(DialectAsmParser &parser, + llvm::SetVector &stack) { + SmallVector dims; + llvm::SMLoc dimPos; + LLVMTypeNew elementType; + if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || + parser.parseDimensionList(dims, /*allowDynamic=*/true) || + parseLLVMTypeImpl(parser, stack, elementType) || parser.parseGreater()) + return LLVMVectorType(); + + if (dims.empty() || dims.size() > 2 || (dims.size() == 2 ^ dims[0] == -1)) { + parser.emitError(dimPos) + << "expected '? x x ' or ' x '"; + return LLVMVectorType(); + } + + bool isScalable = dims.size() == 2; + return isScalable ? static_cast( + LLVMScalableVectorType::get(elementType, dims[1])) + : LLVMFixedVectorType::get(elementType, dims[0]); +} + +static LLVMArrayType parseArrayType(DialectAsmParser &parser, + llvm::SetVector &stack) { + SmallVector dims; + llvm::SMLoc sizePos; + LLVMTypeNew elementType; + if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || + parser.parseDimensionList(dims, /*allowDynamic=*/false) || + parseLLVMTypeImpl(parser, stack, elementType) || parser.parseGreater()) + return LLVMArrayType(); + + if (dims.size() != 1) { + parser.emitError(sizePos) << "expected ? x "; + return LLVMArrayType(); + } + + return LLVMArrayType::get(elementType, dims[0]); +} + +static LLVMStructType trySetStructBody(LLVMStructType type, + ArrayRef subtypes, bool isPacked, + DialectAsmParser &parser, + llvm::SMLoc subtypesLoc, + llvm::SetVector &stack) { + if (succeeded(type.setBody(subtypes, isPacked))) + return type; + + std::string currentBody; + llvm::raw_string_ostream currentBodyStream(currentBody); + printLLVMStructTypeBody(currentBodyStream, type, stack); + (parser.emitError(subtypesLoc) + << "identified type already used with a different body") + .attachNote() + << "existing body: " << currentBodyStream.str(); + return LLVMStructType(); +} + +static LLVMStructType parseStructType(DialectAsmParser &parser, + llvm::SetVector &stack) { + MLIRContext *ctx = parser.getBuilder().getContext(); + + if (failed(parser.parseLess())) + return LLVMStructType(); + + StringRef name; + bool isIdentified = succeeded(parser.parseOptionalString(&name)); + if (isIdentified) { + if (stack.count(name)) { + if (failed(parser.parseGreater())) + return LLVMStructType(); + return LLVMStructType::getIdentified(ctx, name); + } + if (failed(parser.parseComma())) + return LLVMStructType(); + } + + llvm::SMLoc kwLoc = parser.getCurrentLocation(); + if (succeeded(parser.parseOptionalKeyword("opaque"))) { + if (!isIdentified) + return parser.emitError(kwLoc, "only identified structs can be opaque"), + LLVMStructType(); + if (failed(parser.parseGreater())) + return LLVMStructType(); + return LLVMStructType::getOpaque(name, ctx); + } + + bool isPacked = succeeded(parser.parseOptionalKeyword("packed")); + if (failed(parser.parseLParen())) + return LLVMStructType(); + + if (succeeded(parser.parseOptionalRParen())) { + if (failed(parser.parseGreater())) + return LLVMStructType(); + if (!isIdentified) + return LLVMStructType::getLiteral(ctx, {}, isPacked); + auto type = LLVMStructType::getIdentified(ctx, name); + return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack); + } + + SmallVector subtypes; + llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); + do { + if (isIdentified) + stack.insert(name); + Type type = parseLLVMTypeImpl(parser, stack); + if (!type) + return LLVMStructType(); + subtypes.push_back(type); + if (isIdentified) + stack.pop_back(); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen() || parser.parseGreater()) + return LLVMStructType(); + + if (!isIdentified) + return LLVMStructType::getLiteral(ctx, subtypes, isPacked); + + auto type = LLVMStructType::getIdentified(ctx, name); + return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack); +} + +static LLVMTypeNew parseLLVMTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack) { + // Special case for integers (i[1-9][0-9]*) that are literals rather than + // keywords for the parser, so they are not caught by the main dispatch below. + // Try parsing it a built-in integer type instead. + Type maybeIntegerType; + MLIRContext *ctx = parser.getBuilder().getContext(); + llvm::SMLoc keyLoc = parser.getCurrentLocation(); + OptionalParseResult result = parser.parseOptionalType(maybeIntegerType); + if (result.hasValue()) { + if (failed(*result)) + return LLVMTypeNew(); + + if (!maybeIntegerType.isSignlessInteger()) { + parser.emitError(keyLoc) << "unexpected type, expected i* or keyword"; + return LLVMTypeNew(); + } + return LLVMIntegerType::get(ctx, maybeIntegerType.getIntOrFloatBitWidth()); + } + + StringRef key; + if (failed(parser.parseKeyword(&key))) + return LLVMTypeNew(); + + return llvm::StringSwitch>(key) + .Case("void", [&] { return LLVMVoidType::get(ctx); }) + .Case("half", [&] { return LLVMHalfType::get(ctx); }) + .Case("bfloat", [&] { return LLVMBFloatType::get(ctx); }) + .Case("float", [&] { return LLVMFloatType::get(ctx); }) + .Case("double", [&] { return LLVMDoubleType::get(ctx); }) + .Case("fp128", [&] { return LLVMFP128Type::get(ctx); }) + .Case("x86_fp80", [&] { return LLVMX86FP80Type::get(ctx); }) + .Case("ppc_fp128", [&] { return LLVMPPCFP128Type::get(ctx); }) + .Case("x86_mmx", [&] { return LLVMX86MMXType::get(ctx); }) + .Case("token", [&] { return LLVMTokenType::get(ctx); }) + .Case("label", [&] { return LLVMLabelType::get(ctx); }) + .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) + .Case("func", [&] { return parseFunctionType(parser, stack); }) + .Case("ptr", [&] { return parsePointerType(parser, stack); }) + .Case("vec", [&] { return parseVectorType(parser, stack); }) + .Case("array", [&] { return parseArrayType(parser, stack); }) + .Case("struct", [&] { return parseStructType(parser, stack); }) + .Default([&] { + parser.emitError(keyLoc) << "unknown LLVM type: " << key; + return LLVMTypeNew(); + })(); +} + +LLVMTypeNew mlir::LLVM::detail::parseType(DialectAsmParser &parser) { + llvm::SetVector stack; + return parseLLVMTypeImpl(parser, stack); +} + +// +// Vector types. +// + +LLVMTypeNew LLVMVectorType::getElementType() { + // Both derived classes share the implementation type. + return static_cast(impl)->elementType; +} + +llvm::ElementCount LLVMVectorType::getElementCount() { + // Both derived classes share the implementation type. + return llvm::ElementCount( + static_cast(impl)->numElements, + this->isa()); +} + +LLVMFixedVectorType LLVMFixedVectorType::get(LLVMTypeNew elementType, + unsigned numElements) { + assert(elementType && "expected non-null subtype"); + return Base::get(elementType.getContext(), LLVMTypeNew::FIXED_VECTOR_TYPE, + elementType, numElements) + .cast(); +} + +unsigned LLVMFixedVectorType::getNumElements() { + return getImpl()->numElements; +} + +LLVMScalableVectorType LLVMScalableVectorType::get(LLVMTypeNew elementType, + unsigned minNumElements) { + assert(elementType && "expected non-null subtype"); + return Base::get(elementType.getContext(), LLVMTypeNew::SCALABLE_VECTOR_TYPE, + elementType, minNumElements) + .cast(); +} + +unsigned LLVMScalableVectorType::getMinNumElements() { + return getImpl()->numElements; +} + namespace mlir { namespace LLVM { namespace detail { diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -237,6 +237,17 @@ return success(parser.consumeIf(Token::star)); } + /// Parses a quoted string token if present. + ParseResult parseOptionalString(StringRef *string) override { + if (!parser.getToken().is(Token::string)) + return failure(); + + if (string) + *string = parser.getTokenSpelling().drop_front().drop_back(); + parser.consumeToken(); + return success(); + } + /// Returns if the current token corresponds to a keyword. bool isCurrentTokenAKeyword() const { return parser.getToken().is(Token::bare_identifier) || @@ -297,6 +308,10 @@ return parser.parseDimensionListRanked(dimensions, allowDynamic); } + OptionalParseResult parseOptionalType(Type &result) override { + return parser.parseOptionalType(result); + } + private: /// The full symbol specification. StringRef fullSpec; diff --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt --allow-unregistered-dialect -split-input-file -verify-diagnostics %s + +func @repeated_struct_name() { + "some.op"() : () -> !llvm2.struct<"a", (ptr>)> + // expected-error @+2 {{identified type already used with a different body}} + // expected-note @+1 {{existing body: (ptr>)}} + "some.op"() : () -> !llvm2.struct<"a", (i32)> +} + +// ----- + +func @repeated_struct_name_packed() { + "some.op"() : () -> !llvm2.struct<"a", packed (i32)> + // expected-error @+2 {{identified type already used with a different body}} + // expected-note @+1 {{existing body: packed (i32)}} + "some.op"() : () -> !llvm2.struct<"a", (i32)> +} + +// ----- + +func @unexpected_type() { + // expected-error @+1 {{unexpected type, expected i* or keyword}} + "some.op"() : () -> !llvm2.f32 +} + +// ----- + +func @unexpected_type() { + // expected-error @+1 {{unknown LLVM type}} + "some.op"() : () -> !llvm2.ifoo +} diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -0,0 +1,184 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | mlir-opt -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @primitive +func @primitive() { + // CHECK: !llvm2.void + "some.op"() : () -> !llvm2.void + // CHECK: !llvm2.half + "some.op"() : () -> !llvm2.half + // CHECK: !llvm2.bfloat + "some.op"() : () -> !llvm2.bfloat + // CHECK: !llvm2.float + "some.op"() : () -> !llvm2.float + // CHECK: !llvm2.double + "some.op"() : () -> !llvm2.double + // CHECK: !llvm2.fp128 + "some.op"() : () -> !llvm2.fp128 + // CHECK: !llvm2.x86_fp80 + "some.op"() : () -> !llvm2.x86_fp80 + // CHECK: !llvm2.ppc_fp128 + "some.op"() : () -> !llvm2.ppc_fp128 + // CHECK: !llvm2.x86_mmx + "some.op"() : () -> !llvm2.x86_mmx + // CHECK: !llvm2.token + "some.op"() : () -> !llvm2.token + // CHECK: !llvm2.label + "some.op"() : () -> !llvm2.label + // CHECK: !llvm2.metadata + "some.op"() : () -> !llvm2.metadata + return +} + +// CHECK-LABEL: @func +func @func() { + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + // CHECK: !llvm2.func + "some.op"() : () -> !llvm2.func + return +} + +// CHECK-LABEL: @integer +func @integer() { + // CHECK: !llvm2.i1 + "some.op"() : () -> !llvm2.i1 + // CHECK: !llvm2.i8 + "some.op"() : () -> !llvm2.i8 + // CHECK: !llvm2.i16 + "some.op"() : () -> !llvm2.i16 + // CHECK: !llvm2.i32 + "some.op"() : () -> !llvm2.i32 + // CHECK: !llvm2.i64 + "some.op"() : () -> !llvm2.i64 + // CHECK: !llvm2.i57 + "some.op"() : () -> !llvm2.i57 + // CHECK: !llvm2.i129 + "some.op"() : () -> !llvm2.i129 + return +} + +// CHECK-LABEL: @ptr +func @ptr() { + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr> + "some.op"() : () -> !llvm2.ptr> + // CHECK: !llvm2.ptr>>>> + "some.op"() : () -> !llvm2.ptr>>>> + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr + "some.op"() : () -> !llvm2.ptr + // CHECK: !llvm2.ptr, 9> + "some.op"() : () -> !llvm2.ptr, 9> + return +} + +// CHECK-LABEL: @vec +func @vec() { + // CHECK: !llvm2.vec<4 x i32> + "some.op"() : () -> !llvm2.vec<4 x i32> + // CHECK: !llvm2.vec<4 x float> + "some.op"() : () -> !llvm2.vec<4 x float> + // CHECK: !llvm2.vec + "some.op"() : () -> !llvm2.vec + // CHECK: !llvm2.vec + "some.op"() : () -> !llvm2.vec + // CHECK: !llvm2.vec<4 x ptr> + "some.op"() : () -> !llvm2.vec<4 x ptr> + return +} + +// CHECK-LABEL: @array +func @array() { + // CHECK: !llvm2.array<10 x i32> + "some.op"() : () -> !llvm2.array<10 x i32> + // CHECK: !llvm2.array<8 x float> + "some.op"() : () -> !llvm2.array<8 x float> + // CHECK: !llvm2.array<10 x ptr> + "some.op"() : () -> !llvm2.array<10 x ptr> + // CHECK: !llvm2.array<10 x array<4 x float>> + "some.op"() : () -> !llvm2.array<10 x array<4 x float>> + return +} + +// CHECK-LABEL: @literal_struct +func @literal_struct() { + // CHECK: !llvm2.struct<()> + "some.op"() : () -> !llvm2.struct<()> + // CHECK: !llvm2.struct<(i32)> + "some.op"() : () -> !llvm2.struct<(i32)> + // CHECK: !llvm2.struct<(float, i32)> + "some.op"() : () -> !llvm2.struct<(float, i32)> + // CHECK: !llvm2.struct<(struct<(i32)>)> + "some.op"() : () -> !llvm2.struct<(struct<(i32)>)> + // CHECK: !llvm2.struct<(i32, struct<(i32)>, float)> + "some.op"() : () -> !llvm2.struct<(i32, struct<(i32)>, float)> + + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct + "some.op"() : () -> !llvm2.struct + // CHECK: !llvm2.struct)> + "some.op"() : () -> !llvm2.struct)> + // CHECK: !llvm2.struct, float)> + "some.op"() : () -> !llvm2.struct, float)> + + // CHECK: !llvm2.struct<(struct)> + "some.op"() : () -> !llvm2.struct<(struct)> + // CHECK: !llvm2.struct)> + "some.op"() : () -> !llvm2.struct)> + return +} + +// CHECK-LABEL: @identified_struct +func @identified_struct() { + // CHECK: !llvm2.struct<"empty", ()> + "some.op"() : () -> !llvm2.struct<"empty", ()> + // CHECK: !llvm2.struct<"opaque", opaque> + "some.op"() : () -> !llvm2.struct<"opaque", opaque> + // CHECK: !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr>)> + "some.op"() : () -> !llvm2.struct<"long", (i32, struct<(i32, i1)>, float, ptr>)> + // CHECK: !llvm2.struct<"self-recursive", (ptr>)> + "some.op"() : () -> !llvm2.struct<"self-recursive", (ptr>)> + // CHECK: !llvm2.struct<"unpacked", (i32)> + "some.op"() : () -> !llvm2.struct<"unpacked", (i32)> + // CHECK: !llvm2.struct<"packed", packed (i32)> + "some.op"() : () -> !llvm2.struct<"packed", packed (i32)> + // CHECK: !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> + "some.op"() : () -> !llvm2.struct<"name with spaces and !^$@$#", packed (i32)> + + // CHECK: !llvm2.struct<"mutually-a", (ptr, 3>)>>)> + "some.op"() : () -> !llvm2.struct<"mutually-a", (ptr, 3>)>>)> + // CHECK: !llvm2.struct<"mutually-b", (ptr>)>, 3>)> + "some.op"() : () -> !llvm2.struct<"mutually-b", (ptr>)>, 3>)> + // CHECK: !llvm2.struct<"referring-another", (ptr>)> + "some.op"() : () -> !llvm2.struct<"referring-another", (ptr>)> + + // CHECK: !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> + "some.op"() : () -> !llvm2.struct<"struct-of-arrays", (array<10 x i32>)> + // CHECK: !llvm2.array<10 x struct<"array-of-structs", (i32)>> + "some.op"() : () -> !llvm2.array<10 x struct<"array-of-structs", (i32)>> + // CHECK: !llvm2.ptr> + "some.op"() : () -> !llvm2.ptr> + return +} +