Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1320,6 +1320,9 @@ bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, Align Alignment, unsigned AddrSpace) const; + /// \returns True if it is legal to vectorize the given element type. + bool isLegalToVectorizeElementType(Type *Ty, bool VectorIsScalable) const; + /// \returns True if it is legal to vectorize the given reduction kind. bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const; @@ -1704,6 +1707,8 @@ virtual bool isLegalToVectorizeStoreChain(unsigned ChainSizeInBytes, Align Alignment, unsigned AddrSpace) const = 0; + virtual bool isLegalToVectorizeElementType(Type *Ty, + bool VectorIsScalable) const = 0; virtual bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const = 0; virtual unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, @@ -2253,6 +2258,9 @@ return Impl.isLegalToVectorizeStoreChain(ChainSizeInBytes, Alignment, AddrSpace); } + bool isLegalToVectorizeElementType(Type *Ty, bool VectorIsScalable) const override { + return Impl.isLegalToVectorizeElementType(Ty, VectorIsScalable); + } bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const override { return Impl.isLegalToVectorizeReduction(RdxDesc, VF); Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -715,6 +715,10 @@ return true; } + bool isLegalToVectorizeElementType(Type *Ty, bool VectorIsScalable) const { + return true; + } + bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const { return true; Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -997,6 +997,11 @@ AddrSpace); } +bool TargetTransformInfo::isLegalToVectorizeElementType(Type *Ty, + bool VectorIsScalable) const { + return TTIImpl->isLegalToVectorizeElementType(Ty, VectorIsScalable); +} + bool TargetTransformInfo::isLegalToVectorizeReduction( RecurrenceDescriptor RdxDesc, ElementCount VF) const { return TTIImpl->isLegalToVectorizeReduction(RdxDesc, VF); Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -995,7 +995,7 @@ // <4 x i140> -> <2 x i140> if (LK.first == TypeExpandInteger) { if (VT.getVectorElementCount() == ElementCount::getScalable(1)) - report_fatal_error("Cannot legalize this scalable vector"); + return LegalizeKind(TypeScalarizeScalableVector, EltVT); return LegalizeKind(TypeSplitVector, VT.getHalfNumVectorElementsVT(Context)); } Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -288,6 +288,8 @@ bool supportsScalableVectors() const { return ST->hasSVE(); } + bool isLegalToVectorizeElementType(Type *Ty, bool VectorIsScalable) const; + bool isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const; Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1160,6 +1160,13 @@ if (!isa(Src)) return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace, CostKind); + + VectorType *VTy = dyn_cast(Src); + if (VTy && + !isLegalToVectorizeElementType(VTy->getScalarType(), + VTy->getElementCount().isScalable())) + return InstructionCost::getInvalid(); + auto LT = TLI->getTypeLegalizationCost(DL, Src); return LT.first * 2; } @@ -1171,6 +1178,13 @@ if (!isa(DataTy)) return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, Alignment, CostKind, I); + + VectorType *VTy = dyn_cast(DataTy); + if (VTy && + !isLegalToVectorizeElementType(VTy->getScalarType(), + VTy->getElementCount().isScalable())) + return InstructionCost::getInvalid(); + auto *VT = cast(DataTy); auto LT = TLI->getTypeLegalizationCost(DL, DataTy); ElementCount LegalVF = LT.second.getVectorElementCount(); @@ -1198,6 +1212,13 @@ return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, CostKind); + // Return an invalid cost if Ty is unsupported + VectorType *VTy = dyn_cast(Ty); + if (VTy && + !isLegalToVectorizeElementType(Ty->getScalarType(), + VTy->getElementCount().isScalable())) + return InstructionCost::getInvalid(); + auto LT = TLI->getTypeLegalizationCost(DL, Ty); // TODO: consider latency as well for TCK_SizeAndLatency. @@ -1501,6 +1522,17 @@ return Considerable; } +bool AArch64TTIImpl::isLegalToVectorizeElementType(Type *Ty, + bool VectorIsScalable) const { + if (Ty->isIntegerTy(1)) + return true; + + if (VectorIsScalable) + return isLegalElementTypeForSVE(Ty); + + return true; +} + bool AArch64TTIImpl::isLegalToVectorizeReduction(RecurrenceDescriptor RdxDesc, ElementCount VF) const { if (!VF.isScalable()) Index: llvm/test/Analysis/CostModel/AArch64/sve-illegal-types.ll =================================================================== --- /dev/null +++ llvm/test/Analysis/CostModel/AArch64/sve-illegal-types.ll @@ -0,0 +1,49 @@ +; RUN: opt -cost-model -analyze -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +define @load_nxvi128(* %val) { +; CHECK-LABEL: 'load_nxvi128' +; CHECK-NEXT: Invalid cost for instruction: %load = load , * %val + %load = load , * %val + ret %load +} + +define void @store_nxvi128(* %ptrs, %val) { +; CHECK-LABEL: 'store_nxvi128' +; CHECK-NEXT: Invalid cost for instruction: store %val, * %ptrs + store %val, * %ptrs + ret void +} + +define @masked_load_nxvfp128(* %val, %mask, %passthru) { +; CHECK-LABEL: 'masked_load_nxvfp128' +; CHECK-NEXT: Invalid cost for instruction: %mload = call @llvm.masked.load.nxv4f128.p0nxv4f128(* %val, i32 8, %mask, %passthru) + %mload = call @llvm.masked.load.nxv4f128(* %val, i32 8, %mask, %passthru) + ret %mload +} + +define void @masked_store_nxvfp128( %val, * %ptrs, %mask) { +; CHECK-LABEL: 'masked_store_nxvfp128' +; CHECK-NEXT: Invalid cost for instruction: call void @llvm.masked.store.nxv4f128.p0nxv4f128( %val, * %ptrs, i32 8, %mask) + call void @llvm.masked.store.nxv4f128.p0nxv4f128( %val, * %ptrs, i32 8, %mask) + ret void +} + +define @masked_gather_nxv4i128( %ld, %masks, %passthru) { +; CHECK-LABEL: 'masked_gather_nxv4i128' +; CHECK-NEXT: Invalid cost for instruction: %mgather = call @llvm.masked.gather.nxv4i128.nxv4p0i128( %ld, i32 0, %masks, %passthru) + %mgather = call @llvm.masked.gather.nxv4i128( %ld, i32 0, %masks, %passthru) + ret %mgather +} + +define void @masked_scatter_nxv4i128( %val, %ptrs, %masks) { +; CHECK-LABEL: 'masked_scatter_nxv4i128' +; CHECK-NEXT: Invalid cost for instruction: call void @llvm.masked.scatter.nxv4i128.nxv4p0i128( %val, %ptrs, i32 0, %masks) + call void @llvm.masked.scatter.nxv4i128( %val, %ptrs, i32 0, %masks) + ret void +} + +declare @llvm.masked.load.nxv4f128(*, i32, , ) +declare @llvm.masked.gather.nxv4i128(, i32, , ) + +declare void @llvm.masked.store.nxv4f128.p0nxv4f128( , *, i32, ) +declare void @llvm.masked.scatter.nxv4i128(, , i32, )