Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1320,6 +1320,9 @@ bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, Align Alignment, unsigned AddrSpace) const; + /// \returns True if it is legal to vectorize the given element type. + bool isElementTypeLegalForScalableVector(Type *Ty) const; + /// \returns True if it is legal to vectorize the given reduction kind. bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const; @@ -1704,6 +1707,7 @@ virtual bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, Align Alignment, unsigned AddrSpace) const = 0; + virtual bool isElementTypeLegalForScalableVector(Type *Ty) const = 0; virtual bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const = 0; virtual unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, @@ -2253,6 +2257,9 @@ return Impl.isLegalToVectorizeStoreChain(ChainSizeInBytes, Alignment, AddrSpace); } + bool isElementTypeLegalForScalableVector(Type *Ty) const override { + return Impl.isElementTypeLegalForScalableVector(Ty); + } bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const override { return Impl.isLegalToVectorizeReduction(RdxDesc, VF); Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -720,6 +720,8 @@ return true; } + bool isElementTypeLegalForScalableVector(Type *Ty) const { return true; } + unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, unsigned ChainSizeInBytes, VectorType *VecTy) const { Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -997,6 +997,10 @@ AddrSpace); } +bool TargetTransformInfo::isElementTypeLegalForScalableVector(Type *Ty) const { + return TTIImpl->isElementTypeLegalForScalableVector(Ty); +} + bool TargetTransformInfo::isLegalToVectorizeReduction( RecurrenceDescriptor RdxDesc, ElementCount VF) const { return TTIImpl->isLegalToVectorizeReduction(RdxDesc, VF); Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -995,7 +995,7 @@ // <4 x i140> -> <2 x i140> if (LK.first == TypeExpandInteger) { if (VT.getVectorElementCount() == ElementCount::getScalable(1)) - report_fatal_error("Cannot legalize this scalable vector"); + return LegalizeKind(TypeScalarizeScalableVector, EltVT); return LegalizeKind(TypeSplitVector, VT.getHalfNumVectorElementsVT(Context)); } Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -288,6 +288,8 @@ bool supportsScalableVectors() const { return ST->hasSVE(); } + bool isElementTypeLegalForScalableVector(Type *Ty) const; + bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const; Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1160,6 +1160,10 @@ if (!isa(Src)) return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace, CostKind); + + if (!isElementTypeLegalForScalableVector(Src->getScalarType())) + return InstructionCost::getInvalid(); + auto LT = TLI->getTypeLegalizationCost(DL, Src); return LT.first * 2; } @@ -1171,6 +1175,10 @@ if (!isa(DataTy)) return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I); + + if (!isElementTypeLegalForScalableVector(DataTy->getScalarType())) + return InstructionCost::getInvalid(); + auto *VT = cast(DataTy); auto LT = TLI->getTypeLegalizationCost(DL, DataTy); ElementCount LegalVF = LT.second.getVectorElementCount(); @@ -1198,6 +1206,11 @@ return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, CostKind); + // Return an invalid cost if Ty is unsupported + VectorType *VTy = dyn_cast(Ty); + if (VTy && !isElementTypeLegalForScalableVector(VTy->getScalarType())) + return InstructionCost::getInvalid(); + auto LT = TLI->getTypeLegalizationCost(DL, Ty); // TODO: consider latency as well for TCK_SizeAndLatency. @@ -1501,6 +1514,10 @@ return Considerable; } +bool AArch64TTIImpl::isElementTypeLegalForScalableVector(Type *Ty) const { + return Ty->isIntegerTy(1) || isLegalElementTypeForSVE(Ty); +} + bool AArch64TTIImpl::isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const { if (!VF.isScalable()) Index: llvm/test/Analysis/CostModel/AArch64/sve-illegal-types.ll =================================================================== --- /dev/null +++ llvm/test/Analysis/CostModel/AArch64/sve-illegal-types.ll @@ -0,0 +1,49 @@ +; RUN: opt -cost-model -analyze -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +define @load_nxvi128(* %val) { +; CHECK-LABEL: 'load_nxvi128' +; CHECK-NEXT: Invalid cost for instruction: %load = load , * %val + %load = load , * %val + ret %load +} + +define void @store_nxvi128(* %ptrs, %val) { +; CHECK-LABEL: 'store_nxvi128' +; CHECK-NEXT: Invalid cost for instruction: store %val, * %ptrs + store %val, * %ptrs + ret void +} + +define @masked_load_nxvfp128(* %val, %mask, %passthru) { +; CHECK-LABEL: 'masked_load_nxvfp128' +; CHECK-NEXT: Invalid cost for instruction: %mload = call @llvm.masked.load.nxv4f128.p0nxv4f128(* %val, i32 8, %mask, %passthru) + %mload = call @llvm.masked.load.nxv4f128(* %val, i32 8, %mask, %passthru) + ret %mload +} + +define void @masked_store_nxvfp128( %val, * %ptrs, %mask) { +; CHECK-LABEL: 'masked_store_nxvfp128' +; CHECK-NEXT: Invalid cost for instruction: call void @llvm.masked.store.nxv4f128.p0nxv4f128( %val, * %ptrs, i32 8, %mask) + call void @llvm.masked.store.nxv4f128( %val, * %ptrs, i32 8, %mask) + ret void +} + +define @masked_gather_nxv2i128( %ld, %masks, %passthru) { +; CHECK-LABEL: 'masked_gather_nxv2i128' +; CHECK-NEXT: Invalid cost for instruction: %mgather = call @llvm.masked.gather.nxv2i128.nxv2p0i128( %ld, i32 0, %masks, %passthru) + %mgather = call @llvm.masked.gather.nxv2i128( %ld, i32 0, %masks, %passthru) + ret %mgather +} + +define void @masked_scatter_nxv4i128( %val, %ptrs, %masks) { +; CHECK-LABEL: 'masked_scatter_nxv4i128' +; CHECK-NEXT: Invalid cost for instruction: call void @llvm.masked.scatter.nxv4i128.nxv4p0i128( %val, %ptrs, i32 0, %masks) + call void @llvm.masked.scatter.nxv4i128( %val, %ptrs, i32 0, %masks) + ret void +} + +declare @llvm.masked.load.nxv4f128(*, i32, , ) +declare @llvm.masked.gather.nxv2i128(, i32, , ) + +declare void @llvm.masked.store.nxv4f128(, *, i32, ) +declare void @llvm.masked.scatter.nxv4i128(, , i32, )