diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -6539,13 +6539,10 @@ if (!NumOperands) { switch (IID) { case Intrinsic::vscale: { - auto Attr = Call->getFunction()->getFnAttribute(Attribute::VScaleRange); - if (!Attr.isValid()) - return nullptr; - unsigned VScaleMin = Attr.getVScaleRangeMin(); - std::optional VScaleMax = Attr.getVScaleRangeMax(); - if (VScaleMax && VScaleMin == VScaleMax) - return ConstantInt::get(F->getReturnType(), VScaleMin); + Type *RetTy = F->getReturnType(); + ConstantRange CR = getVScaleRange(Call->getFunction(), 64); + if (const APInt *C = CR.getSingleElement()) + return ConstantInt::get(RetTy, C->getZExtValue()); return nullptr; } default: diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1947,13 +1947,10 @@ if (ConstantFold) { const MachineFunction &MF = getMachineFunction(); - auto Attr = MF.getFunction().getFnAttribute(Attribute::VScaleRange); - if (Attr.isValid()) { - unsigned VScaleMin = Attr.getVScaleRangeMin(); - if (std::optional VScaleMax = Attr.getVScaleRangeMax()) - if (*VScaleMax == VScaleMin) - return getConstant(MulImm * VScaleMin, DL, VT); - } + const Function &F = MF.getFunction(); + ConstantRange CR = getVScaleRange(&F, 64); + if (const APInt *C = CR.getSingleElement()) + return getConstant(MulImm * C->getZExtValue(), DL, VT); } return getNode(ISD::VSCALE, DL, VT, getConstant(MulImm, DL, VT)); diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -21,6 +21,7 @@ #include "TargetInfo/AArch64TargetInfo.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/CFIFixup.h" #include "llvm/CodeGen/CSEConfigBase.h" #include "llvm/CodeGen/GlobalISel/CSEInfo.h" @@ -398,11 +399,10 @@ unsigned MinSVEVectorSize = 0; unsigned MaxSVEVectorSize = 0; - Attribute VScaleRangeAttr = F.getFnAttribute(Attribute::VScaleRange); - if (VScaleRangeAttr.isValid()) { - std::optional VScaleMax = VScaleRangeAttr.getVScaleRangeMax(); - MinSVEVectorSize = VScaleRangeAttr.getVScaleRangeMin() * 128; - MaxSVEVectorSize = VScaleMax ? *VScaleMax * 128 : 0; + if (F.hasFnAttribute(Attribute::VScaleRange)) { + ConstantRange CR = getVScaleRange(&F, 64); + MinSVEVectorSize = CR.getUnsignedMin().getZExtValue() * 128; + MaxSVEVectorSize = CR.getUnsignedMax().getZExtValue() * 128; } else { MinSVEVectorSize = SVEVectorBitsMinOpt; MaxSVEVectorSize = SVEVectorBitsMaxOpt; @@ -417,12 +417,10 @@ // Sanitize user input in case of no asserts if (MaxSVEVectorSize == 0) - MinSVEVectorSize = (MinSVEVectorSize / 128) * 128; + MinSVEVectorSize = MinSVEVectorSize; else { - MinSVEVectorSize = - (std::min(MinSVEVectorSize, MaxSVEVectorSize) / 128) * 128; - MaxSVEVectorSize = - (std::max(MinSVEVectorSize, MaxSVEVectorSize) / 128) * 128; + MinSVEVectorSize = std::min(MinSVEVectorSize, MaxSVEVectorSize); + MaxSVEVectorSize = std::max(MinSVEVectorSize, MaxSVEVectorSize); } SmallString<512> Key;