diff --git a/mlir/docs/Tutorials/Toy/Ch-6.md b/mlir/docs/Tutorials/Toy/Ch-6.md --- a/mlir/docs/Tutorials/Toy/Ch-6.md +++ b/mlir/docs/Tutorials/Toy/Ch-6.md @@ -37,10 +37,11 @@ // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` - auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect); - auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); - auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); + auto llvmI32Ty = LLVM::LLVMIntegerType::get(context, 32); + auto llvmI8PtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); + auto llvmFnType = LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, + /*isVarArg=*/true); // Insert the printf function into the body of the parent module. PatternRewriter::InsertionGuard insertGuard(rewriter); diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -34,7 +34,6 @@ namespace LLVM { class LLVMDialect; -class LLVMType; class LLVMPointerType; } // namespace LLVM @@ -71,8 +70,8 @@ /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. - LLVM::LLVMType convertFunctionSignature(FunctionType funcTy, bool isVariadic, - SignatureConversion &result); + Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, + SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is @@ -118,14 +117,14 @@ /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. - LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); + Type convertFunctionTypeCWrapper(FunctionType type); /// Returns the data layout to use during and after conversion. const llvm::DataLayout &getDataLayout() { return options.dataLayout; } /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. - LLVM::LLVMType getIndexType(); + Type getIndexType(); /// Gets the bitwidth of the index type when converted to LLVM. unsigned getIndexTypeBitwidth() { return options.indexBitwidth; } @@ -185,8 +184,8 @@ /// - `!llvm.i64`, `!llvm.i64` (sizes), /// - `!llvm.i64`, `!llvm.i64` (strides). /// These types can be recomposed to a memref descriptor struct. - SmallVector - getMemRefDescriptorFields(MemRefType type, bool unpackAggregates); + SmallVector getMemRefDescriptorFields(MemRefType type, + bool unpackAggregates); /// Convert an unranked memref type into a list of non-aggregate LLVM IR types /// that will form the unranked memref descriptor. In particular, this list @@ -197,7 +196,7 @@ /// !llvm.i64 (rank) /// !llvm<"i8*"> (type-erased pointer). /// These types can be recomposed to a unranked memref descriptor struct. - SmallVector getUnrankedMemRefDescriptorFields(); + SmallVector getUnrankedMemRefDescriptorFields(); // Convert an unranked memref type to an LLVM type that captures the // runtime rank and a pointer to the static ranked memref desc @@ -417,31 +416,30 @@ /// Builds IR extracting the allocated pointer from the descriptor. static Value allocatedPtr(OpBuilder &builder, Location loc, - Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); + Value memRefDescPtr, Type elemPtrPtrType); /// Builds IR inserting the allocated pointer into the descriptor. static void setAllocatedPtr(OpBuilder &builder, Location loc, - Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType, + Value memRefDescPtr, Type elemPtrPtrType, Value allocatedPtr); /// Builds IR extracting the aligned pointer from the descriptor. static Value alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType); + Type elemPtrPtrType); /// Builds IR inserting the aligned pointer into the descriptor. static void setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, + Value memRefDescPtr, Type elemPtrPtrType, Value alignedPtr); /// Builds IR extracting the offset from the descriptor. static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType); + Type elemPtrPtrType); /// Builds IR inserting the offset into the descriptor. static void setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType, Value offset); + Type elemPtrPtrType, Value offset); /// Builds IR extracting the pointer to the first element of the size array. static Value sizeBasePtr(OpBuilder &builder, Location loc, @@ -490,17 +488,17 @@ /// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// defined by the used type converter. - LLVM::LLVMType getIndexType() const; + Type getIndexType() const; /// Gets the MLIR type wrapping the LLVM integer type whose bit width /// corresponds to that of a LLVM pointer type. - LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const; + Type getIntPtrType(unsigned addressSpace = 0) const; /// Gets the MLIR type wrapping the LLVM void type. - LLVM::LLVMType getVoidType() const; + Type getVoidType() const; /// Get the MLIR type wrapping the LLVM i8* type. - LLVM::LLVMType getVoidPtrType() const; + Type getVoidPtrType() const; /// Create an LLVM dialect operation defining the given index constant. Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -49,8 +49,8 @@ // LLVM dialect type. def LLVM_Type : DialectType()">, - "LLVM dialect type">; + CPred<"::mlir::LLVM::isCompatibleType($_self)">, + "LLVM dialect-compatible type">; // Type constraint accepting LLVM integer types. def LLVM_AnyInteger : Type< @@ -223,9 +223,9 @@ // or result in the operation. def LLVM_IntrPatterns { string operand = - [{convertType(opInst.getOperand($0).getType().cast())}]; + [{convertType(opInst.getOperand($0).getType())}]; string result = - [{convertType(opInst.getResult($0).getType().cast())}]; + [{convertType(opInst.getResult($0).getType())}]; string structResult = [{convertType(opInst.getResult(0).getType().cast() .getBody()[$0])}]; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -49,9 +49,8 @@ OpBuilderDAG<(ins "Type":$resultType, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attributes), [{ - auto llvmType = resultType.dyn_cast(); (void)llvmType; - assert(llvmType && "result must be an LLVM type"); - assert(llvmType.isa() && + assert(isCompatibleType(resultType) && "result must be an LLVM type"); + assert(resultType.isa() && "for zero-result operands, only 'void' is accepted as result type"); build($_builder, $_state, operands, attributes); }]>; @@ -443,7 +442,7 @@ OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attributes), [{ - LLVMType resultType = func.getType().getReturnType(); + Type resultType = func.getType().getReturnType(); if (!resultType.isa()) $_state.addTypes(resultType); $_state.addAttribute("callee", $_builder.getSymbolRefAttr(func)); @@ -756,23 +755,21 @@ }]; let builders = [ - OpBuilderDAG<(ins "LLVMType":$resType, "StringRef":$name, - CArg<"ArrayRef", "{}">:$attrs), - [{ - $_state.addAttribute("global_name",$_builder.getSymbolRefAttr(name)); - $_state.addAttributes(attrs); - $_state.addTypes(resType);}]>, OpBuilderDAG<(ins "GlobalOp":$global, CArg<"ArrayRef", "{}">:$attrs), [{ build($_builder, $_state, LLVM::LLVMPointerType::get(global.getType(), global.addr_space()), - global.sym_name(), attrs);}]>, + global.sym_name()); + $_state.addAttributes(attrs); + }]>, OpBuilderDAG<(ins "LLVMFuncOp":$func, CArg<"ArrayRef", "{}">:$attrs), [{ build($_builder, $_state, - LLVM::LLVMPointerType::get(func.getType()), func.getName(), attrs);}]> + LLVM::LLVMPointerType::get(func.getType()), func.getName()); + $_state.addAttributes(attrs); + }]> ]; let extraClassDeclaration = [{ @@ -883,15 +880,15 @@ let regions = (region AnyRegion:$initializer); let builders = [ - OpBuilderDAG<(ins "LLVMType":$type, "bool":$isConstant, "Linkage":$linkage, + OpBuilderDAG<(ins "Type":$type, "bool":$isConstant, "Linkage":$linkage, "StringRef":$name, "Attribute":$value, CArg<"unsigned", "0">:$addrSpace, CArg<"ArrayRef", "{}">:$attrs)> ]; let extraClassDeclaration = [{ /// Return the LLVM type of the global. - LLVMType getType() { - return type().cast(); + Type getType() { + return type(); } /// Return the initializer attribute if it exists, or a null attribute. Attribute getValueOrNull() { @@ -957,7 +954,7 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "StringRef":$name, "LLVMType":$type, + OpBuilderDAG<(ins "StringRef":$name, "Type":$type, CArg<"Linkage", "Linkage::External">:$linkage, CArg<"ArrayRef", "{}">:$attrs, CArg<"ArrayRef", "{}">:$argAttrs)> diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -45,52 +45,13 @@ class LLVMX86FP80Type; class LLVMIntegerType; -//===----------------------------------------------------------------------===// -// LLVMType. -//===----------------------------------------------------------------------===// - -/// Base class for LLVM dialect types. -/// -/// The LLVM dialect in MLIR fully reflects the LLVM IR type system, prodiving a -/// separate MLIR type for each LLVM IR type. All types are represented as -/// separate subclasses and are compatible with the isa/cast infrastructure. -/// -/// The LLVM dialect type system is closed: parametric types can only refer to -/// other LLVM dialect types. This is consistent with LLVM IR and enables a more -/// concise pretty-printing format. -/// -/// Similarly to other MLIR types, LLVM dialect types are owned by the MLIR -/// context, have an immutable identifier (for most types except identified -/// structs, the entire type is the identifier) and are thread-safe. -/// -/// This class is a thin common base class for different types available in the -/// LLVM dialect. It intentionally does not provide the API similar to -/// llvm::Type to avoid confusion and highlight potentially expensive operations -/// (e.g., type creation in MLIR takes a lock, so it's better to cache types). -class LLVMType : public Type { -public: - /// Inherit base constructors. - using Type::Type; - - /// Support for PointerLikeTypeTraits. - using Type::getAsOpaquePointer; - static LLVMType getFromOpaquePointer(const void *ptr) { - return LLVMType(static_cast(const_cast(ptr))); - } - - /// Support for isa/cast. - static bool classof(Type type); - - LLVMDialect &getDialect(); -}; - //===----------------------------------------------------------------------===// // Trivial types. //===----------------------------------------------------------------------===// // Batch-define trivial types. #define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \ - class ClassName : public Type::TypeBase { \ + class ClassName : public Type::TypeBase { \ public: \ using Base::Base; \ } @@ -117,30 +78,30 @@ /// LLVM dialect array type. It is an aggregate type representing consecutive /// elements in memory, parameterized by the number of elements and the element /// type. -class LLVMArrayType : public Type::TypeBase { public: /// Inherit base constructors. using Base::Base; /// Checks if the given type can be used inside an array type. - static bool isValidElementType(LLVMType type); + static bool isValidElementType(Type type); /// Gets or creates an instance of LLVM dialect array type containing /// `numElements` of `elementType`, in the same context as `elementType`. - static LLVMArrayType get(LLVMType elementType, unsigned numElements); - static LLVMArrayType getChecked(Location loc, LLVMType elementType, + static LLVMArrayType get(Type elementType, unsigned numElements); + static LLVMArrayType getChecked(Location loc, Type elementType, unsigned numElements); /// Returns the element type of the array. - LLVMType getElementType(); + Type getElementType(); /// Returns the number of elements in the array type. unsigned getNumElements(); /// Verifies that the type about to be constructed is well-formed. static LogicalResult verifyConstructionInvariants(Location loc, - LLVMType elementType, + Type elementType, unsigned numElements); }; @@ -152,46 +113,46 @@ /// which can have multiple), a list of parameter types and can optionally be /// variadic. class LLVMFunctionType - : public Type::TypeBase { public: /// Inherit base constructors. using Base::Base; /// Checks if the given type can be used an argument in a function type. - static bool isValidArgumentType(LLVMType type); + static bool isValidArgumentType(Type type); /// Checks if the given type can be used as a result in a function type. - static bool isValidResultType(LLVMType type); + static bool isValidResultType(Type type); /// Returns whether the function is variadic. bool isVarArg(); /// Gets or creates an instance of LLVM dialect function in the same context /// as the `result` type. - static LLVMFunctionType get(LLVMType result, ArrayRef arguments, + static LLVMFunctionType get(Type result, ArrayRef arguments, bool isVarArg = false); - static LLVMFunctionType getChecked(Location loc, LLVMType result, - ArrayRef arguments, + static LLVMFunctionType getChecked(Location loc, Type result, + ArrayRef arguments, bool isVarArg = false); /// Returns the result type of the function. - LLVMType getReturnType(); + Type getReturnType(); /// Returns the number of arguments to the function. unsigned getNumParams(); /// Returns `i`-th argument of the function. Asserts on out-of-bounds. - LLVMType getParamType(unsigned i); + Type getParamType(unsigned i); /// Returns a list of argument types of the function. - ArrayRef getParams(); - ArrayRef params() { return getParams(); } + ArrayRef getParams(); + ArrayRef params() { return getParams(); } /// Verifies that the type about to be constructed is well-formed. - static LogicalResult - verifyConstructionInvariants(Location loc, LLVMType result, - ArrayRef arguments, bool); + static LogicalResult verifyConstructionInvariants(Location loc, Type result, + ArrayRef arguments, + bool); }; //===----------------------------------------------------------------------===// @@ -199,7 +160,7 @@ //===----------------------------------------------------------------------===// /// LLVM dialect signless integer type parameterized by bitwidth. -class LLVMIntegerType : public Type::TypeBase { public: /// Inherit base constructor. @@ -225,31 +186,31 @@ /// LLVM dialect pointer type. This type typically represents a reference to an /// object in memory. It is parameterized by the element type and the address /// space. -class LLVMPointerType : public Type::TypeBase { public: /// Inherit base constructors. using Base::Base; /// Checks if the given type can have a pointer type pointing to it. - static bool isValidElementType(LLVMType type); + static bool isValidElementType(Type type); /// Gets or creates an instance of LLVM dialect pointer type pointing to an /// object of `pointee` type in the given address space. The pointer type is /// created in the same context as `pointee`. - static LLVMPointerType get(LLVMType pointee, unsigned addressSpace = 0); - static LLVMPointerType getChecked(Location loc, LLVMType pointee, + static LLVMPointerType get(Type pointee, unsigned addressSpace = 0); + static LLVMPointerType getChecked(Location loc, Type pointee, unsigned addressSpace = 0); /// Returns the pointed-to type. - LLVMType getElementType(); + Type getElementType(); /// Returns the address space of the pointer. unsigned getAddressSpace(); /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location loc, - LLVMType pointee, unsigned); + static LogicalResult verifyConstructionInvariants(Location loc, Type pointee, + unsigned); }; //===----------------------------------------------------------------------===// @@ -280,14 +241,14 @@ /// /// Note that the packedness of the struct takes place in uniquing of literal /// structs, but does not in uniquing of identified structs. -class LLVMStructType : public Type::TypeBase { public: /// Inherit base constructors. using Base::Base; /// Checks if the given type can be contained in a structure type. - static bool isValidElementType(LLVMType type); + static bool isValidElementType(Type type); /// Gets or creates an identified struct with the given name in the provided /// context. Note that unlike llvm::StructType::create, this function will @@ -302,16 +263,14 @@ /// the struct by appending a `.` followed by a number to the name. Renaming /// happens even if the existing struct has the same body. static LLVMStructType getNewIdentified(MLIRContext *context, StringRef name, - ArrayRef elements, + ArrayRef elements, bool isPacked = false); /// Gets or creates a literal struct with the given body in the provided /// context. - static LLVMStructType getLiteral(MLIRContext *context, - ArrayRef types, + static LLVMStructType getLiteral(MLIRContext *context, ArrayRef types, bool isPacked = false); - static LLVMStructType getLiteralChecked(Location loc, - ArrayRef types, + static LLVMStructType getLiteralChecked(Location loc, ArrayRef types, bool isPacked = false); /// Gets or creates an intentionally-opaque identified struct. Such a struct @@ -329,7 +288,7 @@ /// different thread modified the struct after it was created. Most callers /// are likely to assert this always succeeds, but it is possible to implement /// a local renaming scheme based on the result of this call. - LogicalResult setBody(ArrayRef types, bool isPacked); + LogicalResult setBody(ArrayRef types, bool isPacked); /// Checks if a struct is packed. bool isPacked(); @@ -347,12 +306,12 @@ StringRef getName(); /// Returns the list of element types contained in a non-opaque struct. - ArrayRef getBody(); + ArrayRef getBody(); /// Verifies that the type about to be constructed is well-formed. static LogicalResult verifyConstructionInvariants(Location, StringRef, bool); - static LogicalResult - verifyConstructionInvariants(Location loc, ArrayRef types, bool); + static LogicalResult verifyConstructionInvariants(Location loc, + ArrayRef types, bool); }; //===----------------------------------------------------------------------===// @@ -362,26 +321,26 @@ /// LLVM dialect vector type, represents a sequence of elements that can be /// processed as one, typically in SIMD context. This is a base class for fixed /// and scalable vectors. -class LLVMVectorType : public LLVMType { +class LLVMVectorType : public Type { public: /// Inherit base constructor. - using LLVMType::LLVMType; + using Type::Type; /// Support type casting functionality. static bool classof(Type type); /// Checks if the given type can be used in a vector type. - static bool isValidElementType(LLVMType type); + static bool isValidElementType(Type type); /// Returns the element type of the vector. - LLVMType getElementType(); + Type getElementType(); /// Returns the number of elements in the vector. llvm::ElementCount getElementCount(); /// Verifies that the type about to be constructed is well-formed. static LogicalResult verifyConstructionInvariants(Location loc, - LLVMType elementType, + Type elementType, unsigned numElements); }; @@ -401,8 +360,8 @@ /// Gets or creates a fixed vector type containing `numElements` of /// `elementType` in the same context as `elementType`. - static LLVMFixedVectorType get(LLVMType elementType, unsigned numElements); - static LLVMFixedVectorType getChecked(Location loc, LLVMType elementType, + static LLVMFixedVectorType get(Type elementType, unsigned numElements); + static LLVMFixedVectorType getChecked(Location loc, Type elementType, unsigned numElements); /// Returns the number of elements in the fixed vector. @@ -426,9 +385,8 @@ /// Gets or creates a scalable vector type containing a non-zero multiple of /// `minNumElements` of `elementType` in the same context as `elementType`. - static LLVMScalableVectorType get(LLVMType elementType, - unsigned minNumElements); - static LLVMScalableVectorType getChecked(Location loc, LLVMType elementType, + static LLVMScalableVectorType get(Type elementType, unsigned minNumElements); + static LLVMScalableVectorType getChecked(Location loc, Type elementType, unsigned minNumElements); /// Returns the scaling factor of the number of elements in the vector. The @@ -443,10 +401,10 @@ namespace detail { /// Parses an LLVM dialect type. -LLVMType parseType(DialectAsmParser &parser); +Type parseType(DialectAsmParser &parser); /// Prints an LLVM Dialect type. -void printType(LLVMType type, DialectAsmPrinter &printer); +void printType(Type type, DialectAsmPrinter &printer); } // namespace detail //===----------------------------------------------------------------------===// @@ -454,7 +412,30 @@ //===----------------------------------------------------------------------===// /// Returns `true` if the given type is compatible with the LLVM dialect. -inline bool isCompatibleType(Type type) { return type.isa(); } +inline bool isCompatibleType(Type type) { + // clang-format off + return type.isa< + LLVMArrayType, + LLVMBFloatType, + LLVMDoubleType, + LLVMFP128Type, + LLVMFloatType, + LLVMFunctionType, + LLVMHalfType, + LLVMIntegerType, + LLVMLabelType, + LLVMMetadataType, + LLVMPPCFP128Type, + LLVMPointerType, + LLVMStructType, + LLVMTokenType, + LLVMVectorType, + LLVMVoidType, + LLVMX86FP80Type, + LLVMX86MMXType + >(); + // clang-format on +} inline bool isCompatibleFloatingPointType(Type type) { return type.isa -struct DenseMapInfo { - static mlir::LLVM::LLVMType getEmptyKey() { - void *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlir::LLVM::LLVMType( - static_cast(pointer)); - } - static mlir::LLVM::LLVMType getTombstoneKey() { - void *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlir::LLVM::LLVMType( - static_cast(pointer)); - } - static unsigned getHashValue(mlir::LLVM::LLVMType val) { - return mlir::hash_value(val); - } - static bool isEqual(mlir::LLVM::LLVMType lhs, mlir::LLVM::LLVMType rhs) { - return lhs == rhs; - } -}; - -// LLVMType behaves like a pointer similarly to mlir::Type. -template <> -struct PointerLikeTypeTraits { - static inline void *getAsVoidPointer(mlir::LLVM::LLVMType type) { - return const_cast(type.getAsOpaquePointer()); - } - static inline mlir::LLVM::LLVMType getFromVoidPointer(void *ptr) { - return mlir::LLVM::LLVMType::getFromOpaquePointer(ptr); - } - static constexpr int NumLowBitsAvailable = - PointerLikeTypeTraits::NumLowBitsAvailable; -}; - -} // namespace llvm - #endif // MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -149,11 +149,11 @@ LLVM_Type:$slc)>{ string llvmBuilder = [{ $res = createIntrinsicCall(builder, - llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc, + llvm::Intrinsic::amdgcn_buffer_load, {$rsrc, $vindex, $offset, $glc, $slc}, {$_resultType}); }]; let parser = [{ return parseROCDLMubufLoadOp(parser, result); }]; - let printer = [{ + let printer = [{ Operation *op = this->getOperation(); p << op->getName() << " " << op->getOperands() << " : " << op->getResultTypes(); @@ -169,7 +169,7 @@ LLVM_Type:$glc, LLVM_Type:$slc)>{ string llvmBuilder = [{ - auto vdataType = convertType(op.vdata().getType().cast()); + auto vdataType = convertType(op.vdata().getType()); createIntrinsicCall(builder, llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, $offset, $glc, $slc}, {vdataType}); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -104,7 +104,7 @@ llvm::IRBuilder<> &builder); /// Converts the type from MLIR LLVM dialect to LLVM. - llvm::Type *convertType(LLVMType type); + llvm::Type *convertType(Type type); static std::unique_ptr prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, diff --git a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/TypeTranslation.h @@ -24,12 +24,11 @@ namespace mlir { +class Type; class MLIRContext; namespace LLVM { -class LLVMType; - namespace detail { class TypeToLLVMIRTranslatorImpl; class TypeFromLLVMIRTranslatorImpl; @@ -47,11 +46,10 @@ /// that this will perform type conversion and store its results for future /// uses. // TODO: this should be removed when MLIR has proper data layout. - unsigned getPreferredAlignment(LLVM::LLVMType type, - const llvm::DataLayout &layout); + unsigned getPreferredAlignment(Type type, const llvm::DataLayout &layout); /// Translates the given MLIR LLVM dialect type to LLVM IR. - llvm::Type *translateType(LLVM::LLVMType type); + llvm::Type *translateType(Type type); private: /// Private implementation. @@ -67,7 +65,7 @@ ~TypeFromLLVMIRTranslator(); /// Translates the given LLVM IR type to the MLIR LLVM dialect. - LLVM::LLVMType translateType(llvm::Type *type); + Type translateType(llvm::Type *type); private: /// Private implementation. diff --git a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp --- a/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.cpp @@ -36,15 +36,14 @@ OneToOneConvertToLLVMPattern; // Extract an LLVM IR type from the LLVM IR dialect type. -static LLVM::LLVMType unwrap(Type type) { +static Type unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); - auto wrappedLLVMType = type.dyn_cast(); - if (!wrappedLLVMType) + if (!LLVM::isCompatibleType(type)) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); - return wrappedLLVMType; + return type; } static Optional diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -149,7 +149,7 @@ } // Auxiliary coroutine resume intrinsic wrapper. - static LLVM::LLVMType resumeFunctionType(MLIRContext *ctx) { + static Type resumeFunctionType(MLIRContext *ctx) { auto voidTy = LLVM::LLVMVoidType::get(ctx); auto i8Ptr = opaquePointerType(ctx); return LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}, false); @@ -203,13 +203,11 @@ static constexpr const char *kCoroFree = "llvm.coro.free"; static constexpr const char *kCoroResume = "llvm.coro.resume"; -/// Adds an LLVM function declaration to a module. static void addLLVMFuncDecl(ModuleOp module, ImplicitLocOpBuilder &builder, - StringRef name, LLVM::LLVMType ret, - ArrayRef params) { + StringRef name, Type ret, ArrayRef params) { if (module.lookupSymbol(name)) return; - LLVM::LLVMType type = LLVM::LLVMFunctionType::get(ret, params); + Type type = LLVM::LLVMFunctionType::get(ret, params); builder.create(name, type); } @@ -386,8 +384,7 @@ // http://nondot.org/sabre/LLVMNotes/SizeOf-OffsetOf-VariableSizedStructs.txt auto sizeOf = [&](ValueType valueType) -> Value { auto storedType = converter.convertType(valueType.getValueType()); - auto storagePtrType = - LLVM::LLVMPointerType::get(storedType.cast()); + auto storagePtrType = LLVM::LLVMPointerType::get(storedType); // %Size = getelementptr %T* null, int 1 // %SizeI = ptrtoint %T* %Size to i32 @@ -949,8 +946,7 @@ // Cast from i8* to the pointer pointer to LLVM type. auto llvmValueType = getTypeConverter()->convertType(valueType); auto castedStorage = rewriter.create( - loc, LLVM::LLVMPointerType::get(llvmValueType.cast()), - storage.getResult(0)); + loc, LLVM::LLVMPointerType::get(llvmValueType), storage.getResult(0)); // Load from the async value storage. auto loaded = rewriter.create(loc, castedStorage.getResult()); @@ -1015,9 +1011,7 @@ // Cast storage pointer to the yielded value type. auto castedStorage = rewriter.create( - loc, - LLVM::LLVMPointerType::get( - yieldValue.getType().cast()), + loc, LLVM::LLVMPointerType::get(yieldValue.getType()), storage.getResult(0)); // Store the yielded value into the async value storage. diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -52,8 +52,8 @@ class FunctionCallBuilder { public: - FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType, - ArrayRef argumentTypes) + FunctionCallBuilder(StringRef functionName, Type returnType, + ArrayRef argumentTypes) : functionName(functionName), functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {} LLVM::CallOp create(Location loc, OpBuilder &builder, @@ -73,15 +73,14 @@ protected: MLIRContext *context = &this->getTypeConverter()->getContext(); - LLVM::LLVMType llvmVoidType = LLVM::LLVMVoidType::get(context); - LLVM::LLVMType llvmPointerType = + Type llvmVoidType = LLVM::LLVMVoidType::get(context); + Type llvmPointerType = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); - LLVM::LLVMType llvmPointerPointerType = - LLVM::LLVMPointerType::get(llvmPointerType); - LLVM::LLVMType llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8); - LLVM::LLVMType llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32); - LLVM::LLVMType llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64); - LLVM::LLVMType llvmIntPtrType = LLVM::LLVMIntegerType::get( + Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType); + Type llvmInt8Type = LLVM::LLVMIntegerType::get(context, 8); + Type llvmInt32Type = LLVM::LLVMIntegerType::get(context, 32); + Type llvmInt64Type = LLVM::LLVMIntegerType::get(context, 64); + Type llvmIntPtrType = LLVM::LLVMIntegerType::get( context, this->getTypeConverter()->getPointerBitwidth(0)); FunctionCallBuilder moduleLoadCallBuilder = { @@ -321,7 +320,7 @@ static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter) { if (!llvm::all_of(operands, [](Value value) { - return value.getType().isa(); + return LLVM::isCompatibleType(value.getType()); })) return rewriter.notifyMatchFailure( op, "Cannot convert if operands aren't of LLVM type."); @@ -511,10 +510,10 @@ loc, launchOp.getOperands().take_back(numKernelOperands), operands.take_back(numKernelOperands), builder); auto numArguments = arguments.size(); - SmallVector argumentTypes; + SmallVector argumentTypes; argumentTypes.reserve(numArguments); for (auto argument : arguments) - argumentTypes.push_back(argument.getType().cast()); + argumentTypes.push_back(argument.getType()); auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), argumentTypes); auto one = builder.create(loc, llvmInt32Type, diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -38,7 +38,7 @@ uint64_t numElements = type.getNumElements(); auto elementType = typeConverter->convertType(type.getElementType()) - .template cast(); + .template cast(); auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); @@ -126,7 +126,7 @@ // memory space and does not support `alloca`s with addrspace(5). auto ptrType = LLVM::LLVMPointerType::get( typeConverter->convertType(type.getElementType()) - .template cast(), + .template cast(), AllocaAddrSpace); Value numElements = rewriter.create( gpuFuncOp.getLoc(), int64Ty, diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -40,7 +40,6 @@ matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { using LLVM::LLVMFuncOp; - using LLVM::LLVMType; static_assert( std::is_base_of, SourceOp>::value, @@ -54,9 +53,8 @@ for (Value operand : operands) castedOperands.push_back(maybeCast(operand, rewriter)); - LLVMType resultType = - castedOperands.front().getType().cast(); - LLVMType funcType = getFunctionType(resultType, castedOperands); + Type resultType = castedOperands.front().getType(); + Type funcType = getFunctionType(resultType, castedOperands); StringRef funcName = getFunctionName( funcType.cast().getReturnType()); if (funcName.empty()) @@ -80,7 +78,7 @@ private: Value maybeCast(Value operand, PatternRewriter &rewriter) const { - LLVM::LLVMType type = operand.getType().cast(); + Type type = operand.getType(); if (!type.isa()) return operand; @@ -89,17 +87,15 @@ operand); } - LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, - ArrayRef operands) const { - using LLVM::LLVMType; - SmallVector operandTypes; + Type getFunctionType(Type resultType, ArrayRef operands) const { + SmallVector operandTypes; for (Value operand : operands) { - operandTypes.push_back(operand.getType().cast()); + operandTypes.push_back(operand.getType()); } return LLVM::LLVMFunctionType::get(resultType, operandTypes); } - StringRef getFunctionName(LLVM::LLVMType type) const { + StringRef getFunctionName(Type type) const { if (type.isa()) return f32Func; if (type.isa()) @@ -107,8 +103,7 @@ return ""; } - LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, - LLVM::LLVMType funcType, + LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, Operation *op) const { using LLVM::LLVMFuncOp; diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -56,7 +56,7 @@ Location loc = op->getLoc(); gpu::ShuffleOpAdaptor adaptor(operands); - auto valueTy = adaptor.value().getType().cast(); + auto valueTy = adaptor.value().getType(); auto int32Type = LLVM::LLVMIntegerType::get(rewriter.getContext(), 32); auto predTy = LLVM::LLVMIntegerType::get(rewriter.getContext(), 1); auto resultTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -65,7 +65,7 @@ llvmInt64Type = LLVM::LLVMIntegerType::get(&getContext(), 64); } - LLVM::LLVMType getMemRefType(uint32_t rank, LLVM::LLVMType elemenType) { + Type getMemRefType(uint32_t rank, Type elemenType) { // According to the MLIR doc memref argument is converted into a // pointer-to-struct argument of type: // template @@ -89,10 +89,10 @@ llvmArrayRankElementSizeType, llvmArrayRankElementSizeType}); } - LLVM::LLVMType getVoidType() { return llvmVoidType; } - LLVM::LLVMType getPointerType() { return llvmPointerType; } - LLVM::LLVMType getInt32Type() { return llvmInt32Type; } - LLVM::LLVMType getInt64Type() { return llvmInt64Type; } + Type getVoidType() { return llvmVoidType; } + Type getPointerType() { return llvmPointerType; } + Type getInt32Type() { return llvmInt32Type; } + Type getInt64Type() { return llvmInt64Type; } /// Creates an LLVM global for the given `name`. Value createEntryPointNameConstant(StringRef name, Location loc, @@ -128,10 +128,10 @@ /// Deduces a rank and element type from the given 'ptrToMemRefDescriptor`. LogicalResult deduceMemRefRankAndType(Value ptrToMemRefDescriptor, - uint32_t &rank, LLVM::LLVMType &type); + uint32_t &rank, Type &type); /// Returns a string representation from the given `type`. - StringRef stringifyType(LLVM::LLVMType type) { + StringRef stringifyType(Type type) { if (type.isa()) return "Float"; if (type.isa()) @@ -152,11 +152,11 @@ void runOnOperation() override; private: - LLVM::LLVMType llvmFloatType; - LLVM::LLVMType llvmVoidType; - LLVM::LLVMType llvmPointerType; - LLVM::LLVMType llvmInt32Type; - LLVM::LLVMType llvmInt64Type; + Type llvmFloatType; + Type llvmVoidType; + Type llvmPointerType; + Type llvmInt32Type; + Type llvmInt64Type; // TODO: Use an associative array to support multiple vulkan launch calls. std::pair spirvAttributes; @@ -230,7 +230,7 @@ auto ptrToMemRefDescriptor = en.value(); uint32_t rank = 0; - LLVM::LLVMType type; + Type type; if (failed(deduceMemRefRankAndType(ptrToMemRefDescriptor, rank, type))) { cInterfaceVulkanLaunchCallOp.emitError() << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); @@ -258,7 +258,7 @@ } LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( - Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { + Value ptrToMemRefDescriptor, uint32_t &rank, Type &type) { auto llvmPtrDescriptorTy = ptrToMemRefDescriptor.getType().dyn_cast(); if (!llvmPtrDescriptorTy) @@ -324,12 +324,11 @@ } for (unsigned i = 1; i <= 3; i++) { - SmallVector types{ - LLVM::LLVMFloatType::get(&getContext()), - LLVM::LLVMIntegerType::get(&getContext(), 32), - LLVM::LLVMIntegerType::get(&getContext(), 16), - LLVM::LLVMIntegerType::get(&getContext(), 8), - LLVM::LLVMHalfType::get(&getContext())}; + SmallVector types{LLVM::LLVMFloatType::get(&getContext()), + LLVM::LLVMIntegerType::get(&getContext(), 32), + LLVM::LLVMIntegerType::get(&getContext(), 16), + LLVM::LLVMIntegerType::get(&getContext(), 8), + LLVM::LLVMHalfType::get(&getContext())}; for (auto type : types) { std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -67,11 +67,9 @@ using llvm_return = OperationBuilder; template -static LLVMType getPtrToElementType(T containerType, - LLVMTypeConverter &lowering) { - return lowering.convertType(containerType.getElementType()) - .template cast() - .getPointerTo(); +static Type getPtrToElementType(T containerType, LLVMTypeConverter &lowering) { + return LLVMPointerType::get( + lowering.convertType(containerType.getElementType())); } /// Convert the given range descriptor type to the LLVMIR dialect. @@ -84,8 +82,7 @@ /// }; static Type convertRangeType(RangeType t, LLVMTypeConverter &converter) { auto *context = t.getContext(); - auto int64Ty = converter.convertType(IntegerType::get(context, 64)) - .cast(); + auto int64Ty = converter.convertType(IntegerType::get(context, 64)); return LLVMStructType::getLiteral(context, {int64Ty, int64Ty, int64Ty}); } @@ -206,8 +203,7 @@ BaseViewConversionHelper baseDesc(adaptor.view()); auto memRefType = sliceOp.getBaseViewType(); - auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64)) - .cast(); + auto int64Ty = typeConverter->convertType(rewriter.getIntegerType(64)); BaseViewConversionHelper desc( typeConverter->convertType(sliceOp.getShapedType())); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -184,7 +184,7 @@ kernelFunc = rewriter.create( rewriter.getUnknownLoc(), newKernelFuncName, LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context), - ArrayRef())); + ArrayRef())); rewriter.setInsertionPoint(launchOp); } @@ -234,7 +234,7 @@ OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); dstGlobal = rewriter.create( - loc, dstGlobalType.cast(), + loc, dstGlobalType, /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute()); rewriter.setInsertionPoint(launchOp); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -65,7 +65,7 @@ } /// Returns the bit width of LLVMType integer or vector. -static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { +static unsigned getLLVMTypeBitWidth(Type type) { auto vectorType = type.dyn_cast(); return (vectorType ? vectorType.getElementType() : type) .cast() @@ -115,16 +115,15 @@ /// - `BitFieldSExtract` /// - `BitFieldUExtract` /// Truncates or extends the value. If the bitwidth of the value is the same as -/// `dstType` bitwidth, the value remains unchanged. -static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, +/// `llvmType` bitwidth, the value remains unchanged. +static Value optionallyTruncateOrExtend(Location loc, Value value, + Type llvmType, PatternRewriter &rewriter) { auto srcType = value.getType(); - auto llvmType = dstType.cast(); unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); - unsigned valueBitWidth = - srcType.isa() - ? getLLVMTypeBitWidth(srcType.cast()) - : getBitWidth(srcType); + unsigned valueBitWidth = LLVM::isCompatibleType(srcType) + ? getLLVMTypeBitWidth(srcType) + : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) return rewriter.create(loc, llvmType, value); @@ -193,7 +192,7 @@ auto elementsVector = llvm::to_vector<8>( llvm::map_range(type.getElementTypes(), [&](Type elementType) { - return converter.convertType(elementType).cast(); + return converter.convertType(elementType); })); return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, /*isPacked=*/false); @@ -204,7 +203,7 @@ LLVMTypeConverter &converter) { auto elementsVector = llvm::to_vector<8>( llvm::map_range(type.getElementTypes(), [&](Type elementType) { - return converter.convertType(elementType).cast(); + return converter.convertType(elementType); })); return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector, /*isPacked=*/true); @@ -255,8 +254,7 @@ !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride)) return llvm::None; - auto llvmElementType = - converter.convertType(elementType).cast(); + auto llvmElementType = converter.convertType(elementType); unsigned numElements = type.getNumElements(); return LLVM::LLVMArrayType::get(llvmElementType, numElements); } @@ -265,8 +263,7 @@ /// modelled at the moment. static Type convertPointerType(spirv::PointerType type, TypeConverter &converter) { - auto pointeeType = - converter.convertType(type.getPointeeType()).cast(); + auto pointeeType = converter.convertType(type.getPointeeType()); return LLVM::LLVMPointerType::get(pointeeType); } @@ -277,8 +274,7 @@ TypeConverter &converter) { if (type.getArrayStride() != 0) return llvm::None; - auto elementType = - converter.convertType(type.getElementType()).cast(); + auto elementType = converter.convertType(type.getElementType()); return LLVM::LLVMArrayType::get(elementType, 0); } @@ -336,8 +332,7 @@ auto dstType = typeConverter.convertType(op.pointer().getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp( - op, dstType.cast(), op.variable()); + rewriter.replaceOpWithNewOp(op, dstType, op.variable()); return success(); } }; @@ -667,7 +662,7 @@ // int32_t values[]; // optional values // }; auto llvmI32Type = LLVM::LLVMIntegerType::get(context, 32); - SmallVector fields; + SmallVector fields; fields.push_back(llvmI32Type); ArrayAttr values = op.values(); if (!values.empty()) { @@ -757,8 +752,7 @@ ? LLVM::Linkage::Private : LLVM::Linkage::External; rewriter.replaceOpWithNewOp( - op, dstType.cast(), isConstant, linkage, op.sym_name(), - Attribute()); + op, dstType, isConstant, linkage, op.sym_name(), Attribute()); return success(); } }; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -41,15 +41,15 @@ #define PASS_NAME "convert-std-to-llvm" // Extract an LLVM IR type from the LLVM IR dialect type. -static LLVM::LLVMType unwrap(Type type) { +static Type unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); - auto wrappedLLVMType = type.dyn_cast(); - if (!wrappedLLVMType) + if (!LLVM::isCompatibleType(type)) emitError(UnknownLoc::get(mlirContext), - "conversion resulted in a non-LLVM type"); - return wrappedLLVMType; + "conversion resulted in a non-LLVM type ") + << type; + return type; } /// Callback to convert function argument types. It converts a MemRef function @@ -120,8 +120,11 @@ [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); addConversion([&](VectorType type) { return convertVectorType(type); }); - // LLVMType is legal, so add a pass-through conversion. - addConversion([](LLVM::LLVMType type) { return type; }); + // LLVM-compatible types are legal, so add a pass-through conversion. + addConversion([](Type type) { + return LLVM::isCompatibleType(type) ? llvm::Optional(type) + : llvm::None; + }); // Materialization for memrefs creates descriptor structs from individual // values constituting them, when descriptors are used, i.e. more than one @@ -170,7 +173,7 @@ return *getDialect()->getContext(); } -LLVM::LLVMType LLVMTypeConverter::getIndexType() { +Type LLVMTypeConverter::getIndexType() { return LLVM::LLVMIntegerType::get(&getContext(), getIndexTypeBitwidth()); } @@ -205,7 +208,7 @@ static constexpr unsigned kRealPosInComplexNumberStruct = 0; static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; Type LLVMTypeConverter::convertComplexType(ComplexType type) { - auto elementType = convertType(type.getElementType()).cast(); + auto elementType = convertType(type.getElementType()); return LLVM::LLVMStructType::getLiteral(&getContext(), {elementType, elementType}); } @@ -214,7 +217,7 @@ // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); - LLVM::LLVMType converted = + Type converted = convertFunctionSignature(type, /*isVariadic=*/false, conversion); return LLVM::LLVMPointerType::get(converted); } @@ -224,7 +227,7 @@ // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. -LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( +Type LLVMTypeConverter::convertFunctionSignature( FunctionType funcTy, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convention. @@ -240,7 +243,7 @@ result.addInputs(en.index(), converted); } - SmallVector argTypes; + SmallVector argTypes; argTypes.reserve(llvm::size(result.getConvertedTypes())); for (Type type : result.getConvertedTypes()) argTypes.push_back(unwrap(type)); @@ -248,10 +251,9 @@ // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. - LLVM::LLVMType resultType = - funcTy.getNumResults() == 0 - ? LLVM::LLVMVoidType::get(&getContext()) - : unwrap(packFunctionResults(funcTy.getResults())); + Type resultType = funcTy.getNumResults() == 0 + ? LLVM::LLVMVoidType::get(&getContext()) + : unwrap(packFunctionResults(funcTy.getResults())); if (!resultType) return {}; return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic); @@ -259,23 +261,21 @@ /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. -LLVM::LLVMType -LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { - SmallVector inputs; +Type LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { + SmallVector inputs; for (Type t : type.getInputs()) { - auto converted = convertType(t).dyn_cast_or_null(); - if (!converted) + auto converted = convertType(t); + if (!converted || !LLVM::isCompatibleType(converted)) return {}; if (t.isa()) converted = LLVM::LLVMPointerType::get(converted); inputs.push_back(converted); } - LLVM::LLVMType resultType = - type.getNumResults() == 0 - ? LLVM::LLVMVoidType::get(&getContext()) - : unwrap(packFunctionResults(type.getResults())); + Type resultType = type.getNumResults() == 0 + ? LLVM::LLVMVoidType::get(&getContext()) + : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; @@ -316,19 +316,19 @@ /// Index sizes[Rank]; // omitted when rank == 0 /// Index strides[Rank]; // omitted when rank == 0 /// }; -SmallVector +SmallVector LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) { assert(isStrided(type) && "Non-strided layout maps must have been normalized away"); - LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + Type elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); auto indexTy = getIndexType(); - SmallVector results = {ptrTy, ptrTy, indexTy}; + SmallVector results = {ptrTy, ptrTy, indexTy}; auto rank = type.getRank(); if (rank == 0) return results; @@ -345,7 +345,7 @@ Type LLVMTypeConverter::convertMemRefType(MemRefType type) { // When converting a MemRefType to a struct with descriptor fields, do not // unpack the `sizes` and `strides` arrays. - SmallVector types = + SmallVector types = getMemRefDescriptorFields(type, /*unpackAggregates=*/false); return LLVM::LLVMStructType::getLiteral(&getContext(), types); } @@ -360,8 +360,7 @@ /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be /// stack allocated (alloca) copy of a MemRef descriptor that got casted to /// be unranked. -SmallVector -LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { +SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { return {getIndexType(), LLVM::LLVMPointerType::get( LLVM::LLVMIntegerType::get(&getContext(), 8))}; } @@ -395,7 +394,7 @@ if (ShapedType::isDynamicStrideOrOffset(offset)) return {}; - LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); + Type elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); @@ -409,7 +408,7 @@ auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - LLVM::LLVMType vectorType = + Type vectorType = LLVM::LLVMFixedVectorType::get(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) @@ -454,10 +453,9 @@ // StructBuilder implementation //===----------------------------------------------------------------------===// -StructBuilder::StructBuilder(Value v) : value(v) { +StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) { assert(value != nullptr && "value cannot be null"); - structType = value.getType().dyn_cast(); - assert(structType && "expected llvm type"); + assert(LLVM::isCompatibleType(structType) && "expected llvm type"); } Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, @@ -479,7 +477,7 @@ ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, Location loc, Type type) { - Value val = builder.create(loc, type.cast()); + Value val = builder.create(loc, type); return ComplexStructBuilder(val); } @@ -518,8 +516,7 @@ MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = - builder.create(loc, descriptorType.cast()); + Value descriptor = builder.create(loc, descriptorType); return MemRefDescriptor(descriptor); } @@ -620,9 +617,8 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, int64_t rank) { - auto indexTy = indexType.cast(); - auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy); - auto arrayTy = LLVM::LLVMArrayType::get(indexTy, rank); + auto indexPtrTy = LLVM::LLVMPointerType::get(indexType); + auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy); // Copy size values to stack-allocated memory. @@ -774,8 +770,7 @@ UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { - Value descriptor = - builder.create(loc, descriptorType.cast()); + Value descriptor = builder.create(loc, descriptorType); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { @@ -828,7 +823,7 @@ return; // Cache the index type. - LLVM::LLVMType indexType = typeConverter.getIndexType(); + Type indexType = typeConverter.getIndexType(); // Initialize shared constants. Value one = createIndexAttrConstant(builder, loc, indexType, 1); @@ -868,7 +863,7 @@ Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType) { + Type elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); @@ -877,7 +872,7 @@ void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType, + Type elemPtrPtrType, Value allocatedPtr) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); @@ -887,7 +882,7 @@ Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType) { + Type elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); @@ -901,7 +896,7 @@ void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType, + Type elemPtrPtrType, Value alignedPtr) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); @@ -916,7 +911,7 @@ Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType) { + Type elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); @@ -932,8 +927,7 @@ void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType, - Value offset) { + Type elemPtrPtrType, Value offset) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); @@ -949,16 +943,15 @@ Value UnrankedMemRefDescriptor::sizeBasePtr( OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) { - LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType(); - LLVM::LLVMType indexTy = typeConverter.getIndexType(); - LLVM::LLVMType structPtrTy = + Type elemPtrTy = elemPtrPtrType.getElementType(); + Type indexTy = typeConverter.getIndexType(); + Type structPtrTy = LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral( indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy})); Value structPtr = builder.create(loc, structPtrTy, memRefDescPtr); - LLVM::LLVMType int32_type = - unwrap(typeConverter.convertType(builder.getI32Type())); + Type int32_type = unwrap(typeConverter.convertType(builder.getI32Type())); Value zero = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); Value three = builder.create(loc, int32_type, @@ -970,8 +963,7 @@ Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index) { - LLVM::LLVMType indexPtrTy = - LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); return builder.create(loc, sizeStoreGep); @@ -981,8 +973,7 @@ LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index, Value size) { - LLVM::LLVMType indexPtrTy = - LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); builder.create(loc, size, sizeStoreGep); @@ -991,8 +982,7 @@ Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank) { - LLVM::LLVMType indexPtrTy = - LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); return builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({rank})); } @@ -1001,8 +991,7 @@ LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { - LLVM::LLVMType indexPtrTy = - LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); return builder.create(loc, strideStoreGep); @@ -1012,8 +1001,7 @@ LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { - LLVM::LLVMType indexPtrTy = - LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); builder.create(loc, stride, strideStoreGep); @@ -1028,22 +1016,21 @@ return *getTypeConverter()->getDialect(); } -LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { +Type ConvertToLLVMPattern::getIndexType() const { return getTypeConverter()->getIndexType(); } -LLVM::LLVMType -ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { +Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { return LLVM::LLVMIntegerType::get( &getTypeConverter()->getContext(), getTypeConverter()->getPointerBitwidth(addressSpace)); } -LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { +Type ConvertToLLVMPattern::getVoidType() const { return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); } -LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { +Type ConvertToLLVMPattern::getVoidPtrType() const { return LLVM::LLVMPointerType::get( LLVM::LLVMIntegerType::get(&getTypeConverter()->getContext(), 8)); } @@ -1084,7 +1071,7 @@ index ? rewriter.create(loc, index, increment) : increment; } - LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType(); + Type elementPtrType = memRefDescriptor.getElementPtrType(); return index ? rewriter.create(loc, elementPtrType, base, index) : base; } @@ -1161,8 +1148,8 @@ // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. - auto convertedPtrType = LLVM::LLVMPointerType::get( - typeConverter->convertType(type).cast()); + auto convertedPtrType = + LLVM::LLVMPointerType::get(typeConverter->convertType(type)); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( loc, convertedPtrType, @@ -1276,7 +1263,7 @@ FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { OpBuilder::InsertionGuard guard(builder); - LLVM::LLVMType wrapperType = + Type wrapperType = typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); // This conversion can only fail if it could not convert one of the argument // types. But since it has been applies to a non-wrapper function before, it @@ -1318,8 +1305,7 @@ builder, loc, typeConverter, unrankedMemRefType, wrapperArgsRange.take_front(numToDrop)); - auto ptrTy = - LLVM::LLVMPointerType::get(packed.getType().cast()); + auto ptrTy = LLVM::LLVMPointerType::get(packed.getType()); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); @@ -1494,9 +1480,9 @@ // 1-D LLVM vectors. struct NDVectorTypeInfo { // LLVM array struct which encodes n-D vectors. - LLVM::LLVMType llvmArrayTy; + Type llvmArrayTy; // LLVM vector type which encodes the inner 1-D vector type. - LLVM::LLVMType llvmVectorTy; + Type llvmVectorTy; // Multiplicity of llvmArrayTy to llvmVectorTy. SmallVector arraySizes; }; @@ -1510,10 +1496,11 @@ LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "expected >1D vector type"); NDVectorTypeInfo info; - info.llvmArrayTy = - converter.convertType(vectorType).dyn_cast(); - if (!info.llvmArrayTy) + info.llvmArrayTy = converter.convertType(vectorType); + if (!info.llvmArrayTy || !LLVM::isCompatibleType(info.llvmArrayTy)) { + info.llvmArrayTy = nullptr; return info; + } info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmArrayTy; while (llvmTy.isa()) { @@ -1610,14 +1597,14 @@ static LogicalResult handleMultidimensionalVectors( Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, - std::function createOperand, + std::function createOperand, ConversionPatternRewriter &rewriter) { auto vectorType = op->getResult(0).getType().dyn_cast(); if (!vectorType) return failure(); auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter); auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; - auto llvmArrayTy = operands[0].getType().cast(); + auto llvmArrayTy = operands[0].getType(); if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return failure(); @@ -1645,14 +1632,14 @@ // Cannot convert ops if their operands are not of LLVM type. if (!llvm::all_of(operands.getTypes(), - [](Type t) { return t.isa(); })) + [](Type t) { return isCompatibleType(t); })) return failure(); - auto llvmArrayTy = operands[0].getType().cast(); + auto llvmArrayTy = operands[0].getType(); if (!llvmArrayTy.isa()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); - auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy, + auto callback = [op, targetOp, &rewriter](Type llvmVectorTy, ValueRange operands) { OperationState state(op->getLoc(), targetOp); state.addTypes(llvmVectorTy); @@ -1896,16 +1883,18 @@ ConversionPatternRewriter &rewriter) const override { // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { - auto type = typeConverter->convertType(op.getResult().getType()) - .dyn_cast_or_null(); - if (!type) + auto type = typeConverter->convertType(op.getResult().getType()); + if (!type || !LLVM::isCompatibleType(type)) return rewriter.notifyMatchFailure(op, "failed to convert result type"); - NamedAttrList attrs(op->getAttrDictionary()); - attrs.erase("value"); - rewriter.replaceOpWithNewOp( - op, type.cast(), symbolRef.getValue(), - attrs.getAttrs()); + auto newOp = rewriter.create(op.getLoc(), type, + symbolRef.getValue()); + for (const NamedAttribute &attr : op->getAttrs()) { + if (attr.first.strref() == "value") + continue; + newOp.setAttr(attr.first, attr.second); + } + rewriter.replaceOp(op, newOp->getResults()); return success(); } @@ -1947,11 +1936,11 @@ Value createAllocCall(Location loc, StringRef name, Type ptrType, ArrayRef params, ModuleOp module, ConversionPatternRewriter &rewriter) const { - SmallVector paramTypes; + SmallVector paramTypes; auto allocFuncOp = module.lookupSymbol(name); if (!allocFuncOp) { for (Value param : params) - paramTypes.push_back(param.getType().cast()); + paramTypes.push_back(param.getType()); auto allocFuncType = LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); OpBuilder::InsertionGuard guard(rewriter); @@ -2206,10 +2195,10 @@ // Get frequently used types. MLIRContext *context = builder.getContext(); auto voidType = LLVM::LLVMVoidType::get(context); - LLVM::LLVMType voidPtrType = + Type voidPtrType = LLVM::LLVMPointerType::get(LLVM::LLVMIntegerType::get(context, 8)); auto i1Type = LLVM::LLVMIntegerType::get(context, 1); - LLVM::LLVMType indexType = typeConverter.getIndexType(); + Type indexType = typeConverter.getIndexType(); // Find the malloc and free, or declare them if necessary. auto module = builder.getInsertionPoint()->getParentOfType(); @@ -2389,17 +2378,15 @@ }; /// Returns the LLVM type of the global variable given the memref type `type`. -static LLVM::LLVMType -convertGlobalMemrefTypeToLLVM(MemRefType type, - LLVMTypeConverter &typeConverter) { +static Type convertGlobalMemrefTypeToLLVM(MemRefType type, + LLVMTypeConverter &typeConverter) { // LLVM type for a global memref will be a multi-dimension array. For // declarations or uninitialized global memrefs, we can potentially flatten // this to a 1D array. However, for global_memref's with an initial value, // we do not intend to flatten the ElementsAttribute when going from std -> // LLVM dialect, so the LLVM type needs to me a multi-dimension array. - LLVM::LLVMType elementType = - unwrap(typeConverter.convertType(type.getElementType())); - LLVM::LLVMType arrayTy = elementType; + Type elementType = unwrap(typeConverter.convertType(type.getElementType())); + Type arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) arrayTy = LLVM::LLVMArrayType::get(arrayTy, dim); @@ -2417,8 +2404,7 @@ if (!isConvertibleAndHasIdentityMaps(type)) return failure(); - LLVM::LLVMType arrayTy = - convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); + Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; @@ -2457,17 +2443,15 @@ MemRefType type = getGlobalOp.result().getType().cast(); unsigned memSpace = type.getMemorySpace(); - LLVM::LLVMType arrayTy = - convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); + Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. - LLVM::LLVMType elementType = + Type elementType = unwrap(typeConverter->convertType(type.getElementType())); - LLVM::LLVMType elementPtrType = - LLVM::LLVMPointerType::get(elementType, memSpace); + Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); SmallVector operands = {addressOf}; operands.insert(operands.end(), type.getRank() + 1, @@ -2497,10 +2481,9 @@ matchAndRewrite(RsqrtOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RsqrtOp::Adaptor transformed(operands); - auto operandType = - transformed.operand().getType().dyn_cast(); + auto operandType = transformed.operand().getType(); - if (!operandType) + if (!operandType || !LLVM::isCompatibleType(operandType)) return failure(); auto loc = op.getLoc(); @@ -2528,7 +2511,7 @@ return handleMultidimensionalVectors( op.getOperation(), operands, *getTypeConverter(), - [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { + [&](Type llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get( {llvmVectorTy.cast() @@ -2620,13 +2603,11 @@ // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* - auto castPtr = rewriter - .create( - loc, - LLVM::LLVMPointerType::get( - targetStructType.cast()), - ptr) - .getResult(); + auto castPtr = + rewriter + .create( + loc, LLVM::LLVMPointerType::get(targetStructType), ptr) + .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); @@ -2659,9 +2640,8 @@ unsigned memorySpace = operandType.cast().getMemorySpace(); Type elementType = operandType.cast().getElementType(); - LLVM::LLVMType llvmElementType = - unwrap(typeConverter.convertType(elementType)); - LLVM::LLVMType elementPtrPtrType = LLVM::LLVMPointerType::get( + Type llvmElementType = unwrap(typeConverter.convertType(elementType)); + Type elementPtrPtrType = LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); // Extract pointer to the underlying ranked memref descriptor and cast it to @@ -2809,8 +2789,7 @@ &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. - LLVM::LLVMType llvmElementType = - unwrap(typeConverter->convertType(elementType)); + Type llvmElementType = unwrap(typeConverter->convertType(elementType)); auto elementPtrPtrType = LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, @@ -2835,7 +2814,7 @@ rewriter.create(loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); - LLVM::LLVMType indexType = getTypeConverter()->getIndexType(); + Type indexType = getTypeConverter()->getIndexType(); Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, @@ -2865,7 +2844,7 @@ rewriter.setInsertionPointToStart(bodyBlock); // Copy size from shape to descriptor. - LLVM::LLVMType llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); + Type llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); Value sizeLoadGep = rewriter.create( loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); Value size = rewriter.create(loc, sizeLoadGep); @@ -2957,9 +2936,8 @@ Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, - LLVM::LLVMPointerType::get( - typeConverter->convertType(scalarMemRefType).cast(), - addressSpace), + LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), + addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. @@ -3435,8 +3413,7 @@ auto sourceMemRefType = subViewOp.source().getType().cast(); auto sourceElementTy = - typeConverter->convertType(sourceMemRefType.getElementType()) - .dyn_cast_or_null(); + typeConverter->convertType(sourceMemRefType.getElementType()); auto viewMemRefType = subViewOp.getType(); auto inferredType = SubViewOp::inferResultType( @@ -3446,11 +3423,12 @@ extractFromI64ArrayAttr(subViewOp.static_strides())) .cast(); auto targetElementTy = - typeConverter->convertType(viewMemRefType.getElementType()) - .dyn_cast(); - auto targetDescTy = typeConverter->convertType(viewMemRefType) - .dyn_cast_or_null(); - if (!sourceElementTy || !targetDescTy) + typeConverter->convertType(viewMemRefType.getElementType()); + auto targetDescTy = typeConverter->convertType(viewMemRefType); + if (!sourceElementTy || !targetDescTy || !targetElementTy || + !LLVM::isCompatibleType(sourceElementTy) || + !LLVM::isCompatibleType(targetElementTy) || + !LLVM::isCompatibleType(targetDescTy)) return failure(); // Extract the offset and strides from the type. @@ -3461,7 +3439,7 @@ return failure(); // Create the descriptor. - if (!operands.front().getType().isa()) + if (!LLVM::isCompatibleType(operands.front().getType())) return failure(); MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); @@ -3650,11 +3628,11 @@ auto viewMemRefType = viewOp.getType(); auto targetElementTy = - typeConverter->convertType(viewMemRefType.getElementType()) - .dyn_cast(); - auto targetDescTy = - typeConverter->convertType(viewMemRefType).dyn_cast(); - if (!targetDescTy) + typeConverter->convertType(viewMemRefType.getElementType()); + auto targetDescTy = typeConverter->convertType(viewMemRefType); + if (!targetDescTy || !targetElementTy || + !LLVM::isCompatibleType(targetElementTy) || + !LLVM::isCompatibleType(targetDescTy)) return viewOp.emitWarning("Target descriptor type not converted to LLVM"), failure(); @@ -3849,9 +3827,7 @@ auto loc = atomicOp.getLoc(); GenericAtomicRMWOp::Adaptor adaptor(operands); - LLVM::LLVMType valueType = - typeConverter->convertType(atomicOp.getResult().getType()) - .cast(); + Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); @@ -4060,12 +4036,11 @@ if (types.size() == 1) return convertCallingConventionType(types.front()); - SmallVector resultTypes; + SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { - auto converted = - convertCallingConventionType(t).dyn_cast_or_null(); - if (!converted) + auto converted = convertCallingConventionType(t); + if (!converted || !LLVM::isCompatibleType(converted)) return {}; resultTypes.push_back(converted); } @@ -4080,8 +4055,7 @@ auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. - auto ptrType = - LLVM::LLVMPointerType::get(operand.getType().cast()); + auto ptrType = LLVM::LLVMPointerType::get(operand.getType()); Value one = builder.create(loc, int64Ty, IntegerAttr::get(indexType, 1)); Value allocated = diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -152,8 +152,7 @@ // stop depending on translation. llvm::LLVMContext llvmContext; align = LLVM::TypeToLLVMIRTranslator(llvmContext) - .getPreferredAlignment(elementTy.cast(), - typeConverter.getDataLayout()); + .getPreferredAlignment(elementTy, typeConverter.getDataLayout()); return success(); } @@ -193,7 +192,7 @@ Value base; if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); - auto pType = LLVM::LLVMPointerType::get(type.template cast()); + auto pType = LLVM::LLVMPointerType::get(type); base = rewriter.create(loc, pType, base); ptr = rewriter.create(loc, pType, base); return success(); @@ -1401,7 +1400,7 @@ // Helper for printer method declaration (first hit) and lookup. static Operation *getPrint(Operation *op, StringRef name, - ArrayRef params) { + ArrayRef params) { auto module = op->getParentOfType(); auto func = module.lookupSymbol(name); if (func) diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -30,8 +30,8 @@ static LogicalResult replaceTransferOpWithMubuf( ConversionPatternRewriter &rewriter, ArrayRef operands, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, - LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex, - Value &offsetSizeInBytes, Value &glc, Value &slc) { + Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, + Value &glc, Value &slc) { rewriter.replaceOpWithNewOp( xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc); return success(); @@ -40,8 +40,8 @@ static LogicalResult replaceTransferOpWithMubuf( ConversionPatternRewriter &rewriter, ArrayRef operands, LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, - LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex, - Value &offsetSizeInBytes, Value &glc, Value &slc) { + Type &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, + Value &glc, Value &slc) { auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dwordConfig, vindex, @@ -121,16 +121,16 @@ Type i64Ty = rewriter.getIntegerType(64); Value i64x2Ty = rewriter.create( loc, - LLVM::LLVMFixedVectorType::get( - toLLVMTy(i64Ty).template cast(), 2), + LLVM::LLVMFixedVectorType::get(toLLVMTy(i64Ty).template cast(), + 2), constConfig); Value dataPtrAsI64 = rewriter.create( - loc, toLLVMTy(i64Ty).template cast(), dataPtr); + loc, toLLVMTy(i64Ty).template cast(), dataPtr); Value zero = this->createIndexConstant(rewriter, loc, 0); Value dwordConfig = rewriter.create( loc, - LLVM::LLVMFixedVectorType::get( - toLLVMTy(i64Ty).template cast(), 2), + LLVM::LLVMFixedVectorType::get(toLLVMTy(i64Ty).template cast(), + 2), i64x2Ty, dataPtrAsI64, zero); dwordConfig = rewriter.create(loc, toLLVMTy(i32Vecx4), dwordConfig); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -101,14 +101,14 @@ // The result type is either i1 or a vector type if the inputs are // vectors. - LLVMType resultType = LLVMIntegerType::get(builder.getContext(), 1); - auto argType = type.dyn_cast(); - if (!argType) - return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); - if (auto vecArgType = argType.dyn_cast()) + Type resultType = LLVMIntegerType::get(builder.getContext(), 1); + if (!isCompatibleType(type)) + return parser.emitError(trailingTypeLoc, + "expected LLVM dialect-compatible type"); + if (auto vecArgType = type.dyn_cast()) resultType = LLVMFixedVectorType::get(resultType, vecArgType.getNumElements()); - assert(!argType.isa() && + assert(!type.isa() && "unhandled scalable vector"); result.addTypes({resultType}); @@ -546,21 +546,21 @@ return parser.emitError(trailingTypeLoc, "expected function with 0 or 1 result"); - LLVM::LLVMType llvmResultType; + Type llvmResultType; if (funcType.getNumResults() == 0) { llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); } else { - llvmResultType = funcType.getResult(0).dyn_cast(); - if (!llvmResultType) + llvmResultType = funcType.getResult(0); + if (!isCompatibleType(llvmResultType)) return parser.emitError(trailingTypeLoc, "expected result to have LLVM type"); } - SmallVector argTypes; + SmallVector argTypes; argTypes.reserve(funcType.getNumInputs()); for (Type ty : funcType.getInputs()) { - if (auto argType = ty.dyn_cast()) - argTypes.push_back(argType); + if (isCompatibleType(ty)) + argTypes.push_back(ty); else return parser.emitError(trailingTypeLoc, "expected LLVM types as inputs"); @@ -693,7 +693,7 @@ // Type for the callee, we'll get it differently depending if it is a direct // or indirect call. - LLVMType fnType; + Type fnType; bool isIndirect = false; @@ -704,14 +704,10 @@ if (!op.getNumOperands()) return op.emitOpError( "must have either a `callee` attribute or at least an operand"); - fnType = op.getOperand(0).getType().dyn_cast(); - if (!fnType) - return op.emitOpError("indirect call to a non-llvm type: ") - << op.getOperand(0).getType(); - auto ptrType = fnType.dyn_cast(); + auto ptrType = op.getOperand(0).getType().dyn_cast(); if (!ptrType) return op.emitOpError("indirect call expects a pointer as callee: ") - << fnType; + << ptrType; fnType = ptrType.getElementType(); } else { Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName); @@ -825,21 +821,21 @@ "expected function with 0 or 1 result"); Builder &builder = parser.getBuilder(); - LLVM::LLVMType llvmResultType; + Type llvmResultType; if (funcType.getNumResults() == 0) { llvmResultType = LLVM::LLVMVoidType::get(builder.getContext()); } else { - llvmResultType = funcType.getResult(0).dyn_cast(); - if (!llvmResultType) + llvmResultType = funcType.getResult(0); + if (!isCompatibleType(llvmResultType)) return parser.emitError(trailingTypeLoc, "expected result to have LLVM type"); } - SmallVector argTypes; + SmallVector argTypes; argTypes.reserve(funcType.getNumInputs()); for (int i = 0, e = funcType.getNumInputs(); i < e; ++i) { - auto argType = funcType.getInput(i).dyn_cast(); - if (!argType) + auto argType = funcType.getInput(i); + if (!isCompatibleType(argType)) return parser.emitError(trailingTypeLoc, "expected LLVM types as inputs"); argTypes.push_back(argType); @@ -922,13 +918,13 @@ // `containerType`. Position is an integer array attribute where each value // is a zero-based position of the element in the aggregate type. Return the // resulting type wrapped in MLIR, or nullptr on error. -static LLVM::LLVMType getInsertExtractValueElementType(OpAsmParser &parser, - Type containerType, - ArrayAttr positionAttr, - llvm::SMLoc attributeLoc, - llvm::SMLoc typeLoc) { - auto llvmType = containerType.dyn_cast(); - if (!llvmType) +static Type getInsertExtractValueElementType(OpAsmParser &parser, + Type containerType, + ArrayAttr positionAttr, + llvm::SMLoc attributeLoc, + llvm::SMLoc typeLoc) { + Type llvmType = containerType; + if (!isCompatibleType(containerType)) return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; // Infer the element type from the structure type: iteratively step inside the @@ -1162,7 +1158,7 @@ /// the name of the attribute in ODS. static StringRef getLinkageAttrName() { return "linkage"; } -void GlobalOp::build(OpBuilder &builder, OperationState &result, LLVMType type, +void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type, bool isConstant, Linkage linkage, StringRef name, Attribute value, unsigned addrSpace, ArrayRef attrs) { @@ -1212,14 +1208,13 @@ /// report the error, the user is expected to produce an appropriate message. // TODO: make the size depend on data layout rather than on the conversion // pass option, and pull that information here. -static LogicalResult verifyCastWithIndex(LLVMType llvmType) { +static LogicalResult verifyCastWithIndex(Type llvmType) { return success(llvmType.isa()); } /// Checks if `llvmType` is dialect cast-compatible with built-in `type` and /// reports errors to the location of `op`. -static LogicalResult verifyCast(DialectCastOp op, LLVMType llvmType, - Type type) { +static LogicalResult verifyCast(DialectCastOp op, Type llvmType, Type type) { // Index is compatible with any integer. if (type.isIndex()) { if (succeeded(verifyCastWithIndex(llvmType))) @@ -1387,14 +1382,13 @@ } static LogicalResult verify(DialectCastOp op) { - if (auto llvmType = op.getType().dyn_cast()) - return verifyCast(op, llvmType, op.in().getType()); + if (isCompatibleType(op.getType())) + return verifyCast(op, op.getType(), op.in().getType()); - auto llvmType = op.in().getType().dyn_cast(); - if (!llvmType) + if (!isCompatibleType(op.in().getType())) return op->emitOpError("expected one LLVM type and one built-in type"); - return verifyCast(op, llvmType, op.getType()); + return verifyCast(op, op.in().getType(), op.getType()); } // Parses one of the keywords provided in the list `keywords` and returns the @@ -1597,7 +1591,7 @@ } void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, - StringRef name, LLVMType type, LLVM::Linkage linkage, + StringRef name, Type type, LLVM::Linkage linkage, ArrayRef attrs, ArrayRef argAttrs) { result.addRegion(); @@ -1633,23 +1627,23 @@ } // Convert inputs to LLVM types, exit early on error. - SmallVector llvmInputs; + SmallVector llvmInputs; for (auto t : inputs) { - auto llvmTy = t.dyn_cast(); - if (!llvmTy) { + if (!isCompatibleType(t)) { parser.emitError(loc, "failed to construct function type: expected LLVM " "type for function arguments"); return {}; } - llvmInputs.push_back(llvmTy); + llvmInputs.push_back(t); } // No output is denoted as "void" in LLVM type system. - LLVMType llvmOutput = outputs.empty() ? LLVMVoidType::get(b.getContext()) - : outputs.front().dyn_cast(); - if (!llvmOutput) { + Type llvmOutput = + outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front(); + if (!isCompatibleType(llvmOutput)) { parser.emitError(loc, "failed to construct function type: expected LLVM " - "type for function results"); + "type for function results") + << llvmOutput; return {}; } return LLVMFunctionType::get(llvmOutput, llvmInputs, @@ -1720,7 +1714,7 @@ for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) argTypes.push_back(fnType.getParamType(i)); - LLVMType returnType = fnType.getReturnType(); + Type returnType = fnType.getReturnType(); if (!returnType.isa()) resTypes.push_back(returnType); @@ -1792,11 +1786,10 @@ Block &entryBlock = op.front(); for (unsigned i = 0; i < numArguments; ++i) { Type argType = entryBlock.getArgument(i).getType(); - auto argLLVMType = argType.dyn_cast(); - if (!argLLVMType) + if (!isCompatibleType(argType)) return op.emitOpError("entry block argument #") << i << " is not of LLVM type"; - if (op.getType().getParamType(i) != argLLVMType) + if (op.getType().getParamType(i) != argType) return op.emitOpError("the type of entry block argument #") << i << " does not match the function signature"; } @@ -1889,7 +1882,7 @@ // attribute-dict? `:` type static ParseResult parseAtomicRMWOp(OpAsmParser &parser, OperationState &result) { - LLVMType type; + Type type; OpAsmParser::OperandType ptr, val; if (parseAtomicBinOp(parser, result, "bin_op") || parser.parseOperand(ptr) || parser.parseComma() || parser.parseOperand(val) || @@ -1907,11 +1900,11 @@ static LogicalResult verify(AtomicRMWOp op) { auto ptrType = op.ptr().getType().cast(); - auto valType = op.val().getType().cast(); + auto valType = op.val().getType(); if (valType != ptrType.getElementType()) return op.emitOpError("expected LLVM IR element type for operand #0 to " "match type for operand #1"); - auto resType = op.res().getType().cast(); + auto resType = op.res().getType(); if (resType != valType) return op.emitOpError( "expected LLVM IR result type to match type for operand #1"); @@ -1954,7 +1947,7 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); - LLVMType type; + Type type; OpAsmParser::OperandType ptr, cmp, val; if (parser.parseOperand(ptr) || parser.parseComma() || parser.parseOperand(cmp) || parser.parseComma() || @@ -1981,8 +1974,8 @@ auto ptrType = op.ptr().getType().cast(); if (!ptrType) return op.emitOpError("expected LLVM IR pointer type for operand #0"); - auto cmpType = op.cmp().getType().cast(); - auto valType = op.val().getType().cast(); + auto cmpType = op.cmp().getType(); + auto valType = op.val().getType(); if (cmpType != ptrType.getElementType() || cmpType != valType) return op.emitOpError("expected LLVM IR element type for operand #0 to " "match type for all other operands"); @@ -2088,7 +2081,7 @@ /// Print a type registered to this dialect. void LLVMDialect::printType(Type type, DialectAsmPrinter &os) const { - return detail::printType(type.cast(), os); + return detail::printType(type, os); } LogicalResult LLVMDialect::verifyDataLayoutString( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -19,11 +19,11 @@ // Printing. //===----------------------------------------------------------------------===// -static void printTypeImpl(llvm::raw_ostream &os, LLVMType type, +static void printTypeImpl(llvm::raw_ostream &os, Type type, llvm::SetVector &stack); /// Returns the keyword to use for the given type. -static StringRef getTypeKeyword(LLVMType type) { +static StringRef getTypeKeyword(Type type) { return TypeSwitch(type) .Case([&](Type) { return "void"; }) .Case([&](Type) { return "half"; }) @@ -64,7 +64,7 @@ os << '('; if (type.isIdentified()) stack.insert(type.getName()); - llvm::interleaveComma(type.getBody(), os, [&](LLVMType subtype) { + llvm::interleaveComma(type.getBody(), os, [&](Type subtype) { printTypeImpl(os, subtype, stack); }); if (type.isIdentified()) @@ -109,9 +109,9 @@ os << '<'; printTypeImpl(os, funcType.getReturnType(), stack); os << " ("; - llvm::interleaveComma( - funcType.getParams(), os, - [&os, &stack](LLVMType subtype) { printTypeImpl(os, subtype, stack); }); + llvm::interleaveComma(funcType.getParams(), os, [&os, &stack](Type subtype) { + printTypeImpl(os, subtype, stack); + }); if (funcType.isVarArg()) { if (funcType.getNumParams() != 0) os << ", "; @@ -129,7 +129,7 @@ /// struct<"c", (ptr>)>>, /// ptr>)>>)> /// note that "b" is printed twice. -static void printTypeImpl(llvm::raw_ostream &os, LLVMType type, +static void printTypeImpl(llvm::raw_ostream &os, Type type, llvm::SetVector &stack) { if (!type) { os << "<>"; @@ -171,7 +171,7 @@ return printFunctionType(os, funcType, stack); } -void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) { +void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) { llvm::SetVector stack; return printTypeImpl(printer.getStream(), type, stack); } @@ -180,13 +180,13 @@ // Parsing. //===----------------------------------------------------------------------===// -static LLVMType parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack); +static Type parseTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack); /// Helper to be chained with other parsing functions. static ParseResult parseTypeImpl(DialectAsmParser &parser, llvm::SetVector &stack, - LLVMType &result) { + Type &result) { result = parseTypeImpl(parser, stack); return success(result != nullptr); } @@ -196,7 +196,7 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser, llvm::SetVector &stack) { Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); - LLVMType returnType; + Type returnType; if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) || parser.parseLParen()) return LLVMFunctionType(); @@ -210,7 +210,7 @@ } // Parse arguments. - SmallVector argTypes; + SmallVector argTypes; do { if (succeeded(parser.parseOptionalEllipsis())) { if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) @@ -235,7 +235,7 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser, llvm::SetVector &stack) { Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); - LLVMType elementType; + Type elementType; if (parser.parseLess() || parseTypeImpl(parser, stack, elementType)) return LLVMPointerType(); @@ -255,7 +255,7 @@ llvm::SetVector &stack) { SmallVector dims; llvm::SMLoc dimPos; - LLVMType elementType; + Type elementType; Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || parser.parseDimensionList(dims, /*allowDynamic=*/true) || @@ -286,7 +286,7 @@ llvm::SetVector &stack) { SmallVector dims; llvm::SMLoc sizePos; - LLVMType elementType; + Type elementType; Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || parser.parseDimensionList(dims, /*allowDynamic=*/false) || @@ -305,11 +305,11 @@ /// error at `subtypesLoc` in case of failure, uses `stack` to make sure the /// types printed in the error message look like they did when parsed. static LLVMStructType trySetStructBody(LLVMStructType type, - ArrayRef subtypes, - bool isPacked, DialectAsmParser &parser, + ArrayRef subtypes, bool isPacked, + DialectAsmParser &parser, llvm::SMLoc subtypesLoc, llvm::SetVector &stack) { - for (LLVMType t : subtypes) { + for (Type t : subtypes) { if (!LLVMStructType::isValidElementType(t)) { parser.emitError(subtypesLoc) << "invalid LLVM structure element type: " << t; @@ -389,12 +389,12 @@ // Parse subtypes. For identified structs, put the identifier of the struct on // the stack to support self-references in the recursive calls. - SmallVector subtypes; + SmallVector subtypes; llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); do { if (isIdentified) stack.insert(name); - LLVMType type = parseTypeImpl(parser, stack); + Type type = parseTypeImpl(parser, stack); if (!type) return LLVMStructType(); subtypes.push_back(type); @@ -413,8 +413,8 @@ } /// Parses one of the LLVM dialect types. -static LLVMType parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack) { +static Type parseTypeImpl(DialectAsmParser &parser, + llvm::SetVector &stack) { // Special case for integers (i[1-9][0-9]*) that are literals rather than // keywords for the parser, so they are not caught by the main dispatch below. // Try parsing it a built-in integer type instead. @@ -425,11 +425,11 @@ OptionalParseResult result = parser.parseOptionalType(maybeIntegerType); if (result.hasValue()) { if (failed(*result)) - return LLVMType(); + return Type(); if (!maybeIntegerType.isSignlessInteger()) { parser.emitError(keyLoc) << "unexpected type, expected i* or keyword"; - return LLVMType(); + return Type(); } return LLVMIntegerType::getChecked( loc, maybeIntegerType.getIntOrFloatBitWidth()); @@ -438,9 +438,9 @@ // Dispatch to concrete functions. StringRef key; if (failed(parser.parseKeyword(&key))) - return LLVMType(); + return Type(); - return StringSwitch>(key) + return StringSwitch>(key) .Case("void", [&] { return LLVMVoidType::get(ctx); }) .Case("half", [&] { return LLVMHalfType::get(ctx); }) .Case("bfloat", [&] { return LLVMBFloatType::get(ctx); }) @@ -460,11 +460,11 @@ .Case("struct", [&] { return parseStructType(parser, stack); }) .Default([&] { parser.emitError(keyLoc) << "unknown LLVM type: " << key; - return LLVMType(); + return Type(); })(); } -LLVMType mlir::LLVM::detail::parseType(DialectAsmParser &parser) { +Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) { llvm::SetVector stack; return parseTypeImpl(parser, stack); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -24,44 +24,32 @@ using namespace mlir; using namespace mlir::LLVM; -//===----------------------------------------------------------------------===// -// LLVMType. -//===----------------------------------------------------------------------===// - -bool LLVMType::classof(Type type) { - return llvm::isa(type.getDialect()); -} - -LLVMDialect &LLVMType::getDialect() { - return static_cast(Type::getDialect()); -} - //===----------------------------------------------------------------------===// // Array type. //===----------------------------------------------------------------------===// -bool LLVMArrayType::isValidElementType(LLVMType type) { +bool LLVMArrayType::isValidElementType(Type type) { return !type.isa(); } -LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) { +LLVMArrayType LLVMArrayType::get(Type elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); return Base::get(elementType.getContext(), elementType, numElements); } -LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType, +LLVMArrayType LLVMArrayType::getChecked(Location loc, Type elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); return Base::getChecked(loc, elementType, numElements); } -LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; } +Type LLVMArrayType::getElementType() { return getImpl()->elementType; } unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; } LogicalResult -LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType, +LLVMArrayType::verifyConstructionInvariants(Location loc, Type elementType, unsigned numElements) { if (!isValidElementType(elementType)) return emitError(loc, "invalid array element type: ") << elementType; @@ -72,52 +60,50 @@ // Function type. //===----------------------------------------------------------------------===// -bool LLVMFunctionType::isValidArgumentType(LLVMType type) { +bool LLVMFunctionType::isValidArgumentType(Type type) { return !type.isa(); } -bool LLVMFunctionType::isValidResultType(LLVMType type) { +bool LLVMFunctionType::isValidResultType(Type type) { return !type.isa(); } -LLVMFunctionType LLVMFunctionType::get(LLVMType result, - ArrayRef arguments, +LLVMFunctionType LLVMFunctionType::get(Type result, ArrayRef arguments, bool isVarArg) { assert(result && "expected non-null result"); return Base::get(result.getContext(), result, arguments, isVarArg); } -LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result, - ArrayRef arguments, +LLVMFunctionType LLVMFunctionType::getChecked(Location loc, Type result, + ArrayRef arguments, bool isVarArg) { assert(result && "expected non-null result"); return Base::getChecked(loc, result, arguments, isVarArg); } -LLVMType LLVMFunctionType::getReturnType() { - return getImpl()->getReturnType(); -} +Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); } unsigned LLVMFunctionType::getNumParams() { return getImpl()->getArgumentTypes().size(); } -LLVMType LLVMFunctionType::getParamType(unsigned i) { +Type LLVMFunctionType::getParamType(unsigned i) { return getImpl()->getArgumentTypes()[i]; } bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); } -ArrayRef LLVMFunctionType::getParams() { +ArrayRef LLVMFunctionType::getParams() { return getImpl()->getArgumentTypes(); } -LogicalResult LLVMFunctionType::verifyConstructionInvariants( - Location loc, LLVMType result, ArrayRef arguments, bool) { +LogicalResult +LLVMFunctionType::verifyConstructionInvariants(Location loc, Type result, + ArrayRef arguments, bool) { if (!isValidResultType(result)) return emitError(loc, "invalid function result type: ") << result; - for (LLVMType arg : arguments) + for (Type arg : arguments) if (!isValidArgumentType(arg)) return emitError(loc, "invalid function argument type: ") << arg; @@ -150,27 +136,27 @@ // Pointer type. //===----------------------------------------------------------------------===// -bool LLVMPointerType::isValidElementType(LLVMType type) { +bool LLVMPointerType::isValidElementType(Type type) { return !type.isa(); } -LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) { +LLVMPointerType LLVMPointerType::get(Type pointee, unsigned addressSpace) { assert(pointee && "expected non-null subtype"); return Base::get(pointee.getContext(), pointee, addressSpace); } -LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee, +LLVMPointerType LLVMPointerType::getChecked(Location loc, Type pointee, unsigned addressSpace) { return Base::getChecked(loc, pointee, addressSpace); } -LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; } +Type LLVMPointerType::getElementType() { return getImpl()->pointeeType; } unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; } LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc, - LLVMType pointee, + Type pointee, unsigned) { if (!isValidElementType(pointee)) return emitError(loc, "invalid pointer element type: ") << pointee; @@ -181,7 +167,7 @@ // Struct type. //===----------------------------------------------------------------------===// -bool LLVMStructType::isValidElementType(LLVMType type) { +bool LLVMStructType::isValidElementType(Type type) { return !type.isa(); } @@ -198,7 +184,7 @@ LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context, StringRef name, - ArrayRef elements, + ArrayRef elements, bool isPacked) { std::string stringName = name.str(); unsigned counter = 0; @@ -214,13 +200,12 @@ } LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, - ArrayRef types, - bool isPacked) { + ArrayRef types, bool isPacked) { return Base::get(context, types, isPacked); } LLVMStructType LLVMStructType::getLiteralChecked(Location loc, - ArrayRef types, + ArrayRef types, bool isPacked) { return Base::getChecked(loc, types, isPacked); } @@ -233,7 +218,7 @@ return Base::getChecked(loc, name, /*opaque=*/true); } -LogicalResult LLVMStructType::setBody(ArrayRef types, bool isPacked) { +LogicalResult LLVMStructType::setBody(ArrayRef types, bool isPacked) { assert(isIdentified() && "can only set bodies of identified structs"); assert(llvm::all_of(types, LLVMStructType::isValidElementType) && "expected valid body types"); @@ -248,7 +233,7 @@ } bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); } StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); } -ArrayRef LLVMStructType::getBody() { +ArrayRef LLVMStructType::getBody() { return isIdentified() ? getImpl()->getIdentifiedStructBody() : getImpl()->getTypeList(); } @@ -258,10 +243,10 @@ return success(); } -LogicalResult -LLVMStructType::verifyConstructionInvariants(Location loc, - ArrayRef types, bool) { - for (LLVMType t : types) +LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc, + ArrayRef types, + bool) { + for (Type t : types) if (!isValidElementType(t)) return emitError(loc, "invalid LLVM structure element type: ") << t; @@ -272,7 +257,7 @@ // Vector types. //===----------------------------------------------------------------------===// -bool LLVMVectorType::isValidElementType(LLVMType type) { +bool LLVMVectorType::isValidElementType(Type type) { return type.isa() || mlir::LLVM::isCompatibleFloatingPointType(type); } @@ -282,7 +267,7 @@ return type.isa(); } -LLVMType LLVMVectorType::getElementType() { +Type LLVMVectorType::getElementType() { // Both derived classes share the implementation type. return static_cast(impl)->elementType; } @@ -296,7 +281,7 @@ /// Verifies that the type about to be constructed is well-formed. LogicalResult -LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType, +LLVMVectorType::verifyConstructionInvariants(Location loc, Type elementType, unsigned numElements) { if (numElements == 0) return emitError(loc, "the number of vector elements must be positive"); @@ -307,14 +292,14 @@ return success(); } -LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType, +LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); return Base::get(elementType.getContext(), elementType, numElements); } LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc, - LLVMType elementType, + Type elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); return Base::getChecked(loc, elementType, numElements); @@ -324,14 +309,14 @@ return getImpl()->numElements; } -LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType, +LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType, unsigned minNumElements) { assert(elementType && "expected non-null subtype"); return Base::get(elementType.getContext(), elementType, minNumElements); } LLVMScalableVectorType -LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType, +LLVMScalableVectorType::getChecked(Location loc, Type elementType, unsigned minNumElements) { assert(elementType && "expected non-null subtype"); return Base::getChecked(loc, elementType, minNumElements); @@ -351,16 +336,16 @@ return llvm::TypeSwitch(type) .Case( - [](LLVMType) { return llvm::TypeSize::Fixed(16); }) - .Case([](LLVMType) { return llvm::TypeSize::Fixed(32); }) + [](Type) { return llvm::TypeSize::Fixed(16); }) + .Case([](Type) { return llvm::TypeSize::Fixed(32); }) .Case( - [](LLVMType) { return llvm::TypeSize::Fixed(64); }) + [](Type) { return llvm::TypeSize::Fixed(64); }) .Case([](LLVMIntegerType intTy) { return llvm::TypeSize::Fixed(intTy.getBitWidth()); }) - .Case([](LLVMType) { return llvm::TypeSize::Fixed(80); }) + .Case([](Type) { return llvm::TypeSize::Fixed(80); }) .Case( - [](LLVMType) { return llvm::TypeSize::Fixed(128); }) + [](Type) { return llvm::TypeSize::Fixed(128); }) .Case([](LLVMVectorType t) { llvm::TypeSize elementSize = getPrimitiveTypeSizeInBits(t.getElementType()); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -53,19 +53,18 @@ parser.addTypeToList(resultType, result.types)) return failure(); - auto type = resultType.cast(); for (auto &attr : result.attributes) { if (attr.first != "return_value_and_is_valid") continue; - auto structType = type.dyn_cast(); + auto structType = resultType.dyn_cast(); if (structType && !structType.getBody().empty()) - type = structType.getBody()[0]; + resultType = structType.getBody()[0]; break; } auto int32Ty = LLVM::LLVMIntegerType::get(parser.getBuilder().getContext(), 32); - return parser.resolveOperands(ops, {int32Ty, type, int32Ty, int32Ty}, + return parser.resolveOperands(ops, {int32Ty, resultType, int32Ty, int32Ty}, parser.getNameLoc(), result.operands); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h --- a/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h +++ b/mlir/lib/Dialect/LLVMIR/IR/TypeDetail.h @@ -72,7 +72,7 @@ Key(StringRef name, bool opaque) : name(name), identified(true), packed(false), opaque(opaque) {} /// Constructs a key for a literal struct. - Key(ArrayRef types, bool packed) + Key(ArrayRef types, bool packed) : types(types), identified(false), packed(packed), opaque(false) {} /// Checks a specific property of the struct. @@ -96,7 +96,7 @@ } /// Returns the list of type contained in the key of a literal struct. - ArrayRef getTypeList() const { + ArrayRef getTypeList() const { assert(!isIdentified() && "identified struct key cannot have a type list"); return types; @@ -138,7 +138,7 @@ } private: - ArrayRef types; + ArrayRef types; StringRef name; bool identified; bool packed; @@ -153,18 +153,18 @@ } /// Returns the list of types (partially) identifying a literal struct. - ArrayRef getTypeList() const { + ArrayRef getTypeList() const { // If this triggers, use getIdentifiedStructBody() instead. assert(!isIdentified() && "requested typelist on an identified struct"); - return ArrayRef(static_cast(keyPtr), keySize()); + return ArrayRef(static_cast(keyPtr), keySize()); } /// Returns the list of types contained in an identified struct. - ArrayRef getIdentifiedStructBody() const { + ArrayRef getIdentifiedStructBody() const { // If this triggers, use getTypeList() instead. assert(isIdentified() && "requested struct body on a non-identified struct"); - return ArrayRef(identifiedBodyArray, identifiedBodySize()); + return ArrayRef(identifiedBodyArray, identifiedBodySize()); } /// Checks whether the struct is identified. @@ -199,7 +199,7 @@ /// as initialized and can no longer be mutated. LLVMStructTypeStorage(const KeyTy &key) { if (!key.isIdentified()) { - ArrayRef types = key.getTypeList(); + ArrayRef types = key.getTypeList(); keyPtr = static_cast(types.data()); setKeySize(types.size()); llvm::Bitfield::set(keySizeAndFlags, key.isPacked()); @@ -232,7 +232,7 @@ /// initialized, succeeds only if the body is equal to the current body. Fails /// if the struct is marked as intentionally opaque. The struct will be marked /// as initialized as a result of this operation and can no longer be changed. - LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef body, + LogicalResult mutate(TypeStorageAllocator &allocator, ArrayRef body, bool packed) { if (!isIdentified()) return failure(); @@ -244,7 +244,7 @@ true); llvm::Bitfield::set(identifiedBodySizeAndFlags, packed); - ArrayRef typesInAllocator = allocator.copyInto(body); + ArrayRef typesInAllocator = allocator.copyInto(body); identifiedBodyArray = typesInAllocator.data(); setIdentifiedBodySize(typesInAllocator.size()); @@ -310,7 +310,7 @@ const void *keyPtr = nullptr; /// Pointer to the first type contained in an identified struct. - const LLVMType *identifiedBodyArray = nullptr; + const Type *identifiedBodyArray = nullptr; /// Size of the uniquing key combined with identified/literal and /// packedness bits. Must only be used through the Key* bitfields. @@ -328,12 +328,11 @@ /// Type storage for LLVM dialect function types. These are uniqued using the /// list of types they contain and the vararg bit. struct LLVMFunctionTypeStorage : public TypeStorage { - using KeyTy = std::tuple, bool>; + using KeyTy = std::tuple, bool>; /// Construct a storage from the given components. The list is expected to be /// allocated in the context. - LLVMFunctionTypeStorage(LLVMType result, ArrayRef arguments, - bool variadic) + LLVMFunctionTypeStorage(Type result, ArrayRef arguments, bool variadic) : argumentTypes(arguments) { returnTypeAndVariadic.setPointerAndInt(result, variadic); } @@ -359,19 +358,19 @@ } /// Returns the list of function argument types. - ArrayRef getArgumentTypes() const { return argumentTypes; } + ArrayRef getArgumentTypes() const { return argumentTypes; } /// Checks whether the function type is variadic. bool isVariadic() const { return returnTypeAndVariadic.getInt(); } /// Returns the function result type. - LLVMType getReturnType() const { return returnTypeAndVariadic.getPointer(); } + Type getReturnType() const { return returnTypeAndVariadic.getPointer(); } private: /// Function result type packed with the variadic bit. - llvm::PointerIntPair returnTypeAndVariadic; + llvm::PointerIntPair returnTypeAndVariadic; /// Argument types. - ArrayRef argumentTypes; + ArrayRef argumentTypes; }; //===----------------------------------------------------------------------===// @@ -402,7 +401,7 @@ /// Storage type for LLVM dialect pointer types. These are uniqued by a pair of /// element type and address space. struct LLVMPointerTypeStorage : public TypeStorage { - using KeyTy = std::tuple; + using KeyTy = std::tuple; LLVMPointerTypeStorage(const KeyTy &key) : pointeeType(std::get<0>(key)), addressSpace(std::get<1>(key)) {} @@ -417,7 +416,7 @@ return std::make_tuple(pointeeType, addressSpace) == key; } - LLVMType pointeeType; + Type pointeeType; unsigned addressSpace; }; @@ -429,7 +428,7 @@ /// number: arrays, fixed and scalable vectors. The actual semantics of the /// type is defined by its kind. struct LLVMTypeAndSizeStorage : public TypeStorage { - using KeyTy = std::tuple; + using KeyTy = std::tuple; LLVMTypeAndSizeStorage(const KeyTy &key) : elementType(std::get<0>(key)), numElements(std::get<1>(key)) {} @@ -444,7 +443,7 @@ return std::make_tuple(elementType, numElements) == key; } - LLVMType elementType; + Type elementType; unsigned numElements; }; diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -68,8 +68,8 @@ LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block); /// Imports `inst` and populates instMap[inst] with the imported Value. LogicalResult processInstruction(llvm::Instruction *inst); - /// Creates an LLVMType for `type`. - LLVMType processType(llvm::Type *type); + /// Creates an LLVM-compatible MLIR type for `type`. + Type processType(llvm::Type *type); /// `value` is an SSA-use. Return the remapped version of `value` or a /// placeholder that will be remapped later if this is an instruction that /// has not yet been visited. @@ -87,7 +87,7 @@ SmallVectorImpl &blockArguments); /// Returns the builtin type equivalent to be used in attributes for the given /// LLVM IR dialect type. - Type getStdTypeForAttr(LLVMType type); + Type getStdTypeForAttr(Type type); /// Return `value` as an attribute to attach to a GlobalOp. Attribute getConstantAsAttr(llvm::Constant *value); /// Return `c` as an MLIR Value. This could either be a ConstantOp, or @@ -150,8 +150,8 @@ context); } -LLVMType Importer::processType(llvm::Type *type) { - if (LLVMType result = typeTranslator.translateType(type)) +Type Importer::processType(llvm::Type *type) { + if (Type result = typeTranslator.translateType(type)) return result; // FIXME: Diagnostic should be able to natively handle types that have @@ -168,7 +168,7 @@ // equivalents. Array types are converted to ranked tensors; nested array types // are converted to multi-dimensional tensors or vectors, depending on the // innermost type being a scalar or a vector. -Type Importer::getStdTypeForAttr(LLVMType type) { +Type Importer::getStdTypeForAttr(Type type) { if (!type) return nullptr; @@ -252,7 +252,7 @@ // Convert constant data to a dense elements attribute. if (auto *cd = dyn_cast(value)) { - LLVMType type = processType(cd->getElementType()); + Type type = processType(cd->getElementType()); if (!type) return nullptr; @@ -315,7 +315,7 @@ Attribute valueAttr; if (GV->hasInitializer()) valueAttr = getConstantAsAttr(GV->getInitializer()); - LLVMType type = processType(GV->getValueType()); + Type type = processType(GV->getValueType()); if (!type) return nullptr; GlobalOp op = b.create( @@ -338,7 +338,7 @@ if (Attribute attr = getConstantAsAttr(c)) { // These constants can be represented as attributes. OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); - LLVMType type = processType(c->getType()); + Type type = processType(c->getType()); if (!type) return nullptr; if (auto symbolRef = attr.dyn_cast()) @@ -347,7 +347,7 @@ return instMap[c] = bEntry.create(unknownLoc, type, attr); } if (auto *cn = dyn_cast(c)) { - LLVMType type = processType(cn->getType()); + Type type = processType(cn->getType()); if (!type) return nullptr; return instMap[c] = bEntry.create(unknownLoc, type); @@ -370,7 +370,7 @@ return instMap[c] = instMap[i]; } if (auto *ue = dyn_cast(c)) { - LLVMType type = processType(ue->getType()); + Type type = processType(ue->getType()); if (!type) return nullptr; return instMap[c] = bEntry.create(UnknownLoc::get(context), type); @@ -388,7 +388,7 @@ // this instruction yet, create an unknown op and remap it later. if (isa(value)) { OperationState state(UnknownLoc::get(context), "llvm.unknown"); - LLVMType type = processType(value->getType()); + Type type = processType(value->getType()); if (!type) return nullptr; state.addTypes(type); @@ -578,7 +578,7 @@ } state.addOperands(ops); if (!inst->getType()->isVoidTy()) { - LLVMType type = processType(inst->getType()); + Type type = processType(inst->getType()); if (!type) return failure(); state.addTypes(type); @@ -629,7 +629,7 @@ return success(); } case llvm::Instruction::PHI: { - LLVMType type = processType(inst->getType()); + Type type = processType(inst->getType()); if (!type) return failure(); v = b.getInsertionBlock()->addArgument(type); @@ -648,7 +648,7 @@ SmallVector tys; if (!ci->getType()->isVoidTy()) { - LLVMType type = processType(inst->getType()); + Type type = processType(inst->getType()); if (!type) return failure(); tys.push_back(type); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -762,17 +762,17 @@ if (auto inlineAsmOp = dyn_cast(opInst)) { // TODO: refactor function type creation which usually occurs in std-LLVM // conversion. - SmallVector operandTypes; + SmallVector operandTypes; operandTypes.reserve(inlineAsmOp.operands().size()); for (auto t : inlineAsmOp.operands().getTypes()) - operandTypes.push_back(t.cast()); + operandTypes.push_back(t); - LLVM::LLVMType resultType; + Type resultType; if (inlineAsmOp.getNumResults() == 0) { resultType = LLVM::LLVMVoidType::get(mlirModule->getContext()); } else { assert(inlineAsmOp.getNumResults() == 1); - resultType = inlineAsmOp.getResultTypes()[0].cast(); + resultType = inlineAsmOp.getResultTypes()[0]; } auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes); llvm::InlineAsm *inlineAsmInst = @@ -813,7 +813,7 @@ } if (auto lpOp = dyn_cast(opInst)) { - llvm::Type *ty = convertType(lpOp.getType().cast()); + llvm::Type *ty = convertType(lpOp.getType()); llvm::LandingPadInst *lpi = builder.CreateLandingPad(ty, lpOp.getNumOperands()); @@ -872,8 +872,8 @@ blockMapping[switchOp.defaultDestination()], switchOp.caseDestinations().size(), branchWeights); - auto *ty = llvm::cast( - convertType(switchOp.value().getType().cast())); + auto *ty = + llvm::cast(convertType(switchOp.value().getType())); for (auto i : llvm::zip(switchOp.case_values()->cast(), switchOp.caseDestinations())) @@ -927,8 +927,8 @@ unsigned numPredecessors = std::distance(predecessors.begin(), predecessors.end()); for (auto arg : bb.getArguments()) { - auto wrappedType = arg.getType().dyn_cast(); - if (!wrappedType) + auto wrappedType = arg.getType(); + if (!isCompatibleType(wrappedType)) return emitError(bb.front().getLoc(), "block argument does not have an LLVM type"); llvm::Type *type = convertType(wrappedType); @@ -1094,7 +1094,7 @@ argIdx, LLVMDialect::getNoAliasAttrName())) { // NB: Attribute already verified to be boolean, so check if we can indeed // attach the attribute to this argument, based on its type. - auto argTy = mlirArg.getType().dyn_cast(); + auto argTy = mlirArg.getType(); if (!argTy.isa()) return func.emitError( "llvm.noalias attribute attached to LLVM non-pointer argument"); @@ -1106,7 +1106,7 @@ argIdx, LLVMDialect::getAlignAttrName())) { // NB: Attribute already verified to be int, so check if we can indeed // attach the attribute to this argument, based on its type. - auto argTy = mlirArg.getType().dyn_cast(); + auto argTy = mlirArg.getType(); if (!argTy.isa()) return func.emitError( "llvm.align attribute attached to LLVM non-pointer argument"); @@ -1190,7 +1190,7 @@ return success(); } -llvm::Type *ModuleTranslation::convertType(LLVMType type) { +llvm::Type *ModuleTranslation::convertType(Type type) { return typeTranslator.translateType(type); } diff --git a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp --- a/mlir/lib/Target/LLVMIR/TypeTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/TypeTranslation.cpp @@ -27,14 +27,14 @@ TypeToLLVMIRTranslatorImpl(llvm::LLVMContext &context) : context(context) {} /// Translates a single type. - llvm::Type *translateType(LLVM::LLVMType type) { + llvm::Type *translateType(Type type) { // If the conversion is already known, just return it. if (knownTranslations.count(type)) return knownTranslations.lookup(type); // Dispatch to an appropriate function. llvm::Type *translated = - llvm::TypeSwitch(type) + llvm::TypeSwitch(type) .Case([this](LLVM::LLVMVoidType) { return llvm::Type::getVoidTy(context); }) @@ -76,7 +76,7 @@ LLVM::LLVMStructType, LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>( [this](auto type) { return this->translate(type); }) - .Default([](LLVM::LLVMType t) -> llvm::Type * { + .Default([](Type t) -> llvm::Type * { llvm_unreachable("unknown LLVM dialect type"); }); @@ -147,7 +147,7 @@ } /// Translates a list of types. - void translateTypes(ArrayRef types, + void translateTypes(ArrayRef types, SmallVectorImpl &result) { result.reserve(result.size() + types.size()); for (auto type : types) @@ -161,7 +161,7 @@ /// results to avoid repeated recursive calls and makes sure identified /// structs with the same name (that is, equal) are resolved to an existing /// type instead of creating a new type. - llvm::DenseMap knownTranslations; + llvm::DenseMap knownTranslations; }; } // end namespace detail } // end namespace LLVM @@ -172,12 +172,12 @@ LLVM::TypeToLLVMIRTranslator::~TypeToLLVMIRTranslator() {} -llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(LLVM::LLVMType type) { +llvm::Type *LLVM::TypeToLLVMIRTranslator::translateType(Type type) { return impl->translateType(type); } unsigned LLVM::TypeToLLVMIRTranslator::getPreferredAlignment( - LLVM::LLVMType type, const llvm::DataLayout &layout) { + Type type, const llvm::DataLayout &layout) { return layout.getPrefTypeAlignment(translateType(type)); } @@ -191,12 +191,12 @@ TypeFromLLVMIRTranslatorImpl(MLIRContext &context) : context(context) {} /// Translates the given type. - LLVM::LLVMType translateType(llvm::Type *type) { + Type translateType(llvm::Type *type) { if (knownTranslations.count(type)) return knownTranslations.lookup(type); - LLVM::LLVMType translated = - llvm::TypeSwitch(type) + Type translated = + llvm::TypeSwitch(type) .Case( @@ -211,7 +211,7 @@ private: /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, /// type. - LLVM::LLVMType translatePrimitiveType(llvm::Type *type) { + Type translatePrimitiveType(llvm::Type *type) { if (type->isVoidTy()) return LLVM::LLVMVoidType::get(&context); if (type->isHalfTy()) @@ -238,33 +238,33 @@ } /// Translates the given array type. - LLVM::LLVMType translate(llvm::ArrayType *type) { + Type translate(llvm::ArrayType *type) { return LLVM::LLVMArrayType::get(translateType(type->getElementType()), type->getNumElements()); } /// Translates the given function type. - LLVM::LLVMType translate(llvm::FunctionType *type) { - SmallVector paramTypes; + Type translate(llvm::FunctionType *type) { + SmallVector paramTypes; translateTypes(type->params(), paramTypes); return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), paramTypes, type->isVarArg()); } /// Translates the given integer type. - LLVM::LLVMType translate(llvm::IntegerType *type) { + Type translate(llvm::IntegerType *type) { return LLVM::LLVMIntegerType::get(&context, type->getBitWidth()); } /// Translates the given pointer type. - LLVM::LLVMType translate(llvm::PointerType *type) { + Type translate(llvm::PointerType *type) { return LLVM::LLVMPointerType::get(translateType(type->getElementType()), type->getAddressSpace()); } /// Translates the given structure type. - LLVM::LLVMType translate(llvm::StructType *type) { - SmallVector subtypes; + Type translate(llvm::StructType *type) { + SmallVector subtypes; if (type->isLiteral()) { translateTypes(type->subtypes(), subtypes); return LLVM::LLVMStructType::getLiteral(&context, subtypes, @@ -286,20 +286,20 @@ } /// Translates the given fixed-vector type. - LLVM::LLVMType translate(llvm::FixedVectorType *type) { + Type translate(llvm::FixedVectorType *type) { return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()), type->getNumElements()); } /// Translates the given scalable-vector type. - LLVM::LLVMType translate(llvm::ScalableVectorType *type) { + Type translate(llvm::ScalableVectorType *type) { return LLVM::LLVMScalableVectorType::get( translateType(type->getElementType()), type->getMinNumElements()); } /// Translates a list of types. void translateTypes(ArrayRef types, - SmallVectorImpl &result) { + SmallVectorImpl &result) { result.reserve(result.size() + types.size()); for (llvm::Type *type : types) result.push_back(translateType(type)); @@ -307,7 +307,7 @@ /// Map of known translations. Serves as a cache and as recursion stopper for /// translating recursive structs. - llvm::DenseMap knownTranslations; + llvm::DenseMap knownTranslations; /// The context in which MLIR types are created. MLIRContext &context; @@ -321,6 +321,6 @@ LLVM::TypeFromLLVMIRTranslator::~TypeFromLLVMIRTranslator() {} -LLVM::LLVMType LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { +Type LLVM::TypeFromLLVMIRTranslator::translateType(llvm::Type *type) { return impl->translateType(type); } diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -163,7 +163,7 @@ // ----- func @call_non_llvm_indirect(%arg0 : i32) { - // expected-error@+1 {{'llvm.call' op operand #0 must be LLVM dialect type, but got 'i32'}} + // expected-error@+1 {{'llvm.call' op operand #0 must be LLVM dialect-compatible type, but got 'i32'}} "llvm.call"(%arg0) : (i32) -> () } diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -135,7 +135,7 @@ } else if (isResultName(op, name)) { bs << formatv("valueMapping[op.{0}()]", name); } else if (name == "_resultType") { - bs << "convertType(op.getResult().getType().cast())"; + bs << "convertType(op.getResult().getType())"; } else if (name == "_hasResult") { bs << "opInst.getNumResults() == 1"; } else if (name == "_location") {