diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -446,7 +446,8 @@ /// Builds IR extracting the pointer to the first element of the size array. static Value sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); + Value memRefDescPtr, + LLVM::LLVMPointerType elemPtrPtrType); /// Builds IR extracting the size[index] from the descriptor. static Value size(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -51,7 +51,7 @@ [{ auto llvmType = resultType.dyn_cast(); (void)llvmType; assert(llvmType && "result must be an LLVM type"); - assert(llvmType.isVoidTy() && + assert(llvmType.isa() && "for zero-result operands, only 'void' is accepted as result type"); build($_builder, $_state, operands, attributes); }]>; @@ -288,7 +288,7 @@ OpBuilderDAG<(ins "Value":$addr, CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile, CArg<"bool", "false">:$isNonTemporal), [{ - auto type = addr.getType().cast().getPointerElementTy(); + auto type = addr.getType().cast().getElementType(); build($_builder, $_state, type, addr, alignment, isVolatile, isNonTemporal); }]>, OpBuilderDAG<(ins "Type":$t, "Value":$addr, @@ -443,8 +443,8 @@ OpBuilderDAG<(ins "LLVMFuncOp":$func, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attributes), [{ - LLVMType resultType = func.getType().getFunctionResultType(); - if (!resultType.isVoidTy()) + LLVMType resultType = func.getType().getReturnType(); + if (!resultType.isa()) $_state.addTypes(resultType); $_state.addAttribute("callee", $_builder.getSymbolRefAttr(func)); $_state.addAttributes(attributes); @@ -515,12 +515,10 @@ OpBuilderDAG<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask, CArg<"ArrayRef", "{}">:$attrs)>]; let verifier = [{ - auto wrappedVectorType1 = v1().getType().cast(); - auto wrappedVectorType2 = v2().getType().cast(); - if (!wrappedVectorType2.isVectorTy()) - return emitOpError("expected LLVM IR Dialect vector type for operand #2"); - if (wrappedVectorType1.getVectorElementType() != - wrappedVectorType2.getVectorElementType()) + auto wrappedVectorType1 = v1().getType().cast(); + auto wrappedVectorType2 = v2().getType().cast(); + if (wrappedVectorType1.getElementType() != + wrappedVectorType2.getElementType()) return emitOpError("expected matching LLVM IR Dialect element types"); return success(); }]; @@ -768,13 +766,13 @@ CArg<"ArrayRef", "{}">:$attrs), [{ build($_builder, $_state, - global.getType().getPointerTo(global.addr_space()), + LLVM::LLVMPointerType::get(global.getType(), global.addr_space()), global.sym_name(), attrs);}]>, OpBuilderDAG<(ins "LLVMFuncOp":$func, CArg<"ArrayRef", "{}">:$attrs), [{ build($_builder, $_state, - func.getType().getPointerTo(), func.getName(), attrs);}]> + LLVM::LLVMPointerType::get(func.getType()), func.getName(), attrs);}]> ]; let extraClassDeclaration = [{ @@ -970,12 +968,12 @@ // to match the signature of the function. Block *addEntryBlock(); - LLVMType getType() { + LLVMFunctionType getType() { return (*this)->getAttrOfType(getTypeAttrName()) - .getValue().cast(); + .getValue().cast(); } bool isVarArg() { - return getType().isFunctionVarArg(); + return getType().isVarArg(); } // Hook for OpTrait::FunctionLike, returns the number of function arguments`. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -80,58 +80,6 @@ LLVMDialect &getDialect(); - /// Returns the size of a primitive type (including vectors) in bits, for - /// example, the size of !llvm.i16 is 16 and the size of !llvm.vec<4 x i16> - /// is 64. Returns 0 for non-primitive (aggregates such as struct) or types - /// that don't have a size (such as void). - llvm::TypeSize getPrimitiveSizeInBits(); - - /// Floating-point type utilities. - bool isBFloatTy() { return isa(); } - bool isHalfTy() { return isa(); } - bool isFloatTy() { return isa(); } - bool isDoubleTy() { return isa(); } - bool isFP128Ty() { return isa(); } - bool isX86_FP80Ty() { return isa(); } - bool isFloatingPointTy() { - return isa() || isa() || - isa() || isa() || - isa() || isa(); - } - - /// Array type utilities. - LLVMType getArrayElementType(); - unsigned getArrayNumElements(); - bool isArrayTy(); - - /// Integer type utilities. - bool isIntegerTy() { return isa(); } - bool isIntegerTy(unsigned bitwidth); - unsigned getIntegerBitWidth(); - - /// Vector type utilities. - LLVMType getVectorElementType(); - unsigned getVectorNumElements(); - llvm::ElementCount getVectorElementCount(); - bool isVectorTy(); - - /// Function type utilities. - LLVMType getFunctionParamType(unsigned argIdx); - unsigned getFunctionNumParams(); - LLVMType getFunctionResultType(); - bool isFunctionTy(); - bool isFunctionVarArg(); - - /// Pointer type utilities. - LLVMType getPointerTo(unsigned addrSpace = 0); - LLVMType getPointerElementTy(); - bool isPointerTy(); - - /// Struct type utilities. - LLVMType getStructElementType(unsigned i); - unsigned getStructNumElements(); - bool isStructTy(); - /// Utilities used to generate floating point types. static LLVMType getDoubleTy(MLIRContext *context); static LLVMType getFloatTy(MLIRContext *context); @@ -148,9 +96,7 @@ static LLVMType getInt8Ty(MLIRContext *context) { return getIntNTy(context, /*numBits=*/8); } - static LLVMType getInt8PtrTy(MLIRContext *context) { - return getInt8Ty(context).getPointerTo(); - } + static LLVMType getInt8PtrTy(MLIRContext *context); static LLVMType getInt16Ty(MLIRContext *context) { return getIntNTy(context, /*numBits=*/16); } @@ -184,7 +130,6 @@ /// Void type utilities. static LLVMType getVoidTy(MLIRContext *context); - bool isVoidTy(); // Creation and setting of LLVM's identified struct types static LLVMType createStructTy(MLIRContext *context, @@ -585,6 +530,24 @@ void printType(LLVMType type, DialectAsmPrinter &printer); } // namespace detail +//===----------------------------------------------------------------------===// +// Utility functions. +//===----------------------------------------------------------------------===// + +/// Returns `true` if the given type is compatible with the LLVM dialect. +inline bool isCompatibleType(Type type) { return type.isa(); } + +inline bool isCompatibleFloatingPointType(Type type) { + return type.isa(); +} + +/// Returns the size of the given primitive LLVM dialect-compatible type +/// (including vectors) in bits, for example, the size of !llvm.i16 is 16 and +/// the size of !llvm.vec<4 x i16> is 64. Returns 0 for non-primitive +/// (aggregates such as struct) or types that don't have a size (such as void). +llvm::TypeSize getPrimitiveTypeSizeInBits(Type type); + } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -109,10 +109,11 @@ let verifier = [{ if (!(*this)->getAttrOfType("return_value_and_is_valid")) return success(); - auto type = getType().cast(); - if (!type.isStructTy() || type.getStructNumElements() != 2 || - !type.getStructElementType(1).isIntegerTy( - /*Bitwidth=*/1)) + auto type = getType().dyn_cast(); + auto elementType = (type && type.getBody().size() == 2) + ? type.getBody()[1].dyn_cast() + : nullptr; + if (!elementType || elementType.getBitWidth() != 1) return emitError("expected return type to be a two-element struct with " "i1 as the second element"); return success(); diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -79,7 +79,7 @@ static FunctionType executeFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); - auto resume = resumeFunctionType(ctx).getPointerTo(); + auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {hdl, resume}, {}); } @@ -91,13 +91,13 @@ static FunctionType awaitAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); - auto resume = resumeFunctionType(ctx).getPointerTo(); + auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {TokenType::get(ctx), hdl, resume}, {}); } static FunctionType awaitAllAndExecuteFunctionType(MLIRContext *ctx) { auto hdl = LLVM::LLVMType::getInt8PtrTy(ctx); - auto resume = resumeFunctionType(ctx).getPointerTo(); + auto resume = LLVM::LLVMPointerType::get(resumeFunctionType(ctx)); return FunctionType::get(ctx, {GroupType::get(ctx), hdl, resume}, {}); } @@ -507,7 +507,7 @@ // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); auto resumePtr = builder.create( - loc, resumeFnTy.getPointerTo(), kResume); + loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume); // Save the coroutine state: @llvm.coro.save auto coroSave = builder.create( @@ -750,7 +750,7 @@ // A pointer to coroutine resume intrinsic wrapper. auto resumeFnTy = AsyncAPI::resumeFunctionType(ctx); auto resumePtr = builder.create( - loc, resumeFnTy.getPointerTo(), kResume); + loc, LLVM::LLVMPointerType::get(resumeFnTy), kResume); // Save the coroutine state: @llvm.coro.save auto coroSave = builder.create( diff --git a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp --- a/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp @@ -55,14 +55,14 @@ FunctionCallBuilder(StringRef functionName, LLVM::LLVMType returnType, ArrayRef argumentTypes) : functionName(functionName), - functionType(LLVM::LLVMType::getFunctionTy(returnType, argumentTypes, - /*isVarArg=*/false)) {} + functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes, + /*isVarArg=*/false)) {} LLVM::CallOp create(Location loc, OpBuilder &builder, ArrayRef arguments) const; private: StringRef functionName; - LLVM::LLVMType functionType; + LLVM::LLVMFunctionType functionType; }; template @@ -76,7 +76,8 @@ LLVM::LLVMType llvmVoidType = LLVM::LLVMType::getVoidTy(context); LLVM::LLVMType llvmPointerType = LLVM::LLVMType::getInt8PtrTy(context); - LLVM::LLVMType llvmPointerPointerType = llvmPointerType.getPointerTo(); + LLVM::LLVMType llvmPointerPointerType = + LLVM::LLVMPointerType::get(llvmPointerType); LLVM::LLVMType llvmInt8Type = LLVM::LLVMType::getInt8Ty(context); LLVM::LLVMType llvmInt32Type = LLVM::LLVMType::getInt32Ty(context); LLVM::LLVMType llvmInt64Type = LLVM::LLVMType::getInt64Ty(context); @@ -292,7 +293,7 @@ .create(loc, functionName, functionType); }(); return builder.create( - loc, const_cast(functionType).getFunctionResultType(), + loc, const_cast(functionType).getReturnType(), builder.getSymbolRefAttr(function), arguments); } @@ -498,7 +499,7 @@ auto one = builder.create(loc, llvmInt32Type, builder.getI32IntegerAttr(1)); auto structPtr = builder.create( - loc, structType.getPointerTo(), one, /*alignment=*/0); + loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0); auto arraySize = builder.create( loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments)); auto arrayPtr = builder.create(loc, llvmPointerPointerType, @@ -509,7 +510,7 @@ auto index = builder.create( loc, llvmInt32Type, builder.getI32IntegerAttr(en.index())); auto fieldPtr = builder.create( - loc, argumentTypes[en.index()].getPointerTo(), structPtr, + loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr, ArrayRef{zero, index.getResult()}); builder.create(loc, en.value(), fieldPtr); auto elementPtr = builder.create(loc, llvmPointerPointerType, diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -51,8 +51,8 @@ // Rewrite the original GPU function to an LLVM function. auto funcType = typeConverter->convertType(gpuFuncOp.getType()) - .template cast() - .getPointerElementTy(); + .template cast() + .getElementType(); // Remap proper input types. TypeConverter::SignatureConversion signatureConversion( @@ -94,10 +94,11 @@ for (auto en : llvm::enumerate(workgroupBuffers)) { LLVM::GlobalOp global = en.value(); Value address = rewriter.create(loc, global); - auto elementType = global.getType().getArrayElementType(); + auto elementType = + global.getType().cast().getElementType(); Value memory = rewriter.create( - loc, elementType.getPointerTo(global.addr_space()), address, - ArrayRef{zero, zero}); + loc, LLVM::LLVMPointerType::get(elementType, global.addr_space()), + address, ArrayRef{zero, zero}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than @@ -123,9 +124,10 @@ // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). - auto ptrType = typeConverter->convertType(type.getElementType()) - .template cast() - .getPointerTo(AllocaAddrSpace); + auto ptrType = LLVM::LLVMPointerType::get( + typeConverter->convertType(type.getElementType()) + .template cast(), + AllocaAddrSpace); Value numElements = rewriter.create( gpuFuncOp.getLoc(), int64Ty, rewriter.getI64IntegerAttr(type.getNumElements())); diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -57,7 +57,8 @@ LLVMType resultType = castedOperands.front().getType().cast(); LLVMType funcType = getFunctionType(resultType, castedOperands); - StringRef funcName = getFunctionName(funcType.getFunctionResultType()); + StringRef funcName = getFunctionName( + funcType.cast().getReturnType()); if (funcName.empty()) return failure(); @@ -80,7 +81,7 @@ private: Value maybeCast(Value operand, PatternRewriter &rewriter) const { LLVM::LLVMType type = operand.getType().cast(); - if (!type.isHalfTy()) + if (!type.isa()) return operand; return rewriter.create( @@ -100,9 +101,9 @@ } StringRef getFunctionName(LLVM::LLVMType type) const { - if (type.isFloatTy()) + if (type.isa()) return f32Func; - if (type.isDoubleTy()) + if (type.isa()) return f64Func; return ""; } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -75,7 +75,7 @@ // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; - auto llvmPtrToElementType = elemenType.getPointerTo(); + auto llvmPtrToElementType = LLVM::LLVMPointerType::get(elemenType); auto llvmArrayRankElementSizeType = LLVM::LLVMType::getArrayTy(getInt64Type(), rank); @@ -131,16 +131,18 @@ /// Returns a string representation from the given `type`. StringRef stringifyType(LLVM::LLVMType type) { - if (type.isFloatTy()) + if (type.isa()) return "Float"; - if (type.isHalfTy()) + if (type.isa()) return "Half"; - if (type.isIntegerTy(32)) - return "Int32"; - if (type.isIntegerTy(16)) - return "Int16"; - if (type.isIntegerTy(8)) - return "Int8"; + if (auto intType = type.dyn_cast()) { + if (intType.getBitWidth() == 32) + return "Int32"; + if (intType.getBitWidth() == 16) + return "Int16"; + if (intType.getBitWidth() == 8) + return "Int8"; + } llvm_unreachable("unsupported type"); } @@ -238,11 +240,11 @@ llvm::formatv("bindMemRef{0}D{1}", rank, stringifyType(type)).str(); // Special case for fp16 type. Since it is not a supported type in C we use // int16_t and bitcast the descriptor. - if (type.isHalfTy()) { + if (type.isa()) { auto memRefTy = getMemRefType(rank, LLVM::LLVMType::getInt16Ty(&getContext())); ptrToMemRefDescriptor = builder.create( - loc, memRefTy.getPointerTo(), ptrToMemRefDescriptor); + loc, LLVM::LLVMPointerType::get(memRefTy), ptrToMemRefDescriptor); } // Create call to `bindMemRef`. builder.create( @@ -257,11 +259,12 @@ LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType( Value ptrToMemRefDescriptor, uint32_t &rank, LLVM::LLVMType &type) { auto llvmPtrDescriptorTy = - ptrToMemRefDescriptor.getType().dyn_cast(); + ptrToMemRefDescriptor.getType().dyn_cast(); if (!llvmPtrDescriptorTy) return failure(); - auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); + auto llvmDescriptorTy = + llvmPtrDescriptorTy.getElementType().dyn_cast(); // template // struct { // Elem *allocated; @@ -270,15 +273,19 @@ // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; - if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) + if (!llvmDescriptorTy) return failure(); - type = llvmDescriptorTy.getStructElementType(0).getPointerElementTy(); - if (llvmDescriptorTy.getStructNumElements() == 3) { + type = llvmDescriptorTy.getBody()[0] + .cast() + .getElementType(); + if (llvmDescriptorTy.getBody().size() == 3) { rank = 0; return success(); } - rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); + rank = llvmDescriptorTy.getBody()[3] + .cast() + .getNumElements(); return success(); } @@ -326,13 +333,13 @@ LLVM::LLVMType::getHalfTy(&getContext())}) { std::string fnName = "bindMemRef" + std::to_string(i) + "D" + std::string(stringifyType(type)); - if (type.isHalfTy()) + if (type.isa()) type = LLVM::LLVMType::getInt16Ty(&getContext()); if (!module.lookupSymbol(fnName)) { auto fnType = LLVM::LLVMType::getFunctionTy( getVoidType(), {getPointerType(), getInt32Type(), getInt32Type(), - getMemRefType(i, type).getPointerTo()}, + LLVM::LLVMPointerType::get(getMemRefType(i, type))}, /*isVarArg=*/false); builder.create(loc, fnName, fnType); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -66,8 +66,10 @@ /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { - return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth() - : type.getIntegerBitWidth(); + auto vectorType = type.dyn_cast(); + return (vectorType ? vectorType.getElementType() : type) + .cast() + .getBitWidth(); } /// Creates `IntegerAttribute` with all bits set for given type @@ -265,7 +267,7 @@ TypeConverter &converter) { auto pointeeType = converter.convertType(type.getPointeeType()).cast(); - return pointeeType.getPointerTo(); + return LLVM::LLVMPointerType::get(pointeeType); } /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -215,7 +215,7 @@ SignatureConversion conversion(type.getNumInputs()); LLVM::LLVMType converted = convertFunctionSignature(type, /*isVariadic=*/false, conversion); - return converted.getPointerTo(); + return LLVM::LLVMPointerType::get(converted); } @@ -267,7 +267,7 @@ if (!converted) return {}; if (t.isa()) - converted = converted.getPointerTo(); + converted = LLVM::LLVMPointerType::get(converted); inputs.push_back(converted); } @@ -324,7 +324,7 @@ LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); + auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); auto indexTy = getIndexType(); SmallVector results = {ptrTy, ptrTy, indexTy}; @@ -396,7 +396,7 @@ LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - return elementType.getPointerTo(type.getMemorySpace()); + return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when @@ -460,7 +460,7 @@ Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) { - Type type = structType.cast().getStructElementType(pos); + Type type = structType.cast().getBody()[pos]; return builder.create(loc, type, value, builder.getI64ArrayAttr(pos)); } @@ -507,8 +507,9 @@ MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); - indexType = value.getType().cast().getStructElementType( - kOffsetPosInMemRefDescriptor); + indexType = value.getType() + .cast() + .getBody()[kOffsetPosInMemRefDescriptor]; } /// Builds IR creating an `undef` value of the descriptor type. @@ -618,9 +619,9 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, int64_t rank) { auto indexTy = indexType.cast(); - auto indexPtrTy = indexTy.getPointerTo(); + auto indexPtrTy = LLVM::LLVMPointerType::get(indexTy); auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank); - auto arrayPtrTy = arrayTy.getPointerTo(); + auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy); // Copy size values to stack-allocated memory. auto zero = createIndexAttrConstant(builder, loc, indexType, 0); @@ -675,8 +676,8 @@ LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { return value.getType() - .cast() - .getStructElementType(kAlignedPtrPosInMemRefDescriptor) + .cast() + .getBody()[kAlignedPtrPosInMemRefDescriptor] .cast(); } @@ -922,7 +923,7 @@ Value offsetGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); offsetGep = builder.create( - loc, typeConverter.getIndexType().getPointerTo(), offsetGep); + loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); return builder.create(loc, offsetGep); } @@ -939,19 +940,17 @@ Value offsetGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); offsetGep = builder.create( - loc, typeConverter.getIndexType().getPointerTo(), offsetGep); + loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); builder.create(loc, offset, offsetGep); } -Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - LLVM::LLVMType elemPtrPtrType) { - LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy(); +Value UnrankedMemRefDescriptor::sizeBasePtr( + OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) { + LLVM::LLVMType elemPtrTy = elemPtrPtrType.getElementType(); LLVM::LLVMType indexTy = typeConverter.getIndexType(); - LLVM::LLVMType structPtrTy = - LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy) - .getPointerTo(); + LLVM::LLVMType structPtrTy = LLVM::LLVMPointerType::get( + LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy)); Value structPtr = builder.create(loc, structPtrTy, memRefDescPtr); @@ -961,14 +960,15 @@ createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); Value three = builder.create(loc, int32_type, builder.getI32IntegerAttr(3)); - return builder.create(loc, indexTy.getPointerTo(), structPtr, - ValueRange({zero, three})); + return builder.create(loc, LLVM::LLVMPointerType::get(indexTy), + structPtr, ValueRange({zero, three})); } Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index) { - LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + LLVM::LLVMType indexPtrTy = + LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); return builder.create(loc, sizeStoreGep); @@ -978,7 +978,8 @@ LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index, Value size) { - LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + LLVM::LLVMType indexPtrTy = + LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); builder.create(loc, size, sizeStoreGep); @@ -987,7 +988,8 @@ Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank) { - LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + LLVM::LLVMType indexPtrTy = + LLVM::LLVMPointerType::get(typeConverter.getIndexType()); return builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({rank})); } @@ -996,7 +998,8 @@ LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { - LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + LLVM::LLVMType indexPtrTy = + LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); return builder.create(loc, strideStoreGep); @@ -1006,7 +1009,8 @@ LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { - LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + LLVM::LLVMType indexPtrTy = + LLVM::LLVMPointerType::get(typeConverter.getIndexType()); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); builder.create(loc, stride, strideStoreGep); @@ -1100,7 +1104,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); auto structElementType = unwrap(typeConverter->convertType(elementType)); - return structElementType.getPointerTo(type.getMemorySpace()); + return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace()); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( @@ -1158,8 +1162,8 @@ // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. - auto convertedPtrType = - typeConverter->convertType(type).cast().getPointerTo(); + auto convertedPtrType = LLVM::LLVMPointerType::get( + typeConverter->convertType(type).cast()); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( loc, convertedPtrType, @@ -1315,7 +1319,8 @@ builder, loc, typeConverter, unrankedMemRefType, wrapperArgsRange.take_front(numToDrop)); - auto ptrTy = packed.getType().cast().getPointerTo(); + auto ptrTy = + LLVM::LLVMPointerType::get(packed.getType().cast()); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); @@ -1512,11 +1517,12 @@ return info; info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmArrayTy; - while (llvmTy.isArrayTy()) { - info.arraySizes.push_back(llvmTy.getArrayNumElements()); - llvmTy = llvmTy.getArrayElementType(); + while (llvmTy.isa()) { + info.arraySizes.push_back( + llvmTy.cast().getNumElements()); + llvmTy = llvmTy.cast().getElementType(); } - if (!llvmTy.isVectorTy()) + if (!llvmTy.isa()) return info; info.llvmVectorTy = llvmTy; return info; @@ -1644,7 +1650,7 @@ return failure(); auto llvmArrayTy = operands[0].getType().cast(); - if (!llvmArrayTy.isArrayTy()) + if (!llvmArrayTy.isa()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy, @@ -2457,13 +2463,14 @@ LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( - loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name()); + loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. LLVM::LLVMType elementType = unwrap(typeConverter->convertType(type.getElementType())); - LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace); + LLVM::LLVMType elementPtrType = + LLVM::LLVMPointerType::get(elementType, memSpace); SmallVector operands = {addressOf}; operands.insert(operands.end(), type.getRank() + 1, @@ -2504,9 +2511,9 @@ auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); - if (!operandType.isArrayTy()) { + if (!operandType.isa()) { LLVM::ConstantOp one; - if (operandType.isVectorTy()) { + if (operandType.isa()) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)); @@ -2526,8 +2533,10 @@ op.getOperation(), operands, *getTypeConverter(), [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, - floatType), + mlir::VectorType::get( + {llvmVectorTy.cast() + .getNumElements()}, + floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); @@ -2614,12 +2623,13 @@ // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* - auto castPtr = - rewriter - .create( - loc, targetStructType.cast().getPointerTo(), - ptr) - .getResult(); + auto castPtr = rewriter + .create( + loc, + LLVM::LLVMPointerType::get( + targetStructType.cast()), + ptr) + .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); rewriter.replaceOp(memRefCastOp, loadOp.getResult()); @@ -2654,8 +2664,8 @@ Type elementType = operandType.cast().getElementType(); LLVM::LLVMType llvmElementType = unwrap(typeConverter.convertType(elementType)); - LLVM::LLVMType elementPtrPtrType = - llvmElementType.getPointerTo(memorySpace).getPointerTo(); + LLVM::LLVMType elementPtrPtrType = LLVM::LLVMPointerType::get( + LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); // Extract pointer to the underlying ranked memref descriptor and cast it to // ElemType**. @@ -2700,8 +2710,8 @@ MemRefType targetMemRefType = castOp.getResult().getType().cast(); auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); - if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy) return failure(); // Create descriptor. @@ -2804,8 +2814,8 @@ // Set pointers and offset. LLVM::LLVMType llvmElementType = unwrap(typeConverter->convertType(elementType)); - LLVM::LLVMType elementPtrPtrType = - llvmElementType.getPointerTo(addressSpace).getPointerTo(); + auto elementPtrPtrType = LLVM::LLVMPointerType::get( + LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrPtrType, allocatedPtr); UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), @@ -2858,7 +2868,7 @@ rewriter.setInsertionPointToStart(bodyBlock); // Copy size from shape to descriptor. - LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo(); + LLVM::LLVMType llvmIndexPtrType = LLVM::LLVMPointerType::get(indexType); Value sizeLoadGep = rewriter.create( loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); Value size = rewriter.create(loc, sizeLoadGep); @@ -2950,14 +2960,14 @@ Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, - typeConverter->convertType(scalarMemRefType) - .cast() - .getPointerTo(addressSpace), + LLVM::LLVMPointerType::get( + typeConverter->convertType(scalarMemRefType).cast(), + addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. - Type indexPtrTy = - getTypeConverter()->getIndexType().getPointerTo(addressSpace); + Type indexPtrTy = LLVM::LLVMPointerType::get( + getTypeConverter()->getIndexType(), addressSpace); Value two = rewriter.create( loc, typeConverter->convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); @@ -3120,10 +3130,10 @@ auto targetType = typeConverter->convertType(indexCastOp.getResult().getType()) - .cast(); - auto sourceType = transformed.in().getType().cast(); - unsigned targetBits = targetType.getIntegerBitWidth(); - unsigned sourceBits = sourceType.getIntegerBitWidth(); + .cast(); + auto sourceType = transformed.in().getType().cast(); + unsigned targetBits = targetType.getBitWidth(); + unsigned sourceBits = sourceType.getBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(indexCastOp, transformed.in()); @@ -3462,14 +3472,18 @@ // Copy the buffer pointer from the old descriptor to the new one. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( - loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), + loc, + LLVM::LLVMPointerType::get(targetElementTy, + viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Copy the buffer pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( - loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), + loc, + LLVM::LLVMPointerType::get(targetElementTy, + viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); @@ -3662,7 +3676,9 @@ Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = viewOp.source().getType().cast(); Value bitcastPtr = rewriter.create( - loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), + loc, + LLVM::LLVMPointerType::get(targetElementTy, + srcMemRefType.getMemorySpace()), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -3671,7 +3687,9 @@ alignedPtr = rewriter.create(loc, alignedPtr.getType(), alignedPtr, adaptor.byte_shift()); bitcastPtr = rewriter.create( - loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), + loc, + LLVM::LLVMPointerType::get(targetElementTy, + srcMemRefType.getMemorySpace()), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); @@ -4064,7 +4082,8 @@ auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. - auto ptrType = operand.getType().cast().getPointerTo(); + auto ptrType = + LLVM::LLVMPointerType::get(operand.getType().cast()); Value one = builder.create(loc, int64Ty, IntegerAttr::get(indexType, 1)); Value allocated = diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -195,7 +195,7 @@ Value base; if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); - auto pType = type.template cast().getPointerTo(); + auto pType = LLVM::LLVMPointerType::get(type.template cast()); base = rewriter.create(loc, pType, base); ptr = rewriter.create(loc, pType, base); return success(); @@ -1094,14 +1094,14 @@ return failure(); auto llvmSourceDescriptorTy = - operands[0].getType().dyn_cast(); - if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) + operands[0].getType().dyn_cast(); + if (!llvmSourceDescriptorTy) return failure(); MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = typeConverter->convertType(targetMemRefType) - .dyn_cast_or_null(); - if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy) return failure(); // Only contiguous source buffers supported atm. @@ -1223,15 +1223,15 @@ // TODO: support alignment when possible. Value dataPtr = this->getStridedElementPtr( loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); - auto vecTy = - toLLVMTy(xferOp.getVectorType()).template cast(); + auto vecTy = toLLVMTy(xferOp.getVectorType()) + .template cast(); Value vectorDataPtr; if (memRefType.getMemorySpace() == 0) - vectorDataPtr = - rewriter.create(loc, vecTy.getPointerTo(), dataPtr); + vectorDataPtr = rewriter.create( + loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); else vectorDataPtr = rewriter.create( - loc, vecTy.getPointerTo(), dataPtr); + loc, LLVM::LLVMPointerType::get(vecTy), dataPtr); if (!xferOp.isMaskedDim(0)) return replaceTransferOpWithLoadOrStore(rewriter, @@ -1245,7 +1245,7 @@ // // TODO: when the leaf transfer rank is k > 1, we need the last `k` // dimensions here. - unsigned vecWidth = vecTy.getVectorNumElements(); + unsigned vecWidth = vecTy.getNumElements(); unsigned lastIndex = llvm::size(xferOp.indices()) - 1; Value off = xferOp.indices()[lastIndex]; Value dim = rewriter.create(loc, xferOp.memref(), lastIndex); diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -78,9 +78,9 @@ auto toLLVMTy = [&](Type t) { return this->getTypeConverter()->convertType(t); }; - LLVM::LLVMType vecTy = - toLLVMTy(xferOp.getVectorType()).template cast(); - unsigned vecWidth = vecTy.getVectorNumElements(); + auto vecTy = toLLVMTy(xferOp.getVectorType()) + .template cast(); + unsigned vecWidth = vecTy.getNumElements(); Location loc = xferOp->getLoc(); // The backend result vector scalarization have trouble scalarize diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -105,9 +105,10 @@ auto argType = type.dyn_cast(); if (!argType) return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); - if (argType.isVectorTy()) - resultType = - LLVMType::getVectorTy(resultType, argType.getVectorNumElements()); + if (auto vecArgType = argType.dyn_cast()) + resultType = LLVMType::getVectorTy(resultType, vecArgType.getNumElements()); + assert(!argType.isa() && + "unhandled scalable vector"); result.addTypes({resultType}); return success(); @@ -118,7 +119,7 @@ //===----------------------------------------------------------------------===// static void printAllocaOp(OpAsmPrinter &p, AllocaOp &op) { - auto elemTy = op.getType().cast().getPointerElementTy(); + auto elemTy = op.getType().cast().getElementType(); auto funcTy = FunctionType::get(op.getContext(), {op.arraySize().getType()}, {op.getType()}); @@ -363,14 +364,11 @@ // the resulting type wrapped in MLIR, or nullptr on error. static Type getLoadStoreElementType(OpAsmParser &parser, Type type, llvm::SMLoc trailingTypeLoc) { - auto llvmTy = type.dyn_cast(); + auto llvmTy = type.dyn_cast(); if (!llvmTy) - return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"), - nullptr; - if (!llvmTy.isPointerTy()) return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), nullptr; - return llvmTy.getPointerElementTy(); + return llvmTy.getElementType(); } // ::= `llvm.load` `volatile` ssa-use attribute-dict? `:` type @@ -569,7 +567,7 @@ auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, /*isVarArg=*/false); - auto wrappedFuncType = llvmFuncType.getPointerTo(); + auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); auto funcArguments = llvm::makeArrayRef(operands).drop_front(); @@ -613,7 +611,7 @@ for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { value = op.getOperand(idx); - bool isFilter = value.getType().cast().isArrayTy(); + bool isFilter = value.getType().isa(); if (isFilter) { // FIXME: Verify filter clauses when arrays are appropriately handled } else { @@ -646,7 +644,7 @@ for (auto value : op.getOperands()) { // Similar to llvm - if clause is an array type then it is filter // clause else catch clause - bool isArrayTy = value.getType().cast().isArrayTy(); + bool isArrayTy = value.getType().isa(); p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : " << value.getType() << ") "; } @@ -728,37 +726,37 @@ fnType = fn.getType(); } - if (!fnType.isFunctionTy()) + + LLVMFunctionType funcType = fnType.dyn_cast(); + if (!funcType) return op.emitOpError("callee does not have a functional type: ") << fnType; // Verify that the operand and result types match the callee. - if (!fnType.isFunctionVarArg() && - fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect)) + if (!funcType.isVarArg() && + funcType.getNumParams() != (op.getNumOperands() - isIndirect)) return op.emitOpError() << "incorrect number of operands (" << (op.getNumOperands() - isIndirect) - << ") for callee (expecting: " << fnType.getFunctionNumParams() - << ")"; + << ") for callee (expecting: " << funcType.getNumParams() << ")"; - if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect)) + if (funcType.getNumParams() > (op.getNumOperands() - isIndirect)) return op.emitOpError() << "incorrect number of operands (" << (op.getNumOperands() - isIndirect) << ") for varargs callee (expecting at least: " - << fnType.getFunctionNumParams() << ")"; + << funcType.getNumParams() << ")"; - for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i) - if (op.getOperand(i + isIndirect).getType() != - fnType.getFunctionParamType(i)) + for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) + if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i)) return op.emitOpError() << "operand type mismatch for operand " << i << ": " << op.getOperand(i + isIndirect).getType() - << " != " << fnType.getFunctionParamType(i); + << " != " << funcType.getParamType(i); if (op.getNumResults() && - op.getResult(0).getType() != fnType.getFunctionResultType()) + op.getResult(0).getType() != funcType.getReturnType()) return op.emitOpError() << "result type mismatch: " << op.getResult(0).getType() - << " != " << fnType.getFunctionResultType(); + << " != " << funcType.getReturnType(); return success(); } @@ -848,7 +846,7 @@ } auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes, /*isVarArg=*/false); - auto wrappedFuncType = llvmFuncType.getPointerTo(); + auto wrappedFuncType = LLVM::LLVMPointerType::get(llvmFuncType); auto funcArguments = ArrayRef(operands).drop_front(); @@ -875,8 +873,8 @@ void LLVM::ExtractElementOp::build(OpBuilder &b, OperationState &result, Value vector, Value position, ArrayRef attrs) { - auto wrappedVectorType = vector.getType().cast(); - auto llvmType = wrappedVectorType.getVectorElementType(); + auto vectorType = vector.getType().cast(); + auto llvmType = vectorType.getElementType(); build(b, result, llvmType, vector, position); result.addAttributes(attrs); } @@ -903,11 +901,11 @@ parser.resolveOperand(vector, type, result.operands) || parser.resolveOperand(position, positionType, result.operands)) return failure(); - auto wrappedVectorType = type.dyn_cast(); - if (!wrappedVectorType || !wrappedVectorType.isVectorTy()) + auto vectorType = type.dyn_cast(); + if (!vectorType) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); - result.addTypes(wrappedVectorType.getVectorElementType()); + result.addTypes(vectorType.getElementType()); return success(); } @@ -930,8 +928,8 @@ ArrayAttr positionAttr, llvm::SMLoc attributeLoc, llvm::SMLoc typeLoc) { - auto wrappedContainerType = containerType.dyn_cast(); - if (!wrappedContainerType) + auto llvmType = containerType.dyn_cast(); + if (!llvmType) return parser.emitError(typeLoc, "expected LLVM IR Dialect type"), nullptr; // Infer the element type from the structure type: iteratively step inside the @@ -945,26 +943,24 @@ "expected an array of integer literals"), nullptr; int position = positionElementAttr.getInt(); - if (wrappedContainerType.isArrayTy()) { - if (position < 0 || static_cast(position) >= - wrappedContainerType.getArrayNumElements()) + if (auto arrayType = llvmType.dyn_cast()) { + if (position < 0 || + static_cast(position) >= arrayType.getNumElements()) return parser.emitError(attributeLoc, "position out of bounds"), nullptr; - wrappedContainerType = wrappedContainerType.getArrayElementType(); - } else if (wrappedContainerType.isStructTy()) { - if (position < 0 || static_cast(position) >= - wrappedContainerType.getStructNumElements()) + llvmType = arrayType.getElementType(); + } else if (auto structType = llvmType.dyn_cast()) { + if (position < 0 || + static_cast(position) >= structType.getBody().size()) return parser.emitError(attributeLoc, "position out of bounds"), nullptr; - wrappedContainerType = - wrappedContainerType.getStructElementType(position); + llvmType = structType.getBody()[position]; } else { - return parser.emitError(typeLoc, - "expected wrapped LLVM IR structure/array type"), + return parser.emitError(typeLoc, "expected LLVM IR structure/array type"), nullptr; } } - return wrappedContainerType; + return llvmType; } // ::= `llvm.extractvalue` ssa-use @@ -1021,11 +1017,11 @@ parser.parseColonType(vectorType)) return failure(); - auto wrappedVectorType = vectorType.dyn_cast(); - if (!wrappedVectorType || !wrappedVectorType.isVectorTy()) + auto llvmVectorType = vectorType.dyn_cast(); + if (!llvmVectorType) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); - auto valueType = wrappedVectorType.getVectorElementType(); + Type valueType = llvmVectorType.getElementType(); if (!valueType) return failure(); @@ -1145,12 +1141,14 @@ return op.emitOpError( "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); - if (global && global.getType().getPointerTo(global.addr_space()) != - op.getResult().getType()) + if (global && + LLVM::LLVMPointerType::get(global.getType(), global.addr_space()) != + op.getResult().getType()) return op.emitOpError( "the type must be a pointer to the type of the referenced global"); - if (function && function.getType().getPointerTo() != op.getResult().getType()) + if (function && LLVM::LLVMPointerType::get(function.getType()) != + op.getResult().getType()) return op.emitOpError( "the type must be a pointer to the type of the referenced function"); @@ -1276,11 +1274,11 @@ if (vectorType.getRank() != 1) return op->emitOpError("only 1-d vector is allowed"); - auto llvmVector = llvmType.dyn_cast(); - if (llvmVector.isa()) + auto llvmVector = llvmType.dyn_cast(); + if (!llvmVector) return op->emitOpError("only fixed-sized vector is allowed"); - if (vectorType.getDimSize(0) != llvmVector.getVectorNumElements()) + if (vectorType.getDimSize(0) != llvmVector.getNumElements()) return op->emitOpError( "invalid cast between vectors with mismatching sizes"); @@ -1375,7 +1373,10 @@ "be an index-compatible integer"); auto ptrType = structType.getBody()[1].dyn_cast(); - if (!ptrType || !ptrType.getPointerElementTy().isIntegerTy(8)) + auto ptrElementType = + ptrType ? ptrType.getElementType().dyn_cast() + : nullptr; + if (!ptrElementType || ptrElementType.getBitWidth() != 8) return op->emitOpError("expected second element of a memref descriptor " "to be an !llvm.ptr"); @@ -1503,9 +1504,11 @@ return op.emitOpError("must appear at the module level"); if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { - auto type = op.getType(); - if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) || - type.getArrayNumElements() != strAttr.getValue().size()) + auto type = op.getType().dyn_cast(); + LLVMIntegerType elementType = + type ? type.getElementType().dyn_cast() : nullptr; + if (!elementType || elementType.getBitWidth() != 8 || + type.getNumElements() != strAttr.getValue().size()) return op.emitOpError( "requires an i8 array type of the length equal to that of the string " "attribute"); @@ -1534,9 +1537,9 @@ void LLVM::ShuffleVectorOp::build(OpBuilder &b, OperationState &result, Value v1, Value v2, ArrayAttr mask, ArrayRef attrs) { - auto wrappedContainerType1 = v1.getType().cast(); - auto vType = LLVMType::getVectorTy( - wrappedContainerType1.getVectorElementType(), mask.size()); + auto containerType = v1.getType().cast(); + auto vType = + LLVMType::getVectorTy(containerType.getElementType(), mask.size()); build(b, result, vType, v1, v2, mask); result.addAttributes(attrs); } @@ -1566,12 +1569,12 @@ parser.resolveOperand(v1, typeV1, result.operands) || parser.resolveOperand(v2, typeV2, result.operands)) return failure(); - auto wrappedContainerType1 = typeV1.dyn_cast(); - if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy()) + auto containerType = typeV1.dyn_cast(); + if (!containerType) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); - auto vType = LLVMType::getVectorTy( - wrappedContainerType1.getVectorElementType(), maskAttr.size()); + auto vType = + LLVMType::getVectorTy(containerType.getElementType(), maskAttr.size()); result.addTypes(vType); return success(); } @@ -1588,9 +1591,9 @@ auto *entry = new Block; push_back(entry); - LLVMType type = getType(); - for (unsigned i = 0, e = type.getFunctionNumParams(); i < e; ++i) - entry->addArgument(type.getFunctionParamType(i)); + LLVMFunctionType type = getType(); + for (unsigned i = 0, e = type.getNumParams(); i < e; ++i) + entry->addArgument(type.getParamType(i)); return entry; } @@ -1608,7 +1611,7 @@ if (argAttrs.empty()) return; - unsigned numInputs = type.getFunctionNumParams(); + unsigned numInputs = type.cast().getNumParams(); assert(numInputs == argAttrs.size() && "expected as many argument attribute lists as arguments"); SmallString<8> argAttrName; @@ -1711,15 +1714,15 @@ p << stringifyLinkage(op.linkage()) << ' '; p.printSymbolName(op.getName()); - LLVMType fnType = op.getType(); + LLVMFunctionType fnType = op.getType(); SmallVector argTypes; SmallVector resTypes; - argTypes.reserve(fnType.getFunctionNumParams()); - for (unsigned i = 0, e = fnType.getFunctionNumParams(); i < e; ++i) - argTypes.push_back(fnType.getFunctionParamType(i)); + argTypes.reserve(fnType.getNumParams()); + for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i) + argTypes.push_back(fnType.getParamType(i)); - LLVMType returnType = fnType.getFunctionResultType(); - if (!returnType.isVoidTy()) + LLVMType returnType = fnType.getReturnType(); + if (!returnType.isa()) resTypes.push_back(returnType); impl::printFunctionSignature(p, op, argTypes, op.isVarArg(), resTypes); @@ -1737,8 +1740,8 @@ // attribute is present. This can check for preconditions of the // getNumArguments hook not failing. LogicalResult LLVMFuncOp::verifyType() { - auto llvmType = getTypeAttr().getValue().dyn_cast_or_null(); - if (!llvmType || !llvmType.isFunctionTy()) + auto llvmType = getTypeAttr().getValue().dyn_cast_or_null(); + if (!llvmType) return emitOpError("requires '" + getTypeAttrName() + "' attribute of wrapped LLVM function type"); @@ -1747,9 +1750,7 @@ // Hook for OpTrait::FunctionLike, returns the number of function arguments. // Depends on the type attribute being correct as checked by verifyType -unsigned LLVMFuncOp::getNumFuncArguments() { - return getType().getFunctionNumParams(); -} +unsigned LLVMFuncOp::getNumFuncArguments() { return getType().getNumParams(); } // Hook for OpTrait::FunctionLike, returns the number of function results. // Depends on the type attribute being correct as checked by verifyType @@ -1759,7 +1760,7 @@ // If we modeled a void return as one result, then it would be possible to // attach an MLIR result attribute to it, and it isn't clear what semantics we // would assign to that. - if (getType().getFunctionResultType().isVoidTy()) + if (getType().getReturnType().isa()) return 0; return 1; } @@ -1788,7 +1789,7 @@ if (op.isVarArg()) return op.emitOpError("only external functions can be variadic"); - unsigned numArguments = op.getType().getFunctionNumParams(); + unsigned numArguments = op.getType().getNumParams(); Block &entryBlock = op.front(); for (unsigned i = 0; i < numArguments; ++i) { Type argType = entryBlock.getArgument(i).getType(); @@ -1796,7 +1797,7 @@ if (!argLLVMType) return op.emitOpError("entry block argument #") << i << " is not of LLVM type"; - if (op.getType().getFunctionParamType(i) != argLLVMType) + if (op.getType().getParamType(i) != argLLVMType) return op.emitOpError("the type of entry block argument #") << i << " does not match the function signature"; } @@ -1896,7 +1897,8 @@ parseAtomicOrdering(parser, result, "ordering") || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || - parser.resolveOperand(ptr, type.getPointerTo(), result.operands) || + parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), + result.operands) || parser.resolveOperand(val, type, result.operands)) return failure(); @@ -1905,9 +1907,9 @@ } static LogicalResult verify(AtomicRMWOp op) { - auto ptrType = op.ptr().getType().cast(); + auto ptrType = op.ptr().getType().cast(); auto valType = op.val().getType().cast(); - if (valType != ptrType.getPointerElementTy()) + if (valType != ptrType.getElementType()) return op.emitOpError("expected LLVM IR element type for operand #0 to " "match type for operand #1"); auto resType = op.res().getType().cast(); @@ -1915,17 +1917,21 @@ return op.emitOpError( "expected LLVM IR result type to match type for operand #1"); if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { - if (!valType.isFloatingPointTy()) + if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) return op.emitOpError("expected LLVM IR floating point type"); } else if (op.bin_op() == AtomicBinOp::xchg) { - if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && - !valType.isIntegerTy(32) && !valType.isIntegerTy(64) && - !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() && - !valType.isDoubleTy()) + auto intType = valType.dyn_cast(); + unsigned intBitWidth = intType ? intType.getBitWidth() : 0; + if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && + intBitWidth != 64 && !valType.isa() && + !valType.isa() && !valType.isa() && + !valType.isa()) return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { - if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && - !valType.isIntegerTy(32) && !valType.isIntegerTy(64)) + auto intType = valType.dyn_cast(); + unsigned intBitWidth = intType ? intType.getBitWidth() : 0; + if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && + intBitWidth != 64) return op.emitOpError("expected LLVM IR integer type"); } return success(); @@ -1958,7 +1964,8 @@ parseAtomicOrdering(parser, result, "failure_ordering") || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || - parser.resolveOperand(ptr, type.getPointerTo(), result.operands) || + parser.resolveOperand(ptr, LLVM::LLVMPointerType::get(type), + result.operands) || parser.resolveOperand(cmp, type, result.operands) || parser.resolveOperand(val, type, result.operands)) return failure(); @@ -1971,18 +1978,20 @@ } static LogicalResult verify(AtomicCmpXchgOp op) { - auto ptrType = op.ptr().getType().cast(); - if (!ptrType.isPointerTy()) + auto ptrType = op.ptr().getType().cast(); + if (!ptrType) return op.emitOpError("expected LLVM IR pointer type for operand #0"); auto cmpType = op.cmp().getType().cast(); auto valType = op.val().getType().cast(); - if (cmpType != ptrType.getPointerElementTy() || cmpType != valType) + if (cmpType != ptrType.getElementType() || cmpType != valType) return op.emitOpError("expected LLVM IR element type for operand #0 to " "match type for all other operands"); - if (!valType.isPointerTy() && !valType.isIntegerTy(8) && - !valType.isIntegerTy(16) && !valType.isIntegerTy(32) && - !valType.isIntegerTy(64) && !valType.isBFloatTy() && - !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy()) + auto intType = valType.dyn_cast(); + unsigned intBitWidth = intType ? intType.getBitWidth() : 0; + if (!valType.isa() && intBitWidth != 8 && + intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && + !valType.isa() && !valType.isa() && + !valType.isa() && !valType.isa()) return op.emitOpError("unexpected LLVM IR type"); if (op.success_ordering() < AtomicOrdering::monotonic || op.failure_ordering() < AtomicOrdering::monotonic) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -36,129 +36,6 @@ return static_cast(Type::getDialect()); } -//----------------------------------------------------------------------------// -// Misc type utilities. - -llvm::TypeSize LLVMType::getPrimitiveSizeInBits() { - return llvm::TypeSwitch(*this) - .Case( - [](LLVMType) { return llvm::TypeSize::Fixed(16); }) - .Case([](LLVMType) { return llvm::TypeSize::Fixed(32); }) - .Case( - [](LLVMType) { return llvm::TypeSize::Fixed(64); }) - .Case([](LLVMIntegerType intTy) { - return llvm::TypeSize::Fixed(intTy.getBitWidth()); - }) - .Case([](LLVMType) { return llvm::TypeSize::Fixed(80); }) - .Case( - [](LLVMType) { return llvm::TypeSize::Fixed(128); }) - .Case([](LLVMVectorType t) { - llvm::TypeSize elementSize = - t.getElementType().getPrimitiveSizeInBits(); - llvm::ElementCount elementCount = t.getElementCount(); - assert(!elementSize.isScalable() && - "vector type should have fixed-width elements"); - return llvm::TypeSize(elementSize.getFixedSize() * - elementCount.getKnownMinValue(), - elementCount.isScalable()); - }) - .Default([](LLVMType ty) { - assert((ty.isa()) && - "unexpected missing support for primitive type"); - return llvm::TypeSize::Fixed(0); - }); -} - -//----------------------------------------------------------------------------// -// Integer type utilities. - -bool LLVMType::isIntegerTy(unsigned bitwidth) { - if (auto intType = dyn_cast()) - return intType.getBitWidth() == bitwidth; - return false; -} -unsigned LLVMType::getIntegerBitWidth() { - return cast().getBitWidth(); -} - -LLVMType LLVMType::getArrayElementType() { - return cast().getElementType(); -} - -//----------------------------------------------------------------------------// -// Array type utilities. - -unsigned LLVMType::getArrayNumElements() { - return cast().getNumElements(); -} - -bool LLVMType::isArrayTy() { return isa(); } - -//----------------------------------------------------------------------------// -// Vector type utilities. - -LLVMType LLVMType::getVectorElementType() { - return cast().getElementType(); -} - -unsigned LLVMType::getVectorNumElements() { - return cast().getNumElements(); -} -llvm::ElementCount LLVMType::getVectorElementCount() { - return cast().getElementCount(); -} - -bool LLVMType::isVectorTy() { return isa(); } - -//----------------------------------------------------------------------------// -// Function type utilities. - -LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { - return cast().getParamType(argIdx); -} - -unsigned LLVMType::getFunctionNumParams() { - return cast().getNumParams(); -} - -LLVMType LLVMType::getFunctionResultType() { - return cast().getReturnType(); -} - -bool LLVMType::isFunctionTy() { return isa(); } - -bool LLVMType::isFunctionVarArg() { - return cast().isVarArg(); -} - -//----------------------------------------------------------------------------// -// Pointer type utilities. - -LLVMType LLVMType::getPointerTo(unsigned addrSpace) { - return LLVMPointerType::get(*this, addrSpace); -} - -LLVMType LLVMType::getPointerElementTy() { - return cast().getElementType(); -} - -bool LLVMType::isPointerTy() { return isa(); } - -//----------------------------------------------------------------------------// -// Struct type utilities. - -LLVMType LLVMType::getStructElementType(unsigned i) { - return cast().getBody()[i]; -} - -unsigned LLVMType::getStructNumElements() { - return cast().getBody().size(); -} - -bool LLVMType::isStructTy() { return isa(); } - //----------------------------------------------------------------------------// // Utilities used to generate floating point types. @@ -193,6 +70,10 @@ return LLVMIntegerType::get(context, numBits); } +LLVMType LLVMType::getInt8PtrTy(MLIRContext *context) { + return LLVMPointerType::get(LLVMIntegerType::get(context, 8)); +} + //----------------------------------------------------------------------------// // Utilities used to generate other miscellaneous types. @@ -221,8 +102,6 @@ return LLVMVoidType::get(context); } -bool LLVMType::isVoidTy() { return isa(); } - //----------------------------------------------------------------------------// // Creation and setting of LLVM's identified struct types @@ -470,7 +349,7 @@ bool LLVMVectorType::isValidElementType(LLVMType type) { return type.isa() || - type.isFloatingPointTy(); + mlir::LLVM::isCompatibleFloatingPointType(type); } /// Support type casting functionality. @@ -536,3 +415,42 @@ unsigned LLVMScalableVectorType::getMinNumElements() { return getImpl()->numElements; } + +//===----------------------------------------------------------------------===// +// Utility functions. +//===----------------------------------------------------------------------===// + +llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { + assert(isCompatibleType(type) && + "expected a type compatible with the LLVM dialect"); + + return llvm::TypeSwitch(type) + .Case( + [](LLVMType) { return llvm::TypeSize::Fixed(16); }) + .Case([](LLVMType) { return llvm::TypeSize::Fixed(32); }) + .Case( + [](LLVMType) { return llvm::TypeSize::Fixed(64); }) + .Case([](LLVMIntegerType intTy) { + return llvm::TypeSize::Fixed(intTy.getBitWidth()); + }) + .Case([](LLVMType) { return llvm::TypeSize::Fixed(80); }) + .Case( + [](LLVMType) { return llvm::TypeSize::Fixed(128); }) + .Case([](LLVMVectorType t) { + llvm::TypeSize elementSize = + getPrimitiveTypeSizeInBits(t.getElementType()); + llvm::ElementCount elementCount = t.getElementCount(); + assert(!elementSize.isScalable() && + "vector type should have fixed-width elements"); + return llvm::TypeSize(elementSize.getFixedSize() * + elementCount.getKnownMinValue(), + elementCount.isScalable()); + }) + .Default([](Type ty) { + assert((ty.isa()) && + "unexpected missing support for primitive type"); + return llvm::TypeSize::Fixed(0); + }); +} diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -57,8 +57,9 @@ for (auto &attr : result.attributes) { if (attr.first != "return_value_and_is_valid") continue; - if (type.isStructTy() && type.getStructNumElements() > 0) - type = type.getStructElementType(0); + auto structType = type.dyn_cast(); + if (structType && !structType.getBody().empty()) + type = structType.getBody()[0]; break; } diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -196,19 +196,30 @@ Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32)) + auto resultType = mainFunction.getType() + .cast() + .getReturnType() + .dyn_cast(); + if (!resultType || resultType.getBitWidth() != 32) return make_string_error("only single llvm.i32 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64)) + auto resultType = mainFunction.getType() + .cast() + .getReturnType() + .dyn_cast(); + if (!resultType || resultType.getBitWidth() != 64) return make_string_error("only single llvm.i64 function result supported"); return Error::success(); } template <> Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { - if (!mainFunction.getType().getFunctionResultType().isFloatTy()) + if (!mainFunction.getType() + .cast() + .getReturnType() + .isa()) return make_string_error("only single llvm.f32 function result supported"); return Error::success(); } @@ -220,7 +231,7 @@ if (!mainFunction || mainFunction.isExternal()) return make_string_error("entry point not found"); - if (mainFunction.getType().getFunctionNumParams() != 0) + if (mainFunction.getType().cast().getNumParams() != 0) return make_string_error("function inputs not supported"); if (Error error = checkCompatibleReturnType(mainFunction)) diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -172,57 +172,57 @@ if (!type) return nullptr; - if (type.isIntegerTy()) - return b.getIntegerType(type.getIntegerBitWidth()); + if (auto intType = type.dyn_cast()) + return b.getIntegerType(intType.getBitWidth()); - if (type.isFloatTy()) + if (type.isa()) return b.getF32Type(); - if (type.isDoubleTy()) + if (type.isa()) return b.getF64Type(); // LLVM vectors can only contain scalars. - if (type.isVectorTy()) { - auto numElements = type.getVectorElementCount(); + if (auto vectorType = type.dyn_cast()) { + auto numElements = vectorType.getElementCount(); if (numElements.isScalable()) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr; } - Type elementType = getStdTypeForAttr(type.getVectorElementType()); + Type elementType = getStdTypeForAttr(vectorType.getElementType()); if (!elementType) return nullptr; return VectorType::get(numElements.getKnownMinValue(), elementType); } // LLVM arrays can contain other arrays or vectors. - if (type.isArrayTy()) { + if (auto arrayType = type.dyn_cast()) { // Recover the nested array shape. SmallVector shape; - shape.push_back(type.getArrayNumElements()); - while (type.getArrayElementType().isArrayTy()) { - type = type.getArrayElementType(); - shape.push_back(type.getArrayNumElements()); + shape.push_back(arrayType.getNumElements()); + while (arrayType.getElementType().isa()) { + arrayType = arrayType.getElementType().cast(); + shape.push_back(arrayType.getNumElements()); } // If the innermost type is a vector, use the multi-dimensional vector as // attribute type. - if (type.getArrayElementType().isVectorTy()) { - LLVMType vectorType = type.getArrayElementType(); - auto numElements = vectorType.getVectorElementCount(); + if (auto vectorType = + arrayType.getElementType().dyn_cast()) { + auto numElements = vectorType.getElementCount(); if (numElements.isScalable()) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr; } shape.push_back(numElements.getKnownMinValue()); - Type elementType = getStdTypeForAttr(vectorType.getVectorElementType()); + Type elementType = getStdTypeForAttr(vectorType.getElementType()); if (!elementType) return nullptr; return VectorType::get(shape, elementType); } // Otherwise use a tensor. - Type elementType = getStdTypeForAttr(type.getArrayElementType()); + Type elementType = getStdTypeForAttr(arrayType.getElementType()); if (!elementType) return nullptr; return RankedTensorType::get(shape, elementType); @@ -261,7 +261,7 @@ if (!attrType) return nullptr; - if (type.isIntegerTy()) { + if (type.isa()) { SmallVector values; values.reserve(cd->getNumElements()); for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) @@ -269,7 +269,7 @@ return DenseElementsAttr::get(attrType, values); } - if (type.isFloatTy() || type.isDoubleTy()) { + if (type.isa() || type.isa()) { SmallVector values; values.reserve(cd->getNumElements()); for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) @@ -777,7 +777,8 @@ instMap.clear(); unknownInstMap.clear(); - LLVMType functionType = processType(f->getFunctionType()); + auto functionType = + processType(f->getFunctionType()).dyn_cast(); if (!functionType) return failure(); @@ -805,8 +806,8 @@ // Add function arguments to the entry block. for (auto kv : llvm::enumerate(f->args())) - instMap[&kv.value()] = blockList[0]->addArgument( - functionType.getFunctionParamType(kv.index())); + instMap[&kv.value()] = + blockList[0]->addArgument(functionType.getParamType(kv.index())); for (auto bbs : llvm::zip(*f, blockList)) { if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs)))) diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -969,7 +969,7 @@ // NB: Attribute already verified to be boolean, so check if we can indeed // attach the attribute to this argument, based on its type. auto argTy = mlirArg.getType().dyn_cast(); - if (!argTy.isPointerTy()) + if (!argTy.isa()) return func.emitError( "llvm.noalias attribute attached to LLVM non-pointer argument"); if (attr.getValue()) @@ -981,7 +981,7 @@ // NB: Attribute already verified to be int, so check if we can indeed // attach the attribute to this argument, based on its type. auto argTy = mlirArg.getType().dyn_cast(); - if (!argTy.isPointerTy()) + if (!argTy.isa()) return func.emitError( "llvm.align attribute attached to LLVM non-pointer argument"); llvmArg.addAttrs( diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -98,7 +98,7 @@ // ----- func @load_non_llvm_type(%foo : memref) { - // expected-error@+1 {{expected LLVM IR dialect type}} + // expected-error@+1 {{expected LLVM pointer type}} llvm.load %foo : memref } @@ -112,7 +112,7 @@ // ----- func @store_non_llvm_type(%foo : memref, %bar : !llvm.float) { - // expected-error@+1 {{expected LLVM IR dialect type}} + // expected-error@+1 {{expected LLVM pointer type}} llvm.store %bar, %foo : memref } @@ -267,7 +267,7 @@ // ----- func @insertvalue_wrong_nesting() { - // expected-error@+1 {{expected wrapped LLVM IR structure/array type}} + // expected-error@+1 {{expected LLVM IR structure/array type}} llvm.insertvalue %a, %b[0,0] : !llvm.struct<(i32)> } @@ -311,7 +311,7 @@ // ----- func @extractvalue_wrong_nesting() { - // expected-error@+1 {{expected wrapped LLVM IR structure/array type}} + // expected-error@+1 {{expected LLVM IR structure/array type}} llvm.extractvalue %b[0,0] : !llvm.struct<(i32)> }