Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -23,6 +23,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/BitmaskEnum.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/FoldingSet.h" @@ -113,7 +114,8 @@ FlagNW = (1 << 0), // No self-wrap. FlagNUW = (1 << 1), // No unsigned wrap. FlagNSW = (1 << 2), // No signed wrap. - NoWrapMask = (1 << 3) - 1 + NoWrapMask = (1 << 3) - 1, + LLVM_MARK_AS_BITMASK_ENUM(/*LargestEnumerator=*/NoWrapMask), }; explicit SCEV(const FoldingSetNodeIDRef ID, unsigned SCEVTy) @@ -312,37 +314,10 @@ IncrementNUSW = (1 << 0), // No unsigned with signed increment wrap. IncrementNSSW = (1 << 1), // No signed with signed increment wrap // (equivalent with SCEV::NSW) - IncrementNoWrapMask = (1 << 2) - 1 + IncrementNoWrapMask = IncrementNUSW | IncrementNSSW, + LLVM_MARK_AS_BITMASK_ENUM(/*LargestEnumerator=*/IncrementNoWrapMask), }; - /// Convenient IncrementWrapFlags manipulation methods. - LLVM_NODISCARD static SCEVWrapPredicate::IncrementWrapFlags - clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, - SCEVWrapPredicate::IncrementWrapFlags OffFlags) { - assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); - assert((OffFlags & IncrementNoWrapMask) == OffFlags && - "Invalid flags value!"); - return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & ~OffFlags); - } - - LLVM_NODISCARD static SCEVWrapPredicate::IncrementWrapFlags - maskFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, int Mask) { - assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); - assert((Mask & IncrementNoWrapMask) == Mask && "Invalid mask value!"); - - return (SCEVWrapPredicate::IncrementWrapFlags)(Flags & Mask); - } - - LLVM_NODISCARD static SCEVWrapPredicate::IncrementWrapFlags - setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, - SCEVWrapPredicate::IncrementWrapFlags OnFlags) { - assert((Flags & IncrementNoWrapMask) == Flags && "Invalid flags value!"); - assert((OnFlags & IncrementNoWrapMask) == OnFlags && - "Invalid flags value!"); - - return (SCEVWrapPredicate::IncrementWrapFlags)(Flags | OnFlags); - } - /// Returns the set of SCEVWrapPredicate no wrap flags implied by a /// SCEVAddRecExpr. LLVM_NODISCARD static SCEVWrapPredicate::IncrementWrapFlags @@ -467,21 +442,6 @@ ProperlyDominatesBlock ///< The SCEV properly dominates the block. }; - /// Convenient NoWrapFlags manipulation that hides enum casts and is - /// visible in the ScalarEvolution name space. - LLVM_NODISCARD static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, - int Mask) { - return (SCEV::NoWrapFlags)(Flags & Mask); - } - LLVM_NODISCARD static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, - SCEV::NoWrapFlags OnFlags) { - return (SCEV::NoWrapFlags)(Flags | OnFlags); - } - LLVM_NODISCARD static SCEV::NoWrapFlags - clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags) { - return (SCEV::NoWrapFlags)(Flags & ~OffFlags); - } - ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI); ScalarEvolution(ScalarEvolution &&Arg); Index: llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h +++ llvm/include/llvm/Analysis/ScalarEvolutionExpressions.h @@ -331,7 +331,7 @@ /// to make it easier to propagate flags. void setNoWrapFlags(NoWrapFlags Flags) { if (Flags & (FlagNUW | FlagNSW)) - Flags = ScalarEvolution::setFlags(Flags, FlagNW); + Flags |= FlagNW; SubclassData |= Flags; } Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -1428,8 +1428,7 @@ // `Step`: // 1. NSW/NUW flags on the step increment. - auto PreStartFlags = - ScalarEvolution::maskFlags(SA->getNoWrapFlags(), SCEV::FlagNUW); + auto PreStartFlags = SA->getNoWrapFlags() & SCEV::FlagNUW; const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags); const SCEVAddRecExpr *PreAR = dyn_cast( SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap)); @@ -2218,9 +2217,8 @@ (void)CanAnalyze; assert(CanAnalyze && "don't call from other places!"); - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = - ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = Flags & SignOrUnsignMask; // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. auto IsKnownNonNegative = [&](const SCEV *S) { @@ -2228,10 +2226,9 @@ }; if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative)) - Flags = - ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); + Flags |= SignOrUnsignMask; - SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); + SignOrUnsignWrap = Flags & SignOrUnsignMask; if (SignOrUnsignWrap != SignOrUnsignMask && (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 && @@ -2255,7 +2252,7 @@ auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Opcode, C, OBO::NoSignedWrap); if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + Flags |= SCEV::FlagNSW; } // (A C) --> (A C) if the op doesn't unsign overflow. @@ -2263,7 +2260,7 @@ auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Instruction::Add, C, OBO::NoUnsignedWrap); if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + Flags |= SCEV::FlagNUW; } } @@ -2590,7 +2587,7 @@ // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. // Always propagate NW. - Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW)); + Flags = AddRec->getNoWrapFlags(Flags | SCEV::FlagNW); const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. @@ -2756,7 +2753,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl &Ops, SCEV::NoWrapFlags Flags, unsigned Depth) { - assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && + assert(Flags == (Flags & (SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty mul!"); if (Ops.size() == 1) return Ops[0]; @@ -2905,7 +2902,7 @@ // // No self-wrap cannot be guaranteed after changing the step size, but // will be inferred if either NUW or NSW is true. - Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW)); + Flags = AddRec->getNoWrapFlags(Flags & ~SCEV::FlagNW); const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags); // If all of the other operands were loop invariant, we are done. @@ -3244,7 +3241,7 @@ if (const SCEVAddRecExpr *StepChrec = dyn_cast(Step)) if (StepChrec->getLoop() == L) { Operands.append(StepChrec->op_begin(), StepChrec->op_end()); - return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW)); + return getAddRecExpr(Operands, L, Flags & SCEV::FlagNW); } Operands.push_back(Step); @@ -3302,7 +3299,7 @@ // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the // inner recurrence has the same property. SCEV::NoWrapFlags OuterFlags = - maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags()); + Flags & (SCEV::FlagNW | NestedAR->getNoWrapFlags()); NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags); AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) { @@ -3315,7 +3312,7 @@ // The inner recurrence keeps its NW flag but only keeps NUW/NSW if // the outer recurrence has the same property. SCEV::NoWrapFlags InnerFlags = - maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags); + NestedAR->getNoWrapFlags() & (SCEV::FlagNW | Flags); return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags); } } @@ -3900,7 +3897,7 @@ auto AddFlags = SCEV::FlagAnyWrap; const bool RHSIsNotMinSigned = !getSignedRangeMin(RHS).isMinSignedValue(); - if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) { + if ((Flags & SCEV::FlagNSW) == SCEV::FlagNSW) { // Let M be the minimum representable signed value. Then (-1)*RHS // signed-wraps if and only if RHS is M. That can happen even for // a NSW subtraction because e.g. (-1)*M signed-wraps even though @@ -4340,7 +4337,7 @@ auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Instruction::Add, IncRange, OBO::NoSignedWrap); if (NSWRegion.contains(AddRecRange)) - Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); + Result |= SCEV::FlagNSW; } if (!AR->hasNoUnsignedWrap()) { @@ -4350,7 +4347,7 @@ auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( Instruction::Add, IncRange, OBO::NoUnsignedWrap); if (NUWRegion.contains(AddRecRange)) - Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); + Result |= SCEV::FlagNUW; } return Result; @@ -4906,9 +4903,9 @@ SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; if (BO->IsNUW) - Flags = setFlags(Flags, SCEV::FlagNUW); + Flags |= SCEV::FlagNUW; if (BO->IsNSW) - Flags = setFlags(Flags, SCEV::FlagNSW); + Flags |= SCEV::FlagNSW; const SCEV *StartVal = getSCEV(StartValueV); const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); @@ -5004,9 +5001,9 @@ if (auto BO = MatchBinaryOp(BEValueV, DT)) { if (BO->Opcode == Instruction::Add && BO->LHS == PN) { if (BO->IsNUW) - Flags = setFlags(Flags, SCEV::FlagNUW); + Flags |= SCEV::FlagNUW; if (BO->IsNSW) - Flags = setFlags(Flags, SCEV::FlagNSW); + Flags |= SCEV::FlagNSW; } } else if (GEPOperator *GEP = dyn_cast(BEValueV)) { // If the increment is an inbounds GEP, then we know the address @@ -5016,11 +5013,11 @@ // pointer. We can guarantee that no unsigned wrap occurs if the // indices form a positive value. if (GEP->isInBounds() && GEP->getOperand(0) == PN) { - Flags = setFlags(Flags, SCEV::FlagNW); + Flags |= SCEV::FlagNW; const SCEV *Ptr = getSCEV(GEP->getPointerOperand()); if (isKnownPositive(getMinusSCEV(getSCEV(GEP), Ptr))) - Flags = setFlags(Flags, SCEV::FlagNUW); + Flags |= SCEV::FlagNUW; } // We cannot transfer nuw and nsw flags from subtraction @@ -5868,9 +5865,9 @@ // Return early if there are no flags to propagate to the SCEV. SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; if (BinOp->hasNoUnsignedWrap()) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); + Flags |= SCEV::FlagNUW; if (BinOp->hasNoSignedWrap()) - Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); + Flags |= SCEV::FlagNSW; if (Flags == SCEV::FlagAnyWrap) return SCEV::FlagAnyWrap; @@ -11936,15 +11933,15 @@ bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { const auto *Op = dyn_cast(N); - return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; + return Op && Op->AR == AR && (Flags | Op->Flags) == Flags; } bool SCEVWrapPredicate::isAlwaysTrue() const { SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); IncrementWrapFlags IFlags = Flags; - if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) - IFlags = clearFlags(IFlags, IncrementNSSW); + if ((ScevFlags | SCEV::FlagNSW) == ScevFlags) + IFlags &= ~IncrementNSSW; return IFlags == IncrementAnyWrap; } @@ -11965,15 +11962,15 @@ SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); // We can safely transfer the NSW flag as NSSW. - if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) + if ((StaticFlags | SCEV::FlagNSW) == StaticFlags) ImpliedFlags = IncrementNSSW; - if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { + if ((StaticFlags | SCEV::FlagNUW) == StaticFlags) { // If the increment is positive, the SCEV NUW flag will also imply the // WrapPredicate NUSW flag. if (const auto *Step = dyn_cast(AR->getStepRecurrence(SE))) if (Step->getValue()->getValue().isNonNegative()) - ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); + ImpliedFlags |= IncrementNUSW; } return ImpliedFlags; @@ -12096,12 +12093,12 @@ auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); // Clear the statically implied flags. - Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); + Flags &= ~ImpliedFlags; addPredicate(*SE.getWrapPredicate(AR, Flags)); auto II = FlagsMap.insert({V, Flags}); if (!II.second) - II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); + II.first->second |= Flags; } bool PredicatedScalarEvolution::hasNoOverflow( @@ -12109,13 +12106,12 @@ const SCEV *Expr = getSCEV(V); const auto *AR = cast(Expr); - Flags = SCEVWrapPredicate::clearFlags( - Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + Flags &= ~SCEVWrapPredicate::getImpliedFlags(AR, SE); auto II = FlagsMap.find(V); if (II != FlagsMap.end()) - Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + Flags &= ~II->second; return Flags == SCEVWrapPredicate::IncrementAnyWrap; }