diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -907,6 +907,9 @@ bool shouldLocalize(const MachineInstr &MI, const TargetTransformInfo *TTI) const override; + + bool useSVEForFixedLengthVectors() const; + bool useSVEForFixedLengthVectorVT(MVT VT) const; }; namespace AArch64 { diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -184,6 +184,16 @@ addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass); + if (useSVEForFixedLengthVectors()) { + for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) + if (useSVEForFixedLengthVectorVT(VT)) + addRegisterClass(VT, &AArch64::ZPRRegClass); + + for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) + if (useSVEForFixedLengthVectorVT(VT)) + addRegisterClass(VT, &AArch64::ZPRRegClass); + } + for (auto VT : { MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64 }) { setOperationAction(ISD::SADDSAT, VT, Legal); setOperationAction(ISD::UADDSAT, VT, Legal); @@ -3474,6 +3484,41 @@ } } +bool AArch64TargetLowering::useSVEForFixedLengthVectors() const { + // Prefer NEON unless larger SVE registers are available. + return Subtarget->hasSVE() && Subtarget->getMinSVEVectorSizeInBits() >= 256; +} + +bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(MVT VT) const { + assert(VT.isFixedLengthVector()); + if (!useSVEForFixedLengthVectors()) + return false; + + // Fixed length predicates should be promoted to i8. + // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work. + if (VT.getVectorElementType() == MVT::i1) + return false; + + // Don't use SVE for vectors we cannot scalarize if required. + if (!isTypeLegal(VT.getVectorElementType())) + return false; + + // Ensure NEON MVTs only belong to a single register class. + if (VT.getSizeInBits() <= 128) + return false; + + // Don't use SVE for types that don't fit. + if (VT.getSizeInBits() > Subtarget->getMinSVEVectorSizeInBits()) + return false; + + // TODO: Perhaps an artificial restriction, but worth having whilst getting + // the base fixed length SVE support in place. + if (!VT.isPow2VectorType()) + return false; + + return true; +} + //===----------------------------------------------------------------------===// // Calling Convention Implementation //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.h b/llvm/lib/Target/AArch64/AArch64Subtarget.h --- a/llvm/lib/Target/AArch64/AArch64Subtarget.h +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -534,6 +534,12 @@ } void mirFileLoaded(MachineFunction &MF) const override; + + // Return the known range for the bit length of SVE data registers. A value + // of 0 means nothing is known about that particular limit beyong what's + // implied by the architecture. + unsigned getMaxSVEVectorSizeInBits() const; + unsigned getMinSVEVectorSizeInBits() const; }; } // End llvm namespace diff --git a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp --- a/llvm/lib/Target/AArch64/AArch64Subtarget.cpp +++ b/llvm/lib/Target/AArch64/AArch64Subtarget.cpp @@ -47,6 +47,18 @@ cl::desc("Call nonlazybind functions via direct GOT load"), cl::init(false), cl::Hidden); +static cl::opt SVEVectorBitsMax( + "aarch64-sve-vector-bits-max", + cl::desc("Assume SVE vector registers are at most this big, " + "with zero meaning no maximum size is assumed."), + cl::init(0), cl::Hidden); + +static cl::opt SVEVectorBitsMin( + "aarch64-sve-vector-bits-min", + cl::desc("Assume SVE vector registers are at least this big, " + "with zero meaning no minimum size is assumed."), + cl::init(0), cl::Hidden); + AArch64Subtarget & AArch64Subtarget::initializeSubtargetDependencies(StringRef FS, StringRef CPUString) { @@ -329,3 +341,25 @@ if (!MFI.isMaxCallFrameSizeComputed()) MFI.computeMaxCallFrameSize(MF); } + +unsigned AArch64Subtarget::getMaxSVEVectorSizeInBits() const { + assert(HasSVE && "Tried to get SVE vector length without SVE support!"); + assert(SVEVectorBitsMax % 128 == 0 && + "SVE requires vector length in multiples of 128!"); + assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) && + "Minimum SVE vector size should not be larger than its maximum!"); + if (SVEVectorBitsMax == 0) + return 0; + return (std::max(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128; +} + +unsigned AArch64Subtarget::getMinSVEVectorSizeInBits() const { + assert(HasSVE && "Tried to get SVE vector length without SVE support!"); + assert(SVEVectorBitsMin % 128 == 0 && + "SVE requires vector length in multiples of 128!"); + assert((SVEVectorBitsMax >= SVEVectorBitsMin || SVEVectorBitsMax == 0) && + "Minimum SVE vector size should not be larger than its maximum!"); + if (SVEVectorBitsMax == 0) + return (SVEVectorBitsMin / 128) * 128; + return (std::min(SVEVectorBitsMin, SVEVectorBitsMax) / 128) * 128; +} diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -98,6 +98,8 @@ unsigned getRegisterBitWidth(bool Vector) const { if (Vector) { + if (ST->hasSVE()) + return std::max(ST->getMinSVEVectorSizeInBits(), 128u); if (ST->hasNEON()) return 128; return 0; diff --git a/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll b/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Analysis/CostModel/AArch64/sve-fixed-length.ll @@ -0,0 +1,42 @@ +; RUN: opt < %s -cost-model -analyze | FileCheck %s -D#VBITS=128 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=128 | FileCheck %s -D#VBITS=128 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=256 | FileCheck %s -D#VBITS=256 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=384 | FileCheck %s -D#VBITS=256 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=512 | FileCheck %s -D#VBITS=512 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=640 | FileCheck %s -D#VBITS=512 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=768 | FileCheck %s -D#VBITS=512 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=896 | FileCheck %s -D#VBITS=512 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1024 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1152 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1280 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1408 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1536 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1664 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1792 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=1920 | FileCheck %s -D#VBITS=1024 +; RUN: opt < %s -cost-model -analyze -aarch64-sve-vector-bits-min=2048 | FileCheck %s -D#VBITS=2048 + +; VBITS represents the useful bit size of a vector register from the code +; generator's point of view. It is clamped to power-of-2 values because +; only power-of-2 vector lengths are considered legal, regardless of the +; user specified vector length. + +target triple = "aarch64-unknown-linux-gnu" + +; Ensure the cost of legalisation is removed as the vector length grows. +define void @add() #0 { +; CHECK-LABEL: Printing analysis 'Cost Model Analysis' for function 'add': +; CHECK: cost of [[#div(127,VBITS)+1]] for instruction: %add128 = add <4 x i32> undef, undef +; CHECK: cost of [[#div(255,VBITS)+1]] for instruction: %add256 = add <8 x i32> undef, undef +; CHECK: cost of [[#div(511,VBITS)+1]] for instruction: %add512 = add <16 x i32> undef, undef +; CHECK: cost of [[#div(1023,VBITS)+1]] for instruction: %add1024 = add <32 x i32> undef, undef +; CHECK: cost of [[#div(2047,VBITS)+1]] for instruction: %add2048 = add <64 x i32> undef, undef + %add128 = add <4 x i32> undef, undef + %add256 = add <8 x i32> undef, undef + %add512 = add <16 x i32> undef, undef + %add1024 = add <32 x i32> undef, undef + %add2048 = add <64 x i32> undef, undef + ret void +} + +attributes #0 = { "target-features"="+sve" }