diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -26,6 +26,8 @@ class DialectAsmPrinter; namespace LLVM { +class LLVMDialect; + namespace detail { struct LLVMFunctionTypeStorage; struct LLVMIntegerTypeStorage; @@ -34,6 +36,12 @@ struct LLVMTypeAndSizeStorage; } // namespace detail +class LLVMBFloatType; +class LLVMHalfType; +class LLVMFloatType; +class LLVMDoubleType; +class LLVMIntegerType; + //===----------------------------------------------------------------------===// // LLVMTypeNew. //===----------------------------------------------------------------------===// @@ -53,7 +61,7 @@ /// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR /// context, have an immutable identifier (for most types except identified /// structs, the entire type is the identifier) and are thread-safe. -class LLVMTypeNew : public Type { +class LLVMTypeNew : public Type::TypeBase { public: enum Kind { // Keep non-parametric types contiguous in the enum. @@ -84,7 +92,7 @@ }; /// Inherit base constructors. - using Type::Type; + using Base::Base; /// Support for PointerLikeTypeTraits. using Type::getAsOpaquePointer; @@ -96,6 +104,150 @@ static bool kindof(unsigned kind) { return FIRST_NEW_LLVM_TYPE <= kind && kind <= LAST_NEW_LLVM_TYPE; } + + LLVMDialect &getDialect(); + + /// Floating-point type utilities. + bool isBFloatTy() { return isa(); } + bool isHalfTy() { return isa(); } + bool isFloatTy() { return isa(); } + bool isDoubleTy() { return isa(); } + bool isFloatingPointTy() { + return isa() || isa() || + isa() || isa(); + } + + /// Array type utilities. + LLVMTypeNew getArrayElementType(); + unsigned getArrayNumElements(); + bool isArrayTy(); + + /// Integer type utilities. + bool isIntegerTy() { return isa(); } + bool isIntegerTy(unsigned bitwidth); + unsigned getIntegerBitWidth(); + + /// Vector type utilities. + LLVMTypeNew getVectorElementType(); + unsigned getVectorNumElements(); + llvm::ElementCount getVectorElementCount(); + bool isVectorTy(); + + /// Function type utilities. + LLVMTypeNew getFunctionParamType(unsigned argIdx); + unsigned getFunctionNumParams(); + LLVMTypeNew getFunctionResultType(); + bool isFunctionTy(); + bool isFunctionVarArg(); + + /// Pointer type utilities. + LLVMTypeNew getPointerTo(unsigned addrSpace = 0); + LLVMTypeNew getPointerElementTy(); + bool isPointerTy(); + static bool isValidPointerElementType(LLVMTypeNew type); + + /// Struct type utilities. + LLVMTypeNew getStructElementType(unsigned i); + unsigned getStructNumElements(); + bool isStructTy(); + + /// Utilities used to generate floating point types. + static LLVMTypeNew getDoubleTy(LLVMDialect *dialect); + static LLVMTypeNew getFloatTy(LLVMDialect *dialect); + static LLVMTypeNew getBFloatTy(LLVMDialect *dialect); + static LLVMTypeNew getHalfTy(LLVMDialect *dialect); + static LLVMTypeNew getFP128Ty(LLVMDialect *dialect); + static LLVMTypeNew getX86_FP80Ty(LLVMDialect *dialect); + + /// Utilities used to generate integer types. + static LLVMTypeNew getIntNTy(LLVMDialect *dialect, unsigned numBits); + static LLVMTypeNew getInt1Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/1); + } + static LLVMTypeNew getInt8Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/8); + } + static LLVMTypeNew getInt8PtrTy(LLVMDialect *dialect) { + return getInt8Ty(dialect).getPointerTo(); + } + static LLVMTypeNew getInt16Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/16); + } + static LLVMTypeNew getInt32Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/32); + } + static LLVMTypeNew getInt64Ty(LLVMDialect *dialect) { + return getIntNTy(dialect, /*numBits=*/64); + } + + /// Utilities used to generate other miscellaneous types. + static LLVMTypeNew getArrayTy(LLVMTypeNew elementType, uint64_t numElements); + static LLVMTypeNew getFunctionTy(LLVMTypeNew result, + ArrayRef params, bool isVarArg); + static LLVMTypeNew getFunctionTy(LLVMTypeNew result, bool isVarArg) { + return getFunctionTy(result, llvm::None, isVarArg); + } + static LLVMTypeNew getStructTy(LLVMDialect *dialect, + ArrayRef elements, + bool isPacked = false); + static LLVMTypeNew getStructTy(LLVMDialect *dialect, bool isPacked = false) { + return getStructTy(dialect, llvm::None, isPacked); + } + template + static typename std::enable_if::value, + LLVMTypeNew>::type + getStructTy(LLVMTypeNew elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + return getStructTy(&elt1.getDialect(), fields); + } + static LLVMTypeNew getVectorTy(LLVMTypeNew elementType, unsigned numElements); + + /// Void type utilities. + static LLVMTypeNew getVoidTy(LLVMDialect *dialect); + bool isVoidTy(); + + // Creation and setting of LLVM's identified struct types + static LLVMTypeNew createStructTy(LLVMDialect *dialect, + ArrayRef elements, + Optional name, + bool isPacked = false); + + static LLVMTypeNew createStructTy(LLVMDialect *dialect, + Optional name) { + return createStructTy(dialect, llvm::None, name); + } + + static LLVMTypeNew createStructTy(ArrayRef elements, + Optional name, + bool isPacked = false) { + assert(!elements.empty() && + "This method may not be invoked with an empty list"); + LLVMTypeNew ele0 = elements.front(); + return createStructTy(&ele0.getDialect(), elements, name, isPacked); + } + + template + static + typename std::enable_if_t::value, + LLVMTypeNew> + createStructTy(StringRef name, LLVMTypeNew elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + Optional opt_name(name); + return createStructTy(&elt1.getDialect(), fields, opt_name); + } + + static LLVMTypeNew setStructTyBody(LLVMTypeNew structType, + ArrayRef elements, + bool isPacked = false); + + template + static + typename std::enable_if_t::value, + LLVMTypeNew> + setStructTyBody(LLVMTypeNew structType, LLVMTypeNew elt1, Args... elts) { + SmallVector fields({elt1, elts...}); + return setStructTyBody(structType, fields); + } }; //===----------------------------------------------------------------------===// @@ -323,6 +475,9 @@ /// Checks if a struct is opaque. bool isOpaque(); + /// Checks if a struct is initialized. + bool isInitialized(); + /// Returns the name of an identified struct. StringRef getName(); @@ -337,10 +492,11 @@ /// LLVM dialect vector type, represents a sequence of elements that can be /// processed as one, typically in SIMD context. This is a base class for fixed /// and scalable vectors. -class LLVMVectorType : public LLVMTypeNew { +class LLVMVectorType : public Type::TypeBase { public: /// Inherit base constructor. - using LLVMTypeNew::LLVMTypeNew; + using Base::Base; /// Support for isa/cast. static bool kindof(unsigned kind) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -13,6 +13,7 @@ #include "TypeDetail.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/TypeSupport.h" @@ -22,6 +23,213 @@ using namespace mlir; using namespace mlir::LLVM; +//===----------------------------------------------------------------------===// +// LLVMTypeNew. +//===----------------------------------------------------------------------===// + +// TODO: when these types are registered with the LLVMDialect, this method +// should be removed and the regular Type::getDialect should just work. +LLVMDialect &LLVMTypeNew::getDialect() { + return *getContext()->getRegisteredDialect(); +} + +//----------------------------------------------------------------------------// +// Integer type utilities. + +bool LLVMTypeNew::isIntegerTy(unsigned bitwidth) { + if (auto intType = dyn_cast()) + return intType.getBitWidth() == bitwidth; + return false; +} + +unsigned LLVMTypeNew::getIntegerBitWidth() { + return cast().getBitWidth(); +} + +LLVMTypeNew LLVMTypeNew::getArrayElementType() { + return cast().getElementType(); +} + +//----------------------------------------------------------------------------// +// Array type utilities. + +unsigned LLVMTypeNew::getArrayNumElements() { + return cast().getNumElements(); +} + +bool LLVMTypeNew::isArrayTy() { return isa(); } + +//----------------------------------------------------------------------------// +// Vector type utilities. + +LLVMTypeNew LLVMTypeNew::getVectorElementType() { + return cast().getElementType(); +} + +unsigned LLVMTypeNew::getVectorNumElements() { + return cast().getNumElements(); +} +llvm::ElementCount LLVMTypeNew::getVectorElementCount() { + return cast().getElementCount(); +} + +bool LLVMTypeNew::isVectorTy() { return isa(); } + +//----------------------------------------------------------------------------// +// Function type utilities. + +LLVMTypeNew LLVMTypeNew::getFunctionParamType(unsigned argIdx) { + return cast().getParamType(argIdx); +} + +unsigned LLVMTypeNew::getFunctionNumParams() { + return cast().getNumParams(); +} + +LLVMTypeNew LLVMTypeNew::getFunctionResultType() { + return cast().getReturnType(); +} + +bool LLVMTypeNew::isFunctionTy() { return isa(); } + +bool LLVMTypeNew::isFunctionVarArg() { + return cast().isVarArg(); +} + +//----------------------------------------------------------------------------// +// Pointer type utilities. + +LLVMTypeNew LLVMTypeNew::getPointerTo(unsigned addrSpace) { + return LLVMPointerType::get(*this, addrSpace); +} + +LLVMTypeNew LLVMTypeNew::getPointerElementTy() { + return cast().getElementType(); +} + +bool LLVMTypeNew::isPointerTy() { return isa(); } + +bool LLVMTypeNew::isValidPointerElementType(LLVMTypeNew type) { + return !type.isa() && !type.isa() && + !type.isa() && !type.isa(); +} + +//----------------------------------------------------------------------------// +// Struct type utilities. + +LLVMTypeNew LLVMTypeNew::getStructElementType(unsigned i) { + return cast().getBody()[i]; +} + +unsigned LLVMTypeNew::getStructNumElements() { + return cast().getBody().size(); +} + +bool LLVMTypeNew::isStructTy() { return isa(); } + +//----------------------------------------------------------------------------// +// Utilities used to generate floating point types. + +LLVMTypeNew LLVMTypeNew::getDoubleTy(LLVMDialect *dialect) { + return LLVMDoubleType::get(dialect->getContext()); +} + +LLVMTypeNew LLVMTypeNew::getFloatTy(LLVMDialect *dialect) { + return LLVMFloatType::get(dialect->getContext()); +} + +LLVMTypeNew LLVMTypeNew::getBFloatTy(LLVMDialect *dialect) { + return LLVMBFloatType::get(dialect->getContext()); +} + +LLVMTypeNew LLVMTypeNew::getHalfTy(LLVMDialect *dialect) { + return LLVMHalfType::get(dialect->getContext()); +} + +LLVMTypeNew LLVMTypeNew::getFP128Ty(LLVMDialect *dialect) { + return LLVMFP128Type::get(dialect->getContext()); +} + +LLVMTypeNew LLVMTypeNew::getX86_FP80Ty(LLVMDialect *dialect) { + return LLVMX86FP80Type::get(dialect->getContext()); +} + +//----------------------------------------------------------------------------// +// Utilities used to generate integer types. + +LLVMTypeNew LLVMTypeNew::getIntNTy(LLVMDialect *dialect, unsigned numBits) { + return LLVMIntegerType::get(dialect->getContext(), numBits); +} + +//----------------------------------------------------------------------------// +// Utilities used to generate other miscellaneous types. + +LLVMTypeNew LLVMTypeNew::getArrayTy(LLVMTypeNew elementType, + uint64_t numElements) { + return LLVMArrayType::get(elementType, numElements); +} + +LLVMTypeNew LLVMTypeNew::getFunctionTy(LLVMTypeNew result, + ArrayRef params, + bool isVarArg) { + return LLVMFunctionType::get(result, params, isVarArg); +} + +LLVMTypeNew LLVMTypeNew::getStructTy(LLVMDialect *dialect, + ArrayRef elements, + bool isPacked) { + return LLVMStructType::getLiteral(dialect->getContext(), elements, isPacked); +} + +LLVMTypeNew LLVMTypeNew::getVectorTy(LLVMTypeNew elementType, + unsigned numElements) { + return LLVMFixedVectorType::get(elementType, numElements); +} + +//----------------------------------------------------------------------------// +// Void type utilities. + +LLVMTypeNew LLVMTypeNew::getVoidTy(LLVMDialect *dialect) { + return LLVMVoidType::get(dialect->getContext()); +} + +bool LLVMTypeNew::isVoidTy() { return isa(); } + +//----------------------------------------------------------------------------// +// Creation and setting of LLVM's identified struct types + +LLVMTypeNew LLVMTypeNew::createStructTy(LLVMDialect *dialect, + ArrayRef elements, + Optional name, + bool isPacked) { + assert(name.hasValue() && + "identified structs with no identifier not supported"); + StringRef stringNameBase = name.getValueOr(""); + std::string stringName = stringNameBase.str(); + unsigned counter = 0; + do { + auto type = + LLVMStructType::getIdentified(dialect->getContext(), stringName); + if (type.isInitialized() || failed(type.setBody(elements, isPacked))) { + counter += 1; + stringName = + (Twine(stringNameBase) + "." + std::to_string(counter)).str(); + continue; + } + return type; + } while (true); +} + +LLVMTypeNew LLVMTypeNew::setStructTyBody(LLVMTypeNew structType, + ArrayRef elements, + bool isPacked) { + LogicalResult couldSet = + structType.cast().setBody(elements, isPacked); + assert(succeeded(couldSet) && "failed to set the body"); + (void)couldSet; + return structType; +} + //===----------------------------------------------------------------------===// // Array type. @@ -117,6 +325,7 @@ bool LLVMStructType::isOpaque() { return getImpl()->isOpaque() || !getImpl()->isInitialized(); } +bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); } StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); } ArrayRef LLVMStructType::getBody() { return isIdentified() ? getImpl()->getIdentifiedStructBody()