diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1804,9 +1804,10 @@ op, operands, typeConverter, [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get( - {llvmVectorTy.getUnderlyingType()->getVectorNumElements()}, - floatType), + mlir::VectorType::get({(unsigned)cast( + llvmVectorTy.getUnderlyingType()) + ->getNumElements()}, + floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -102,7 +102,8 @@ return parser.emitError(trailingTypeLoc, "expected LLVM IR dialect type"); if (argType.getUnderlyingType()->isVectorTy()) resultType = LLVMType::getVectorTy( - resultType, argType.getUnderlyingType()->getVectorNumElements()); + resultType, llvm::cast(argType.getUnderlyingType()) + ->getNumElements()); result.addTypes({resultType}); return success(); @@ -1772,10 +1773,12 @@ /// Vector type utilities. LLVMType LLVMType::getVectorElementType() { - return get(getContext(), getUnderlyingType()->getVectorElementType()); + return get( + getContext(), + llvm::cast(getUnderlyingType())->getElementType()); } unsigned LLVMType::getVectorNumElements() { - return getUnderlyingType()->getVectorNumElements(); + return llvm::cast(getUnderlyingType())->getNumElements(); } bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -169,14 +169,15 @@ return LLVMType::getArrayTy(elementType, type->getArrayNumElements()); } case llvm::Type::VectorTyID: { - if (type->getVectorIsScalable()) { + auto *typeVTy = llvm::cast(type); + if (typeVTy->isScalable()) { emitError(unknownLoc) << "scalable vector types not supported"; return nullptr; } - LLVMType elementType = processType(type->getVectorElementType()); + LLVMType elementType = processType(typeVTy->getElementType()); if (!elementType) return nullptr; - return LLVMType::getVectorTy(elementType, type->getVectorNumElements()); + return LLVMType::getVectorTy(elementType, typeVTy->getNumElements()); } case llvm::Type::VoidTyID: return LLVMType::getVoidTy(dialect); @@ -243,7 +244,8 @@ // LLVM vectors can only contain scalars. if (type.isVectorTy()) { - auto numElements = type.getUnderlyingType()->getVectorElementCount(); + auto numElements = llvm::cast(type.getUnderlyingType()) + ->getElementCount(); if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr; @@ -269,7 +271,8 @@ if (type.getArrayElementType().isVectorTy()) { LLVMType vectorType = type.getArrayElementType(); auto numElements = - vectorType.getUnderlyingType()->getVectorElementCount(); + llvm::cast(vectorType.getUnderlyingType()) + ->getElementCount(); if (numElements.Scalable) { emitError(unknownLoc) << "scalable vectors not supported"; return nullptr;