diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -7542,66 +7542,68 @@ // Return the llvm predicate vector type corresponding to the specified element // TypeFlags. -llvm::VectorType* CodeGenFunction::getSVEPredType(SVETypeFlags TypeFlags) { +llvm::ScalableVectorType * +CodeGenFunction::getSVEPredType(SVETypeFlags TypeFlags) { switch (TypeFlags.getEltType()) { default: llvm_unreachable("Unhandled SVETypeFlag!"); case SVETypeFlags::EltTyInt8: - return llvm::VectorType::get(Builder.getInt1Ty(), { 16, true }); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 16); case SVETypeFlags::EltTyInt16: - return llvm::VectorType::get(Builder.getInt1Ty(), { 8, true }); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8); case SVETypeFlags::EltTyInt32: - return llvm::VectorType::get(Builder.getInt1Ty(), { 4, true }); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4); case SVETypeFlags::EltTyInt64: - return llvm::VectorType::get(Builder.getInt1Ty(), { 2, true }); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2); case SVETypeFlags::EltTyFloat16: - return llvm::VectorType::get(Builder.getInt1Ty(), { 8, true }); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8); case SVETypeFlags::EltTyFloat32: - return llvm::VectorType::get(Builder.getInt1Ty(), { 4, true }); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4); case SVETypeFlags::EltTyFloat64: - return llvm::VectorType::get(Builder.getInt1Ty(), { 2, true }); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2); } } // Return the llvm vector type corresponding to the specified element TypeFlags. -llvm::VectorType *CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) { +llvm::ScalableVectorType * +CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) { switch (TypeFlags.getEltType()) { default: llvm_unreachable("Invalid SVETypeFlag!"); case SVETypeFlags::EltTyInt8: - return llvm::VectorType::get(Builder.getInt8Ty(), {16, true}); + return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16); case SVETypeFlags::EltTyInt16: - return llvm::VectorType::get(Builder.getInt16Ty(), {8, true}); + return llvm::ScalableVectorType::get(Builder.getInt16Ty(), 8); case SVETypeFlags::EltTyInt32: - return llvm::VectorType::get(Builder.getInt32Ty(), {4, true}); + return llvm::ScalableVectorType::get(Builder.getInt32Ty(), 4); case SVETypeFlags::EltTyInt64: - return llvm::VectorType::get(Builder.getInt64Ty(), {2, true}); + return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2); case SVETypeFlags::EltTyFloat16: - return llvm::VectorType::get(Builder.getHalfTy(), {8, true}); + return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8); case SVETypeFlags::EltTyFloat32: - return llvm::VectorType::get(Builder.getFloatTy(), {4, true}); + return llvm::ScalableVectorType::get(Builder.getFloatTy(), 4); case SVETypeFlags::EltTyFloat64: - return llvm::VectorType::get(Builder.getDoubleTy(), {2, true}); + return llvm::ScalableVectorType::get(Builder.getDoubleTy(), 2); case SVETypeFlags::EltTyBool8: - return llvm::VectorType::get(Builder.getInt1Ty(), {16, true}); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 16); case SVETypeFlags::EltTyBool16: - return llvm::VectorType::get(Builder.getInt1Ty(), {8, true}); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 8); case SVETypeFlags::EltTyBool32: - return llvm::VectorType::get(Builder.getInt1Ty(), {4, true}); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 4); case SVETypeFlags::EltTyBool64: - return llvm::VectorType::get(Builder.getInt1Ty(), {2, true}); + return llvm::ScalableVectorType::get(Builder.getInt1Ty(), 2); } } constexpr unsigned SVEBitsPerBlock = 128; -static llvm::VectorType* getSVEVectorForElementType(llvm::Type *EltTy) { +static llvm::ScalableVectorType *getSVEVectorForElementType(llvm::Type *EltTy) { unsigned NumElts = SVEBitsPerBlock / EltTy->getScalarSizeInBits(); - return llvm::VectorType::get(EltTy, { NumElts, true }); + return llvm::ScalableVectorType::get(EltTy, NumElts); } // Reinterpret the input predicate so that it can be used to correctly isolate @@ -7640,8 +7642,8 @@ SmallVectorImpl &Ops, unsigned IntID) { auto *ResultTy = getSVEType(TypeFlags); - auto *OverloadedTy = llvm::VectorType::get(SVEBuiltinMemEltTy(TypeFlags), - ResultTy->getElementCount()); + auto *OverloadedTy = + llvm::ScalableVectorType::get(SVEBuiltinMemEltTy(TypeFlags), ResultTy); // At the ACLE level there's only one predicate type, svbool_t, which is // mapped to . However, this might be incompatible with the @@ -7692,8 +7694,8 @@ SmallVectorImpl &Ops, unsigned IntID) { auto *SrcDataTy = getSVEType(TypeFlags); - auto *OverloadedTy = llvm::VectorType::get(SVEBuiltinMemEltTy(TypeFlags), - SrcDataTy->getElementCount()); + auto *OverloadedTy = + llvm::ScalableVectorType::get(SVEBuiltinMemEltTy(TypeFlags), SrcDataTy); // In ACLE the source data is passed in the last argument, whereas in LLVM IR // it's the first argument. Move it accordingly. @@ -7748,7 +7750,7 @@ unsigned BuiltinID) { auto *MemEltTy = SVEBuiltinMemEltTy(TypeFlags); auto *VectorTy = getSVEVectorForElementType(MemEltTy); - auto *MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount()); + auto *MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy); Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy); Value *BasePtr = Ops[1]; @@ -7778,8 +7780,8 @@ // The vector type that is returned may be different from the // eventual type loaded from memory. - auto VectorTy = cast(ReturnTy); - auto MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount()); + auto VectorTy = cast(ReturnTy); + auto MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy); Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy); Value *BasePtr = Builder.CreateBitCast(Ops[1], MemoryTy->getPointerTo()); @@ -7803,8 +7805,8 @@ // The vector type that is stored may be different from the // eventual type stored to memory. - auto VectorTy = cast(Ops.back()->getType()); - auto MemoryTy = llvm::VectorType::get(MemEltTy, VectorTy->getElementCount()); + auto VectorTy = cast(Ops.back()->getType()); + auto MemoryTy = llvm::ScalableVectorType::get(MemEltTy, VectorTy); Value *Predicate = EmitSVEPredicateCast(Ops[0], MemoryTy); Value *BasePtr = Builder.CreateBitCast(Ops[1], MemoryTy->getPointerTo()); diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3911,8 +3911,8 @@ SmallVector getSVEOverloadTypes(SVETypeFlags TypeFlags, ArrayRef Ops); llvm::Type *getEltType(SVETypeFlags TypeFlags); - llvm::VectorType *getSVEType(const SVETypeFlags &TypeFlags); - llvm::VectorType *getSVEPredType(SVETypeFlags TypeFlags); + llvm::ScalableVectorType *getSVEType(const SVETypeFlags &TypeFlags); + llvm::ScalableVectorType *getSVEPredType(SVETypeFlags TypeFlags); llvm::Value *EmitSVEDupX(llvm::Value *Scalar); llvm::Value *EmitSVEPredicateCast(llvm::Value *Pred, llvm::VectorType *VTy); llvm::Value *EmitSVEGatherLoad(SVETypeFlags TypeFlags,