Index: llvm/include/llvm/IR/Attributes.h =================================================================== --- llvm/include/llvm/IR/Attributes.h +++ llvm/include/llvm/IR/Attributes.h @@ -217,9 +217,12 @@ /// if not known). std::pair> getAllocSizeArgs() const; - /// Returns the argument numbers for the vscale_range attribute (or pair(1, 0) - /// if not known). - std::pair getVScaleRangeArgs() const; + /// Returns the minimum value for the vscale_range attribute. + unsigned getVScaleRangeMin() const; + + /// Returns the maximum value for the vscale_range attribute. If omitted, + /// this will be the minimum, or None for unbounded (0). + Optional getVScaleRangeMax() const; /// The Attribute is converted to a string of equivalent mnemonic. This /// is, presumably, for writing out the mnemonics for the assembly writer. @@ -349,7 +352,8 @@ Type *getInAllocaType() const; Type *getElementType() const; std::pair> getAllocSizeArgs() const; - std::pair getVScaleRangeArgs() const; + unsigned getVScaleRangeMin() const; + Optional getVScaleRangeMax() const; std::string getAsString(bool InAttrGrp = false) const; /// Return true if this attribute set belongs to the LLVMContext. @@ -1054,9 +1058,12 @@ /// doesn't exist, pair(0, 0) is returned. std::pair> getAllocSizeArgs() const; - /// Retrieve the vscale_range args, if the vscale_range attribute exists. If - /// it doesn't exist, pair(1, 0) is returned. - std::pair getVScaleRangeArgs() const; + /// Retrieve the minimum value of 'vscale_range'. + unsigned getVScaleRangeMin() const; + + /// Retrieve the maximum value of 'vscale_range'. If omitted, + /// this will be the minimum, or None for unbounded (0). + Optional getVScaleRangeMax() const; /// Add integer attribute with raw value (packed/encoded if necessary). AttrBuilder &addRawIntAttr(Attribute::AttrKind Kind, uint64_t Value); @@ -1098,7 +1105,8 @@ const Optional &NumElemsArg); /// This turns two ints into the form used internally in Attribute. - AttrBuilder &addVScaleRangeAttr(unsigned MinValue, unsigned MaxValue); + AttrBuilder &addVScaleRangeAttr(unsigned MinValue, + Optional MaxValue); /// Add a type attribute with the given type. AttrBuilder &addTypeAttr(Attribute::AttrKind Kind, Type *Ty); Index: llvm/lib/Analysis/InstructionSimplify.cpp =================================================================== --- llvm/lib/Analysis/InstructionSimplify.cpp +++ llvm/lib/Analysis/InstructionSimplify.cpp @@ -5858,9 +5858,9 @@ auto Attr = Call->getFunction()->getFnAttribute(Attribute::VScaleRange); if (!Attr.isValid()) return nullptr; - unsigned VScaleMin, VScaleMax; - std::tie(VScaleMin, VScaleMax) = Attr.getVScaleRangeArgs(); - if (VScaleMin == VScaleMax && VScaleMax != 0) + unsigned VScaleMin = Attr.getVScaleRangeMin(); + Optional VScaleMax = Attr.getVScaleRangeMax(); + if (VScaleMax && VScaleMin == VScaleMax) return ConstantInt::get(F->getReturnType(), VScaleMin); return nullptr; } Index: llvm/lib/Analysis/ValueTracking.cpp =================================================================== --- llvm/lib/Analysis/ValueTracking.cpp +++ llvm/lib/Analysis/ValueTracking.cpp @@ -1701,23 +1701,25 @@ !II->getFunction()->hasFnAttribute(Attribute::VScaleRange)) break; - auto VScaleRange = II->getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs(); + auto Attr = II->getFunction()->getFnAttribute(Attribute::VScaleRange); + Optional VScaleMax = Attr.getVScaleRangeMax(); - if (VScaleRange.second == 0) + if (!VScaleMax) break; + unsigned VScaleMin = Attr.getVScaleRangeMin(); + // If vscale min = max then we know the exact value at compile time // and hence we know the exact bits. - if (VScaleRange.first == VScaleRange.second) { - Known.One = VScaleRange.first; - Known.Zero = VScaleRange.first; + if (VScaleMin == VScaleMax) { + Known.One = VScaleMin; + Known.Zero = VScaleMin; Known.Zero.flipAllBits(); break; } - unsigned FirstZeroHighBit = 32 - countLeadingZeros(VScaleRange.second); + unsigned FirstZeroHighBit = + 32 - countLeadingZeros(VScaleMax.getValue()); if (FirstZeroHighBit < BitWidth) Known.Zero.setBitsFrom(FirstZeroHighBit); Index: llvm/lib/AsmParser/LLParser.cpp =================================================================== --- llvm/lib/AsmParser/LLParser.cpp +++ llvm/lib/AsmParser/LLParser.cpp @@ -1306,7 +1306,8 @@ unsigned MinValue, MaxValue; if (parseVScaleRangeArguments(MinValue, MaxValue)) return true; - B.addVScaleRangeAttr(MinValue, MaxValue); + B.addVScaleRangeAttr(MinValue, + MaxValue > 0 ? MaxValue : Optional()); return false; } case Attribute::Dereferenceable: { Index: llvm/lib/IR/AttributeImpl.h =================================================================== --- llvm/lib/IR/AttributeImpl.h +++ llvm/lib/IR/AttributeImpl.h @@ -253,7 +253,8 @@ uint64_t getDereferenceableBytes() const; uint64_t getDereferenceableOrNullBytes() const; std::pair> getAllocSizeArgs() const; - std::pair getVScaleRangeArgs() const; + unsigned getVScaleRangeMin() const; + Optional getVScaleRangeMax() const; std::string getAsString(bool InAttrGrp) const; Type *getAttributeType(Attribute::AttrKind Kind) const; Index: llvm/lib/IR/Attributes.cpp =================================================================== --- llvm/lib/IR/Attributes.cpp +++ llvm/lib/IR/Attributes.cpp @@ -78,15 +78,18 @@ return std::make_pair(ElemSizeArg, NumElemsArg); } -static uint64_t packVScaleRangeArgs(unsigned MinValue, unsigned MaxValue) { - return uint64_t(MinValue) << 32 | MaxValue; +static uint64_t packVScaleRangeArgs(unsigned MinValue, + Optional MaxValue) { + return uint64_t(MinValue) << 32 | MaxValue.getValueOr(0); } -static std::pair unpackVScaleRangeArgs(uint64_t Value) { +static std::pair> +unpackVScaleRangeArgs(uint64_t Value) { unsigned MaxValue = Value & std::numeric_limits::max(); unsigned MinValue = Value >> 32; - return std::make_pair(MinValue, MaxValue); + return std::make_pair(MinValue, + MaxValue > 0 ? MaxValue : Optional()); } Attribute Attribute::get(LLVMContext &Context, Attribute::AttrKind Kind, @@ -354,10 +357,16 @@ return unpackAllocSizeArgs(pImpl->getValueAsInt()); } -std::pair Attribute::getVScaleRangeArgs() const { +unsigned Attribute::getVScaleRangeMin() const { assert(hasAttribute(Attribute::VScaleRange) && "Trying to get vscale args from non-vscale attribute"); - return unpackVScaleRangeArgs(pImpl->getValueAsInt()); + return unpackVScaleRangeArgs(pImpl->getValueAsInt()).first; +} + +Optional Attribute::getVScaleRangeMax() const { + assert(hasAttribute(Attribute::VScaleRange) && + "Trying to get vscale args from non-vscale attribute"); + return unpackVScaleRangeArgs(pImpl->getValueAsInt()).second; } std::string Attribute::getAsString(bool InAttrGrp) const { @@ -428,13 +437,13 @@ } if (hasAttribute(Attribute::VScaleRange)) { - unsigned MinValue, MaxValue; - std::tie(MinValue, MaxValue) = getVScaleRangeArgs(); + unsigned MinValue = getVScaleRangeMin(); + Optional MaxValue = getVScaleRangeMax(); std::string Result = "vscale_range("; Result += utostr(MinValue); Result += ','; - Result += utostr(MaxValue); + Result += utostr(MaxValue.getValueOr(0)); Result += ')'; return Result; } @@ -717,9 +726,12 @@ : std::pair>(0, 0); } -std::pair AttributeSet::getVScaleRangeArgs() const { - return SetNode ? SetNode->getVScaleRangeArgs() - : std::pair(1, 0); +unsigned AttributeSet::getVScaleRangeMin() const { + return SetNode ? SetNode->getVScaleRangeMin() : 1; +} + +Optional AttributeSet::getVScaleRangeMax() const { + return SetNode ? SetNode->getVScaleRangeMax() : None; } std::string AttributeSet::getAsString(bool InAttrGrp) const { @@ -897,10 +909,16 @@ return std::make_pair(0, 0); } -std::pair AttributeSetNode::getVScaleRangeArgs() const { +unsigned AttributeSetNode::getVScaleRangeMin() const { if (auto A = findEnumAttribute(Attribute::VScaleRange)) - return A->getVScaleRangeArgs(); - return std::make_pair(1, 0); + return A->getVScaleRangeMin(); + return 1; +} + +Optional AttributeSetNode::getVScaleRangeMax() const { + if (auto A = findEnumAttribute(Attribute::VScaleRange)) + return A->getVScaleRangeMax(); + return None; } std::string AttributeSetNode::getAsString(bool InAttrGrp) const { @@ -1623,8 +1641,12 @@ return unpackAllocSizeArgs(getRawIntAttr(Attribute::AllocSize)); } -std::pair AttrBuilder::getVScaleRangeArgs() const { - return unpackVScaleRangeArgs(getRawIntAttr(Attribute::VScaleRange)); +unsigned AttrBuilder::getVScaleRangeMin() const { + return unpackVScaleRangeArgs(getRawIntAttr(Attribute::VScaleRange)).first; +} + +Optional AttrBuilder::getVScaleRangeMax() const { + return unpackVScaleRangeArgs(getRawIntAttr(Attribute::VScaleRange)).second; } AttrBuilder &AttrBuilder::addAlignmentAttr(MaybeAlign Align) { @@ -1669,7 +1691,7 @@ } AttrBuilder &AttrBuilder::addVScaleRangeAttr(unsigned MinValue, - unsigned MaxValue) { + Optional MaxValue) { return addVScaleRangeAttrFromRawRepr(packVScaleRangeArgs(MinValue, MaxValue)); } Index: llvm/lib/IR/Verifier.cpp =================================================================== --- llvm/lib/IR/Verifier.cpp +++ llvm/lib/IR/Verifier.cpp @@ -2057,13 +2057,12 @@ } if (Attrs.hasFnAttr(Attribute::VScaleRange)) { - std::pair Args = - Attrs.getFnAttrs().getVScaleRangeArgs(); - - if (Args.first == 0) + unsigned VScaleMin = Attrs.getFnAttrs().getVScaleRangeMin(); + if (VScaleMin == 0) CheckFailed("'vscale_range' minimum must be greater than 0", V); - if (Args.first > Args.second && Args.second != 0) + Optional VScaleMax = Attrs.getFnAttrs().getVScaleRangeMax(); + if (VScaleMax && VScaleMin > VScaleMax) CheckFailed("'vscale_range' minimum cannot be greater than maximum", V); } Index: llvm/lib/Target/AArch64/AArch64TargetMachine.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -371,10 +371,9 @@ unsigned MaxSVEVectorSize = 0; Attribute VScaleRangeAttr = F.getFnAttribute(Attribute::VScaleRange); if (VScaleRangeAttr.isValid()) { - std::tie(MinSVEVectorSize, MaxSVEVectorSize) = - VScaleRangeAttr.getVScaleRangeArgs(); - MinSVEVectorSize *= 128; - MaxSVEVectorSize *= 128; + Optional VScaleMax = VScaleRangeAttr.getVScaleRangeMax(); + MinSVEVectorSize = VScaleRangeAttr.getVScaleRangeMin() * 128; + MaxSVEVectorSize = VScaleMax ? VScaleMax.getValue() * 128 : 0; } else { MinSVEVectorSize = SVEVectorBitsMinOpt; MaxSVEVectorSize = SVEVectorBitsMaxOpt; Index: llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp =================================================================== --- llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp +++ llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp @@ -287,10 +287,10 @@ if (!Attr.isValid()) return false; - unsigned MinVScale, MaxVScale; - std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs(); + unsigned MinVScale = Attr.getVScaleRangeMin(); + Optional MaxVScale = Attr.getVScaleRangeMax(); // The transform needs to know the exact runtime length of scalable vectors - if (MinVScale != MaxVScale || MinVScale == 0) + if (!MaxVScale || MinVScale != MaxVScale) return false; auto *PredType = @@ -351,10 +351,10 @@ if (!Attr.isValid()) return false; - unsigned MinVScale, MaxVScale; - std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs(); + unsigned MinVScale = Attr.getVScaleRangeMin(); + Optional MaxVScale = Attr.getVScaleRangeMax(); // The transform needs to know the exact runtime length of scalable vectors - if (MinVScale != MaxVScale || MinVScale == 0) + if (!MaxVScale || MinVScale != MaxVScale) return false; auto *PredType = Index: llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -965,13 +965,13 @@ if (match(Src, m_VScale(DL))) { if (Trunc.getFunction() && Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - unsigned MaxVScale = Trunc.getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - if (MaxVScale > 0 && Log2_32(MaxVScale) < DestWidth) { - Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(Trunc, VScale); + Attribute Attr = + Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional MaxVScale = Attr.getVScaleRangeMax()) { + if (Log2_32(MaxVScale.getValue()) < DestWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(Trunc, VScale); + } } } } @@ -1337,14 +1337,13 @@ if (match(Src, m_VScale(DL))) { if (CI.getFunction() && CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - unsigned MaxVScale = CI.getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); - if (MaxVScale > 0 && Log2_32(MaxVScale) < TypeWidth) { - Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional MaxVScale = Attr.getVScaleRangeMax()) { + unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); + if (Log2_32(MaxVScale.getValue()) < TypeWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } } } } @@ -1608,13 +1607,12 @@ if (match(Src, m_VScale(DL))) { if (CI.getFunction() && CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - unsigned MaxVScale = CI.getFunction() - ->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - if (MaxVScale > 0 && Log2_32(MaxVScale) < (SrcBitSize - 1)) { - Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (Optional MaxVScale = Attr.getVScaleRangeMax()) { + if (Log2_32(MaxVScale.getValue()) < (SrcBitSize - 1)) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } } } } Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5651,11 +5651,9 @@ // Limit MaxScalableVF by the maximum safe dependence distance. Optional MaxVScale = TTI.getMaxVScale(); if (!MaxVScale && TheFunction->hasFnAttribute(Attribute::VScaleRange)) { - unsigned VScaleMax = TheFunction->getFnAttribute(Attribute::VScaleRange) - .getVScaleRangeArgs() - .second; - if (VScaleMax > 0) - MaxVScale = VScaleMax; + Attribute Attr = TheFunction->getFnAttribute(Attribute::VScaleRange); + if (Optional VScaleMax = Attr.getVScaleRangeMax()) + MaxVScale = VScaleMax.getValue(); } MaxScalableVF = ElementCount::getScalable( MaxVScale ? (MaxSafeElements / MaxVScale.getValue()) : 0);