diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -28,6 +28,7 @@ #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/InstructionCost.h" +#include "llvm/Support/MachineValueType.h" #include #include @@ -732,6 +733,9 @@ /// Returns the estimated number of registers required to represent \p Ty. unsigned getRegUsageForType(Type *Ty) const; + /// Returns the MVT used to represent \p Ty. + MVT getRegMVTForType(Type *Ty) const; + /// Return true if switches should be turned into lookup tables for the /// target. bool shouldBuildLookupTables() const; @@ -1594,6 +1598,7 @@ virtual bool useAA() = 0; virtual bool isTypeLegal(Type *Ty) = 0; virtual unsigned getRegUsageForType(Type *Ty) = 0; + virtual MVT getRegMVTForType(Type *Ty) const = 0; virtual bool shouldBuildLookupTables() = 0; virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0; virtual bool shouldBuildRelLookupTables() = 0; @@ -2035,6 +2040,9 @@ unsigned getRegUsageForType(Type *Ty) override { return Impl.getRegUsageForType(Ty); } + MVT getRegMVTForType(Type *Ty) const override { + return Impl.getRegMVTForType(Ty); + } bool shouldBuildLookupTables() override { return Impl.shouldBuildLookupTables(); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -313,6 +313,7 @@ bool isTypeLegal(Type *Ty) const { return false; } unsigned getRegUsageForType(Type *Ty) const { return 1; } + MVT getRegMVTForType(Type *Ty) const { assert(false && "unimplemented"); } bool shouldBuildLookupTables() const { return true; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -387,6 +387,11 @@ return getTLI()->getNumRegisters(Ty->getContext(), ETy); } + MVT getRegMVTForType(Type *Ty) const { + EVT ETy = getTLI()->getValueType(DL, Ty); + return getTLI()->getRegisterType(Ty->getContext(), ETy); + } + InstructionCost getGEPCost(Type *PointeeType, const Value *Ptr, ArrayRef Operands, TTI::TargetCostKind CostKind) { diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -477,6 +477,10 @@ return TTIImpl->getRegUsageForType(Ty); } +MVT TargetTransformInfo::getRegMVTForType(Type *Ty) const { + return TTIImpl->getRegMVTForType(Ty); +} + bool TargetTransformInfo::shouldBuildLookupTables() const { return TTIImpl->shouldBuildLookupTables(); } diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -100,6 +100,8 @@ return 31; } + unsigned getRegisterClassForType(bool Vector, Type *Ty) const; + InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind); diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -223,6 +223,12 @@ return TTI::PSK_Software; } +unsigned AArch64TTIImpl::getRegisterClassForType(bool Vector, Type *Ty) const { + if (Ty == nullptr) + return Vector ? 1 : 0; + return getRegMVTForType(Ty).isVector() ? 1 : 0; +} + InstructionCost AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, TTI::TargetCostKind CostKind) { diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/i1-reg-usage.ll b/llvm/test/Transforms/LoopVectorize/AArch64/i1-reg-usage.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/i1-reg-usage.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/i1-reg-usage.ll @@ -7,9 +7,8 @@ ; CHECK-LABEL: LV: Checking a loop in 'or_reduction_neon' from ; CHECK: LV(REG): VF = 32 -; CHECK-NEXT: LV(REG): Found max usage: 2 item -; CHECK-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 72 registers -; CHECK-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 1 registers +; CHECK-NEXT: LV(REG): Found max usage: 1 item +; CHECK-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 72 registers define i1 @or_reduction_neon(i32 %arg, ptr %ptr) "target-features"="+neon" { entry: @@ -30,9 +29,8 @@ ; CHECK-LABEL: LV: Checking a loop in 'or_reduction_sve' ; CHECK: LV(REG): VF = 64 -; CHECK-NEXT: LV(REG): Found max usage: 2 item -; CHECK-NEXT: LV(REG): RegisterClass: Generic::VectorRC, 136 registers -; CHECK-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 1 registers +; CHECK-NEXT: LV(REG): Found max usage: 1 item +; CHECK-NEXT: LV(REG): RegisterClass: Generic::ScalarRC, 136 registers define i1 @or_reduction_sve(i32 %arg, ptr %ptr) vscale_range(2,2) "target-features"="+sve" { entry: