diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -47,6 +47,14 @@ struct LLVMDialectImpl; } // namespace detail +class LLVMType; + +/// Converts an MLIR LLVM dialect type to LLVM IR type. Note that this function +/// exists exclusively for the purpose of gradual transition to the first-party +/// modeling of LLVM types. It should not be used outside MLIR-to-LLVM +/// translation. +llvm::Type *convertLLVMType(LLVMType type); + class LLVMType : public mlir::Type::TypeBase { public: @@ -59,26 +67,32 @@ static bool kindof(unsigned kind) { return kind == LLVM_TYPE; } LLVMDialect &getDialect(); - llvm::Type *getUnderlyingType() const; /// Utilities to identify types. bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); } bool isHalfTy() { return getUnderlyingType()->isHalfTy(); } bool isFloatTy() { return getUnderlyingType()->isFloatTy(); } bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); } - bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); } - bool isIntegerTy(unsigned bitwidth) { - return getUnderlyingType()->isIntegerTy(bitwidth); - } + bool isFloatingPointTy() { return getUnderlyingType()->isFloatingPointTy(); } /// Array type utilities. LLVMType getArrayElementType(); unsigned getArrayNumElements(); bool isArrayTy(); + /// Integer type utilities. + unsigned getIntegerBitWidth() { + return getUnderlyingType()->getIntegerBitWidth(); + } + bool isIntegerTy() { return getUnderlyingType()->isIntegerTy(); } + bool isIntegerTy(unsigned bitwidth) { + return getUnderlyingType()->isIntegerTy(bitwidth); + } + /// Vector type utilities. LLVMType getVectorElementType(); unsigned getVectorNumElements(); + llvm::ElementCount getVectorElementCount(); bool isVectorTy(); /// Function type utilities. @@ -86,11 +100,13 @@ unsigned getFunctionNumParams(); LLVMType getFunctionResultType(); bool isFunctionTy(); + bool isFunctionVarArg(); /// Pointer type utilities. LLVMType getPointerTo(unsigned addrSpace = 0); LLVMType getPointerElementTy(); bool isPointerTy(); + static bool isValidPointerElementType(LLVMType type); /// Struct type utilities. LLVMType getStructElementType(unsigned i); @@ -194,6 +210,14 @@ private: friend LLVMDialect; + friend llvm::Type *convertLLVMType(LLVMType type); + + /// Get the underlying LLVM IR type. + llvm::Type *getUnderlyingType() const; + + /// Get the underlying LLVM IR types for the given array of types. + static void getUnderlyingTypes(ArrayRef types, + SmallVectorImpl &result); /// Get an LLVMType with a pre-existing llvm type. static LLVMType get(MLIRContext *context, llvm::Type *llvmType); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -134,11 +134,9 @@ // or result in the operation. def LLVM_IntrPatterns { string operand = - [{opInst.getOperand($0).getType() - .cast().getUnderlyingType()}]; + [{convertType(opInst.getOperand($0).getType().cast())}]; string result = - [{opInst.getResult($0).getType() - .cast().getUnderlyingType()}]; + [{convertType(opInst.getResult($0).getType().cast())}]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -61,9 +61,8 @@ [{ auto llvmType = resultType.dyn_cast(); (void)llvmType; assert(llvmType && "result must be an LLVM type"); - assert(llvmType.getUnderlyingType() && - llvmType.getUnderlyingType()->isVoidTy() && - "for zero-result operands, only 'void' is accepted as result type"); + assert(llvmType.isVoidTy() && + "for zero-result operands, only 'void' is accepted as result type"); build(builder, result, operands, attributes); }]>; @@ -477,7 +476,7 @@ let verifier = [{ auto wrappedVectorType1 = v1().getType().cast(); auto wrappedVectorType2 = v2().getType().cast(); - if (!wrappedVectorType2.getUnderlyingType()->isVectorTy()) + if (!wrappedVectorType2.isVectorTy()) return emitOpError("expected LLVM IR Dialect vector type for operand #2"); if (wrappedVectorType1.getVectorElementType() != wrappedVectorType2.getVectorElementType()) @@ -765,10 +764,10 @@ .getValue().cast(); } bool isVarArg() { - return getType().getUnderlyingType()->isFunctionVarArg(); + return getType().isFunctionVarArg(); } - // Hook for OpTrait::FunctionLike, returns the number of function arguments. + // Hook for OpTrait::FunctionLike, returns the number of function arguments`. // Depends on the type attribute being correct as checked by verifyType. unsigned getNumFuncArguments(); diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -139,7 +139,7 @@ // Vector buffer load/store intrinsics def ROCDL_MubufLoadOp : - ROCDL_Op<"buffer.load">, + ROCDL_Op<"buffer.load">, Results<(outs LLVM_Type:$res)>, Arguments<(ins LLVM_Type:$rsrc, LLVM_Type:$vindex, @@ -160,7 +160,7 @@ } def ROCDL_MubufStoreOp : - ROCDL_Op<"buffer.store">, + ROCDL_Op<"buffer.store">, Arguments<(ins LLVM_Type:$vdata, LLVM_Type:$rsrc, LLVM_Type:$vindex, @@ -168,14 +168,13 @@ LLVM_Type:$glc, LLVM_Type:$slc)>{ string llvmBuilder = [{ - auto vdataType = op.vdata().getType().cast() - .getUnderlyingType(); + auto vdataType = convertType(op.vdata().getType().cast()); createIntrinsicCall(builder, - llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, + llvm::Intrinsic::amdgcn_buffer_store, {$vdata, $rsrc, $vindex, $offset, $glc, $slc}, {vdataType}); }]; let parser = [{ return parseROCDLMubufStoreOp(parser, result); }]; - let printer = [{ + let printer = [{ Operation *op = this->getOperation(); p << op->getName() << " " << op->getOperands() << " : " << vdata().getType(); diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -89,6 +89,10 @@ llvm::IRBuilder<> &builder); virtual LogicalResult convertOmpParallel(Operation &op, llvm::IRBuilder<> &builder); + + /// Converts the type from MLIR LLVM dialect to LLVM. + llvm::Type *convertType(LLVMType type); + static std::unique_ptr prepareLLVMModule(Operation *m); /// A helper to look up remapped operands in the value remapping table. diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -64,10 +64,8 @@ /// Returns the bit width of LLVMType integer or vector. static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { - return type.isVectorTy() ? type.getVectorElementType() - .getUnderlyingType() - ->getIntegerBitWidth() - : type.getUnderlyingType()->getIntegerBitWidth(); + return type.isVectorTy() ? type.getVectorElementType().getIntegerBitWidth() + : type.getIntegerBitWidth(); } /// Creates `IntegerAttribute` with all bits set for given type diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2248,10 +2248,8 @@ op, operands, typeConverter, [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {cast(llvmVectorTy.getUnderlyingType()) - ->getNumElements()}, - floatType), + mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, + floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); @@ -2511,8 +2509,8 @@ this->typeConverter.convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); - unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); - unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); + unsigned targetBits = targetType.getIntegerBitWidth(); + unsigned sourceBits = sourceType.getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, transformed.in()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -127,7 +127,7 @@ auto dataLayout = typeConverter.getDialect()->getLLVMModule().getDataLayout(); align = dataLayout.getPrefTypeAlignment( - elementTy.cast().getUnderlyingType()); + LLVM::convertLLVMType(elementTy.cast())); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -105,11 +105,9 @@ auto argType = type.dyn_cast(); if (!argType) return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); - if (argType.getUnderlyingType()->isVectorTy()) - resultType = LLVMType::getVectorTy( - resultType, - llvm::cast(argType.getUnderlyingType()) - ->getNumElements()); + if (argType.isVectorTy()) + resultType = + LLVMType::getVectorTy(resultType, argType.getVectorNumElements()); result.addTypes({resultType}); return success(); @@ -214,7 +212,7 @@ if (!llvmTy) return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"), nullptr; - if (!llvmTy.getUnderlyingType()->isPointerTy()) + if (!llvmTy.isPointerTy()) return parser.emitError(trailingTypeLoc, "expected LLVM pointer type"), nullptr; return llvmTy.getPointerElementTy(); @@ -683,8 +681,7 @@ parser.resolveOperand(position, positionType, result.operands)) return failure(); auto wrappedVectorType = type.dyn_cast(); - if (!wrappedVectorType || - !wrappedVectorType.getUnderlyingType()->isVectorTy()) + if (!wrappedVectorType || !wrappedVectorType.isVectorTy()) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); result.addTypes(wrappedVectorType.getVectorElementType()); @@ -725,16 +722,15 @@ "expected an array of integer literals"), nullptr; int position = positionElementAttr.getInt(); - auto *llvmContainerType = wrappedContainerType.getUnderlyingType(); - if (llvmContainerType->isArrayTy()) { + if (wrappedContainerType.isArrayTy()) { if (position < 0 || static_cast(position) >= - llvmContainerType->getArrayNumElements()) + wrappedContainerType.getArrayNumElements()) return parser.emitError(attributeLoc, "position out of bounds"), nullptr; wrappedContainerType = wrappedContainerType.getArrayElementType(); - } else if (llvmContainerType->isStructTy()) { + } else if (wrappedContainerType.isStructTy()) { if (position < 0 || static_cast(position) >= - llvmContainerType->getStructNumElements()) + wrappedContainerType.getStructNumElements()) return parser.emitError(attributeLoc, "position out of bounds"), nullptr; wrappedContainerType = @@ -803,8 +799,7 @@ return failure(); auto wrappedVectorType = vectorType.dyn_cast(); - if (!wrappedVectorType || - !wrappedVectorType.getUnderlyingType()->isVectorTy()) + if (!wrappedVectorType || !wrappedVectorType.isVectorTy()) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); auto valueType = wrappedVectorType.getVectorElementType(); @@ -1125,7 +1120,7 @@ } static LogicalResult verify(GlobalOp op) { - if (!llvm::PointerType::isValidElementType(op.getType().getUnderlyingType())) + if (!LLVMType::isValidPointerElementType(op.getType())) return op.emitOpError( "expects type to be a valid element type for an LLVM pointer"); if (op.getParentOp() && !satisfiesLLVMModule(op.getParentOp())) @@ -1133,8 +1128,7 @@ if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { auto type = op.getType(); - if (!type.getUnderlyingType()->isArrayTy() || - !type.getArrayElementType().getUnderlyingType()->isIntegerTy(8) || + if (!type.isArrayTy() || !type.getArrayElementType().isIntegerTy(8) || type.getArrayNumElements() != strAttr.getValue().size()) return op.emitOpError( "requires an i8 array type of the length equal to that of the string " @@ -1197,8 +1191,7 @@ parser.resolveOperand(v2, typeV2, result.operands)) return failure(); auto wrappedContainerType1 = typeV1.dyn_cast(); - if (!wrappedContainerType1 || - !wrappedContainerType1.getUnderlyingType()->isVectorTy()) + if (!wrappedContainerType1 || !wrappedContainerType1.isVectorTy()) return parser.emitError( loc, "expected LLVM IR dialect vector type for operand #1"); auto vType = LLVMType::getVectorTy( @@ -1239,7 +1232,7 @@ if (argAttrs.empty()) return; - unsigned numInputs = type.getUnderlyingType()->getFunctionNumParams(); + unsigned numInputs = type.getFunctionNumParams(); assert(numInputs == argAttrs.size() && "expected as many argument attribute lists as arguments"); SmallString<8> argAttrName; @@ -1374,7 +1367,7 @@ // getNumArguments hook not failing. LogicalResult LLVMFuncOp::verifyType() { auto llvmType = getTypeAttr().getValue().dyn_cast_or_null(); - if (!llvmType || !llvmType.getUnderlyingType()->isFunctionTy()) + if (!llvmType || !llvmType.isFunctionTy()) return emitOpError("requires '" + getTypeAttrName() + "' attribute of wrapped LLVM function type"); @@ -1384,7 +1377,7 @@ // Hook for OpTrait::FunctionLike, returns the number of function arguments. // Depends on the type attribute being correct as checked by verifyType unsigned LLVMFuncOp::getNumFuncArguments() { - return getType().getUnderlyingType()->getFunctionNumParams(); + return getType().getFunctionNumParams(); } // Hook for OpTrait::FunctionLike, returns the number of function results. @@ -1424,8 +1417,7 @@ if (op.isVarArg()) return op.emitOpError("only external functions can be variadic"); - auto *funcType = cast(op.getType().getUnderlyingType()); - unsigned numArguments = funcType->getNumParams(); + unsigned numArguments = op.getType().getFunctionNumParams(); Block &entryBlock = op.front(); for (unsigned i = 0; i < numArguments; ++i) { Type argType = entryBlock.getArgument(i).getType(); @@ -1433,7 +1425,7 @@ if (!argLLVMType) return op.emitOpError("entry block argument #") << i << " is not of LLVM type"; - if (funcType->getParamType(i) != argLLVMType.getUnderlyingType()) + if (op.getType().getFunctionParamType(i) != argLLVMType) return op.emitOpError("the type of entry block argument #") << i << " does not match the function signature"; } @@ -1566,7 +1558,7 @@ return op.emitOpError( "expected LLVM IR result type to match type for operand #1"); if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) { - if (!valType.getUnderlyingType()->isFloatingPointTy()) + if (!valType.isFloatingPointTy()) return op.emitOpError("expected LLVM IR floating point type"); } else if (op.bin_op() == AtomicBinOp::xchg) { if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && @@ -1842,6 +1834,13 @@ return getImpl()->underlyingType; } +void LLVMType::getUnderlyingTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (LLVMType ty : types) + result.push_back(ty.getUnderlyingType()); +} + /// Array type utilities. LLVMType LLVMType::getArrayElementType() { return get(getContext(), getUnderlyingType()->getArrayElementType()); @@ -1861,6 +1860,9 @@ return llvm::cast(getUnderlyingType()) ->getNumElements(); } +llvm::ElementCount LLVMType::getVectorElementCount() { + return llvm::cast(getUnderlyingType())->getElementCount(); +} bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } /// Function type utilities. @@ -1876,6 +1878,9 @@ llvm::cast(getUnderlyingType())->getReturnType()); } bool LLVMType::isFunctionTy() { return getUnderlyingType()->isFunctionTy(); } +bool LLVMType::isFunctionVarArg() { + return getUnderlyingType()->isFunctionVarArg(); +} /// Pointer type utilities. LLVMType LLVMType::getPointerTo(unsigned addrSpace) { @@ -1888,6 +1893,9 @@ return get(getContext(), getUnderlyingType()->getPointerElementType()); } bool LLVMType::isPointerTy() { return getUnderlyingType()->isPointerTy(); } +bool LLVMType::isValidPointerElementType(LLVMType type) { + return llvm::PointerType::isValidElementType(type.getUnderlyingType()); +} /// Struct type utilities. LLVMType LLVMType::getStructElementType(unsigned i) { @@ -1974,18 +1982,12 @@ isPacked); }); } -inline static SmallVector -toUnderlyingTypes(ArrayRef elements) { - SmallVector llvmElements; - for (auto elt : elements) - llvmElements.push_back(elt.getUnderlyingType()); - return llvmElements; -} LLVMType LLVMType::createStructTy(LLVMDialect *dialect, ArrayRef elements, Optional name, bool isPacked) { StringRef sr = name.hasValue() ? *name : ""; - SmallVector llvmElements(toUnderlyingTypes(elements)); + SmallVector llvmElements; + getUnderlyingTypes(elements, llvmElements); return getLocked(dialect, [=] { auto *rv = llvm::StructType::create(dialect->getLLVMContext(), sr); if (!llvmElements.empty()) @@ -1997,7 +1999,8 @@ ArrayRef elements, bool isPacked) { llvm::StructType *st = llvm::cast(structType.getUnderlyingType()); - SmallVector llvmElements(toUnderlyingTypes(elements)); + SmallVector llvmElements; + getUnderlyingTypes(elements, llvmElements); return getLocked(&structType.getDialect(), [=] { st->setBody(llvmElements, isPacked); return st; @@ -2017,6 +2020,10 @@ bool LLVMType::isVoidTy() { return getUnderlyingType()->isVoidTy(); } +llvm::Type *mlir::LLVM::convertLLVMType(LLVMType type) { + return type.getUnderlyingType(); +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -234,18 +234,17 @@ return nullptr; if (type.isIntegerTy()) - return b.getIntegerType(type.getUnderlyingType()->getIntegerBitWidth()); + return b.getIntegerType(type.getIntegerBitWidth()); - if (type.getUnderlyingType()->isFloatTy()) + if (type.isFloatTy()) return b.getF32Type(); - if (type.getUnderlyingType()->isDoubleTy()) + if (type.isDoubleTy()) return b.getF64Type(); // LLVM vectors can only contain scalars. if (type.isVectorTy()) { - auto numElements = llvm::cast(type.getUnderlyingType()) - ->getElementCount(); + auto numElements = type.getVectorElementCount(); if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr; @@ -270,9 +269,7 @@ // attribute type. if (type.getArrayElementType().isVectorTy()) { LLVMType vectorType = type.getArrayElementType(); - auto numElements = - llvm::cast(vectorType.getUnderlyingType()) - ->getElementCount(); + auto numElements = vectorType.getVectorElementCount(); if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -574,7 +574,7 @@ } if (auto lpOp = dyn_cast(opInst)) { - llvm::Type *ty = lpOp.getType().dyn_cast().getUnderlyingType(); + llvm::Type *ty = convertType(lpOp.getType().cast()); llvm::LandingPadInst *lpi = builder.CreateLandingPad(ty, lpOp.getNumOperands()); @@ -661,7 +661,7 @@ if (!wrappedType) return emitError(bb.front().getLoc(), "block argument does not have an LLVM type"); - llvm::Type *type = wrappedType.getUnderlyingType(); + llvm::Type *type = convertType(wrappedType); llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors); valueMapping[arg] = phi; } @@ -687,7 +687,7 @@ llvm::sys::SmartScopedLock scopedLock( llvmDialect->getLLVMContextMutex()); for (auto op : getModuleBody(mlirModule).getOps()) { - llvm::Type *type = op.getType().getUnderlyingType(); + llvm::Type *type = convertType(op.getType()); llvm::Constant *cst = llvm::UndefValue::get(type); if (op.getValueOrNull()) { // String attributes are treated separately because they cannot appear as @@ -826,7 +826,7 @@ // NB: Attribute already verified to be boolean, so check if we can indeed // attach the attribute to this argument, based on its type. auto argTy = mlirArg.getType().dyn_cast(); - if (!argTy.getUnderlyingType()->isPointerTy()) + if (!argTy.isPointerTy()) return func.emitError( "llvm.noalias attribute attached to LLVM non-pointer argument"); if (attr.getValue()) @@ -837,7 +837,7 @@ // NB: Attribute already verified to be int, so check if we can indeed // attach the attribute to this argument, based on its type. auto argTy = mlirArg.getType().dyn_cast(); - if (!argTy.getUnderlyingType()->isPointerTy()) + if (!argTy.isPointerTy()) return func.emitError( "llvm.align attribute attached to LLVM non-pointer argument"); llvmArg.addAttrs( @@ -896,7 +896,7 @@ for (auto function : getModuleBody(mlirModule).getOps()) { llvm::FunctionCallee llvmFuncCst = llvmModule->getOrInsertFunction( function.getName(), - cast(function.getType().getUnderlyingType())); + cast(convertType(function.getType()))); llvm::Function *llvmFunc = cast(llvmFuncCst.getCallee()); llvmFunc->setLinkage(convertLinkageToLLVM(function.linkage())); functionMapping[function.getName()] = llvmFunc; @@ -928,6 +928,10 @@ return success(); } +llvm::Type *ModuleTranslation::convertType(LLVMType type) { + return LLVM::convertLLVMType(type); +} + /// A helper to look up remapped operands in the value remapping table.` SmallVector ModuleTranslation::lookupValues(ValueRange values) { diff --git a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRConversionGen.cpp @@ -135,8 +135,7 @@ } else if (isResultName(op, name)) { bs << formatv("valueMapping[op.{0}()]", name); } else if (name == "_resultType") { - bs << "op.getResult().getType().cast()." - "getUnderlyingType()"; + bs << "convertType(op.getResult().getType().cast())"; } else if (name == "_hasResult") { bs << "opInst.getNumResults() == 1"; } else if (name == "_location") {