diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h --- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h @@ -15,6 +15,7 @@ #define LLVM_ANALYSIS_LOOPACCESSANALYSIS_H #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/DiagnosticInfo.h" @@ -338,13 +339,15 @@ /// two groups. struct RuntimeCheckingPtrGroup { /// Create a new pointer checking group containing a single - /// pointer, with index \p Index in RtCheck. - RuntimeCheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck); + /// pointer, with index \p Index in RtCheck and using \p Fork (0 or 1) if + /// this represents a forked pointer. + RuntimeCheckingPtrGroup(unsigned Index, RuntimePointerChecking &RtCheck, + unsigned Fork = 0); RuntimeCheckingPtrGroup(unsigned Index, const SCEV *Start, const SCEV *End, unsigned AS) : High(End), Low(Start), AddressSpace(AS) { - Members.push_back(Index); + Members.insert(Index); } /// Tries to add the pointer recorded in RtCheck at index @@ -352,7 +355,10 @@ /// to a checking group if we will still be able to get /// the upper and lower bounds of the check. Returns true in case /// of success, false otherwise. - bool addPointer(unsigned Index, RuntimePointerChecking &RtCheck); + /// For forked pointers this will only add one fork at a time, determined + /// by \p Fork (0 or 1). + bool addPointer(unsigned Index, RuntimePointerChecking &RtCheck, + unsigned Fork = 0); bool addPointer(unsigned Index, const SCEV *Start, const SCEV *End, unsigned AS, ScalarEvolution &SE); @@ -362,8 +368,10 @@ /// The SCEV expression which represents the lower bound of all the /// pointers in this group. const SCEV *Low; - /// Indices of all the pointers that constitute this grouping. - SmallVector Members; + /// Indices of all the pointers that constitute this grouping. Using + /// a setvector to reject duplicates for forked pointers if both forks + /// are compatible for a checking group. + SmallSetVector Members; /// Address space of the involved pointers. unsigned AddressSpace; }; @@ -382,12 +390,12 @@ struct PointerInfo { /// Holds the pointer value that we need to check. TrackingVH PointerValue; - /// Holds the smallest byte address accessed by the pointer throughout all - /// iterations of the loop. - const SCEV *Start; - /// Holds the largest byte address accessed by the pointer throughout all - /// iterations of the loop, plus 1. - const SCEV *End; + /// Holds the smallest byte address(es) accessed by the pointer throughout + /// all iterations of the loop. + SmallVector Starts; + /// Holds the largest byte address(es) accessed by the pointer throughout + /// all iterations of the loop, plus 1. + SmallVector Ends; /// Holds the information if this pointer is used for writing to memory. bool IsWritePtr; /// Holds the id of the set of pointers that could be dependent because of a @@ -395,23 +403,30 @@ unsigned DependencySetId; /// Holds the id of the disjoint alias set to which this pointer belongs. unsigned AliasSetId; - /// SCEV for the access. - const SCEV *Expr; + /// SCEV(s) for the access. + SmallVector Exprs; - PointerInfo(Value *PointerValue, const SCEV *Start, const SCEV *End, + PointerInfo(Value *PointerValue, SmallVectorImpl &ScStarts, + SmallVectorImpl &ScEnds, bool IsWritePtr, unsigned DependencySetId, unsigned AliasSetId, - const SCEV *Expr) - : PointerValue(PointerValue), Start(Start), End(End), + SmallVectorImpl &ScExprs) + : PointerValue(PointerValue), IsWritePtr(IsWritePtr), DependencySetId(DependencySetId), - AliasSetId(AliasSetId), Expr(Expr) {} + AliasSetId(AliasSetId) { + Starts.append(ScStarts); + Ends.append(ScEnds); + Exprs.append(ScExprs); + } }; - RuntimePointerChecking(ScalarEvolution *SE) : Need(false), SE(SE) {} + RuntimePointerChecking(ScalarEvolution *SE, bool AllowForkedPtrs) + : Need(false), AllowForkedPtrs(AllowForkedPtrs), SE(SE) {} /// Reset the state of the pointer runtime information. void reset() { Need = false; Pointers.clear(); + ForkedPtrs.clear(); Checks.clear(); } @@ -454,12 +469,28 @@ const SmallVectorImpl &Checks, unsigned Depth = 0) const; + /// Returns true if the pointer can decompose into two separate pointers + /// through a select. + bool isForkedPtr(const Value *Ptr) const { + return ForkedPtrs.count(Ptr) != 0; + } + /// This flag indicates if we need to add the runtime check. bool Need; + /// This flags indicates we allow diverging pointers over a select + bool AllowForkedPtrs; + /// Information about the pointers that may require checking. SmallVector Pointers; + /// Mapping between a pointer Value used by memory operations and a pair + /// of SCEV expressions. This is used in cases where a 'select' instruction + /// is used to form part of an address, and we could form a single + /// SCEVAddRecExpr with either side but not with both, so we just calculate + /// and store the two expressions to determine whether checks are required. + DenseMap> ForkedPtrs; + /// Holds a partitioning of pointers into "check groups". SmallVector CheckingGroups; diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -437,6 +437,11 @@ } }; +/// Type for a possible forked pointer, where a single pointer could be +/// one of two different pointers determined conditionally via a select +/// instruction. +using ForkedPointer = Optional>; + /// The main scalar evolution driver. Because client code (intentionally) /// can't do much with the SCEV objects directly, they must ask this class /// for services. @@ -719,6 +724,11 @@ /// This is a convenience function which does getSCEVAtScope(getSCEV(V), L). const SCEV *getSCEVAtScope(Value *V, const Loop *L); + /// This function determines whether the given pointer value V is a forked + /// pointer; that is, could it have two possible values distinguished by + /// a select instruction. + ForkedPointer findForkedPointer(Value *V, const Loop *L); + /// Test whether entry to the loop is protected by a conditional between LHS /// and RHS. This is used to help avoid max expressions in loop trip /// counts, and to eliminate casts. diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -129,6 +129,15 @@ cl::desc("Enable conflict detection in loop-access analysis"), cl::init(true)); +/// Enables the detection of forked pointers when analyzing loops; these +/// pointers could have two possible values at runtime based on a conditional +/// select instruction, and we can analyze both possibilities to determine +/// bounds. +static cl::opt EnableForkedPointers( + "enable-forked-pointer-detection", cl::Hidden, + cl::desc("Enable detection of pointers forked by a select instruction"), + cl::init(false)); + bool VectorizerParams::isInterleaveForced() { return ::VectorizationInterleave.getNumOccurrences() > 0; } @@ -169,12 +178,14 @@ } RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( - unsigned Index, RuntimePointerChecking &RtCheck) - : High(RtCheck.Pointers[Index].End), Low(RtCheck.Pointers[Index].Start), + unsigned Index, RuntimePointerChecking &RtCheck, unsigned Fork) + : High(RtCheck.Pointers[Index].Ends[Fork]), + Low(RtCheck.Pointers[Index].Starts[Fork]), AddressSpace(RtCheck.Pointers[Index] .PointerValue->getType() ->getPointerAddressSpace()) { - Members.push_back(Index); + assert((Fork == 0 || Fork == 1) && "Fork out of range for pointer checking"); + Members.insert(Index); } /// Calculate Start and End points of memory access. @@ -201,38 +212,52 @@ const SCEV *ScStart; const SCEV *ScEnd; - if (SE->isLoopInvariant(Sc, Lp)) { - ScStart = ScEnd = Sc; - } else { - const SCEVAddRecExpr *AR = dyn_cast(Sc); - assert(AR && "Invalid addrec expression"); - const SCEV *Ex = PSE.getBackedgeTakenCount(); - - ScStart = AR->getStart(); - ScEnd = AR->evaluateAtIteration(Ex, *SE); - const SCEV *Step = AR->getStepRecurrence(*SE); - - // For expressions with negative step, the upper bound is ScStart and the - // lower bound is ScEnd. - if (const auto *CStep = dyn_cast(Step)) { - if (CStep->getValue()->isNegative()) - std::swap(ScStart, ScEnd); + // See if this was a ForkedPtr. If so, we want to add checks for both sides. + SmallVector Scevs; + SmallVector Starts; + SmallVector Ends; + if (ForkedPtrs.count(Ptr)) { + Scevs.push_back(ForkedPtrs[Ptr].first); + Scevs.push_back(ForkedPtrs[Ptr].second); + } else + Scevs.push_back(Sc); + + for (const SCEV *Sc : Scevs) { + if (SE->isLoopInvariant(Sc, Lp)) { + ScStart = ScEnd = Sc; } else { - // Fallback case: the step is not constant, but we can still - // get the upper and lower bounds of the interval by using min/max - // expressions. - ScStart = SE->getUMinExpr(ScStart, ScEnd); - ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd); + const SCEVAddRecExpr *AR = dyn_cast(Sc); + assert(AR && "Invalid addrec expression"); + const SCEV *Ex = PSE.getBackedgeTakenCount(); + + ScStart = AR->getStart(); + ScEnd = AR->evaluateAtIteration(Ex, *SE); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // For expressions with negative step, the upper bound is ScStart and the + // lower bound is ScEnd. + if (const auto *CStep = dyn_cast(Step)) { + if (CStep->getValue()->isNegative()) + std::swap(ScStart, ScEnd); + } else { + // Fallback case: the step is not constant, but we can still + // get the upper and lower bounds of the interval by using min/max + // expressions. + ScStart = SE->getUMinExpr(ScStart, ScEnd); + ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd); + } } + // Add the size of the pointed element to ScEnd. + auto &DL = Lp->getHeader()->getModule()->getDataLayout(); + Type *IdxTy = DL.getIndexType(Ptr->getType()); + const SCEV *EltSizeSCEV = + SE->getStoreSizeOfExpr(IdxTy, Ptr->getType()->getPointerElementType()); + ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV); + Starts.push_back(ScStart); + Ends.push_back(ScEnd); } - // Add the size of the pointed element to ScEnd. - auto &DL = Lp->getHeader()->getModule()->getDataLayout(); - Type *IdxTy = DL.getIndexType(Ptr->getType()); - const SCEV *EltSizeSCEV = - SE->getStoreSizeOfExpr(IdxTy, Ptr->getType()->getPointerElementType()); - ScEnd = SE->getAddExpr(ScEnd, EltSizeSCEV); + Pointers.emplace_back(Ptr, Starts, Ends, WritePtr, DepSetId, ASId, Scevs); - Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); } SmallVector @@ -282,9 +307,12 @@ } bool RuntimeCheckingPtrGroup::addPointer(unsigned Index, - RuntimePointerChecking &RtCheck) { + RuntimePointerChecking &RtCheck, + unsigned Fork) { + assert((Fork == 0 || Fork == 1) && "Fork out of range for pointer checking"); return addPointer( - Index, RtCheck.Pointers[Index].Start, RtCheck.Pointers[Index].End, + Index, RtCheck.Pointers[Index].Starts[Fork], + RtCheck.Pointers[Index].Ends[Fork], RtCheck.Pointers[Index].PointerValue->getType()->getPointerAddressSpace(), *RtCheck.SE); } @@ -314,7 +342,7 @@ if (Min1 != End) High = End; - Members.push_back(Index); + Members.insert(Index); return true; } @@ -365,8 +393,15 @@ // for correctness, because in this case we can have checking between // pointers to the same underlying object. if (!UseDependencies) { - for (unsigned I = 0; I < Pointers.size(); ++I) - CheckingGroups.push_back(RuntimeCheckingPtrGroup(I, *this)); + for (unsigned I = 0; I < Pointers.size(); ++I) { + auto Group = CheckingGroups.emplace_back(I, *this); + // Try to add a fork to the same group first, otherwise establish + // a new one. + if (Pointers[I].Exprs.size() == 2) + if (!Group.addPointer(I, *this, /*Fork=*/1)) + CheckingGroups.emplace_back(I, *this, /*Fork=*/1); + + } return; } @@ -406,7 +441,8 @@ assert(PointerI != PositionMap.end() && "pointer in equivalence class not found in PositionMap"); unsigned Pointer = PointerI->second; - bool Merged = false; + bool Merged[2] = {false, false}; + const unsigned NumExprs = Pointers[Pointer].Exprs.size(); // Mark this pointer as seen. Seen.insert(Pointer); @@ -422,17 +458,28 @@ TotalComparisons++; - if (Group.addPointer(Pointer, *this)) { - Merged = true; + // Try to add both forks, if applicable. + for (unsigned J = 0; J < NumExprs; ++J) + if (!Merged[J]) + Merged[J] = Group.addPointer(Pointer, *this, J); + + if (Merged[0] && (Merged[1] || NumExprs == 1)) break; - } } - if (!Merged) - // We couldn't add this pointer to any existing set or the threshold - // for the number of comparisons has been reached. Create a new group - // to hold the current pointer. - Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); + + // If we couldn't add the pointer expression(s) to any existing set or + // the threshold for the number of comparisons has been reached, create a + // new group to hold the current pointer expression(s). + RuntimeCheckingPtrGroup *Group = nullptr; + if (!Merged[0]) + Group = &Groups.emplace_back(Pointer, *this); + if (!Merged[1] && NumExprs == 2) { + // Try merging forks first, as they may have the same base address + // but different offsets. + if (!Group || !Group->addPointer(Pointer, *this, /*Fork=*/1)) + Groups.emplace_back(Pointer, *this, /*Fork=*/1); + } } // We've computed the grouped checks for this partition. @@ -499,8 +546,10 @@ OS.indent(Depth + 4) << "(Low: " << *CG.Low << " High: " << *CG.High << ")\n"; for (unsigned J = 0; J < CG.Members.size(); ++J) { - OS.indent(Depth + 6) << "Member: " << *Pointers[CG.Members[J]].Expr - << "\n"; + const PointerInfo &Info = Pointers[CG.Members[J]]; + OS.indent(Depth + 6) << "Member: " << *(Info.Exprs[0]) << "\n"; + if (Info.Exprs.size() == 2) + OS.indent(Depth + 6) << "Member(Fork): " << *(Info.Exprs[1]) << "\n"; } } } @@ -634,11 +683,13 @@ /// by adding run-time checks (overflow checks) if necessary. static bool hasComputableBounds(PredicatedScalarEvolution &PSE, const ValueToValueMap &Strides, Value *Ptr, - Loop *L, bool Assume) { + Loop *L, bool Assume, + RuntimePointerChecking &RtCheck) { const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + ScalarEvolution *SE = PSE.getSE(); // The bounds for loop-invariant pointer is trivial. - if (PSE.getSE()->isLoopInvariant(PtrScev, L)) + if (SE->isLoopInvariant(PtrScev, L)) return true; const SCEVAddRecExpr *AR = dyn_cast(PtrScev); @@ -646,19 +697,94 @@ if (!AR && Assume) AR = PSE.getAsAddRec(Ptr); - if (!AR) + if (!AR) { + if (RtCheck.AllowForkedPtrs) { + // If we can't find a single SCEVAddRecExpr, then maybe we can find two. + // The findForkedPointer function tries walking backwards through IR until + // it finds operands that are either loop invariant or for which a + // SCEVAddRecExpr can be formed. If a select instruction is encountered + // during this walk then we split the walk and try to generate two valid + // SCEVAddRecExprs. Both of those SCEVs can then be considered when + // deciding whether runtime checks are required. + ForkedPointer FPtr = SE->findForkedPointer(Ptr, L); + if (FPtr) { + const SCEV* A = FPtr->first; + const SCEV* B = FPtr->second; + LLVM_DEBUG(dbgs() << "LAA: ForkedPtr found: " << *Ptr << "\n"); + LLVM_DEBUG(dbgs() << "LAA: SCEV1: " << *(FPtr->first) << "\n"); + LLVM_DEBUG(dbgs() << "LAA: SCEV2: " << *(FPtr->second) << "\n"); + RtCheck.ForkedPtrs[Ptr] = *FPtr; + if (isa(A) && cast(A)->isAffine() && + isa(B) && cast(B)->isAffine()) + return true; + LLVM_DEBUG(dbgs() << + "LAA: Could not determine bounds for forked pointer\n"); + } + } return false; + } return AR->isAffine(); } /// Check whether a pointer address cannot wrap. static bool isNoWrap(PredicatedScalarEvolution &PSE, - const ValueToValueMap &Strides, Value *Ptr, Loop *L) { + const ValueToValueMap &Strides, Value *Ptr, Loop *L, + RuntimePointerChecking &RtCheck) { const SCEV *PtrScev = PSE.getSCEV(Ptr); if (PSE.getSE()->isLoopInvariant(PtrScev, L)) return true; + // The SCEV generated directly from a forked pointer is not an AddRecExpr, + // but an unknown SCEV. The PSE.hasNoOverflow method currently assumes + // that it must be an AddRecExpr and just casts. So we just bail out at + // this point, since we can't pass in the SCEVs one at a time -- + // hasNoOverflow takes a Value* as a param instead of a SCEV*. + // TODO: Support overflow checking for ForkedPointers. + // + // We can, however, calculate an effective stride for each side of the fork + // and check if the stride is 1 for both; this is the other way we can + // assume a pointer doesn't wrap. + if (RtCheck.isForkedPtr(Ptr)) { + auto SCEVs = RtCheck.ForkedPtrs[Ptr]; + const SCEVAddRecExpr *LSAR = cast(SCEVs.first); + const SCEVAddRecExpr *RSAR = cast(SCEVs.second); + const SCEV *LStep = LSAR->getStepRecurrence(*PSE.getSE()); + const SCEV *RStep = RSAR->getStepRecurrence(*PSE.getSE()); + + if (LStep != RStep) { + LLVM_DEBUG(dbgs() << "LAA: Forkedptr with mismatched steps\n"); + return false; + } + + auto *PtrTy = dyn_cast(Ptr->getType()); + auto &DL = L->getHeader()->getModule()->getDataLayout(); + int64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); + + const SCEVConstant *C = dyn_cast(LStep); + if (!C) { + LLVM_DEBUG(dbgs() << "LAA: Forkedptr with non-constant steps\n"); + return false; + } + const APInt &APStepVal = C->getAPInt(); + if (APStepVal.getBitWidth() > 64) { + LLVM_DEBUG(dbgs() << "LAA: Forkedptr step is too large\n"); + return false; + } + + int64_t StepVal = APStepVal.getSExtValue(); + + // Strided access. + int64_t Stride = StepVal / Size; + int64_t Rem = StepVal % Size; + if (Rem) { + LLVM_DEBUG(dbgs() << "LAA: Forkedptr step not multiple of elt size\n"); + return false; + } + + return Stride == 1; + } + int64_t Stride = getPtrStride(PSE, Ptr, L, Strides); if (Stride == 1 || PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW)) return true; @@ -675,12 +801,12 @@ bool Assume) { Value *Ptr = Access.getPointer(); - if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume)) + if (!hasComputableBounds(PSE, StridesMap, Ptr, TheLoop, Assume, RtCheck)) return false; // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. - if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop)) { + if (ShouldCheckWrap && !isNoWrap(PSE, StridesMap, Ptr, TheLoop, RtCheck)) { auto *Expr = PSE.getSCEV(Ptr); if (!Assume || !isa(Expr)) return false; @@ -1817,6 +1943,7 @@ PtrRtChecking->Pointers.clear(); PtrRtChecking->Need = false; + PtrRtChecking->ForkedPtrs.clear(); const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); @@ -2192,7 +2319,8 @@ const TargetLibraryInfo *TLI, AAResults *AA, DominatorTree *DT, LoopInfo *LI) : PSE(std::make_unique(*SE, *L)), - PtrRtChecking(std::make_unique(SE)), + PtrRtChecking(std::make_unique(SE, + EnableForkedPointers)), DepChecker(std::make_unique(*PSE, L)), TheLoop(L), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false), HasConvergentOp(false), diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -9086,6 +9086,165 @@ return getSCEVAtScope(getSCEV(V), L); } +// Walk back through the IR for a pointer, looking for a select like the +// following: +// +// %offset = select i1 %cmp, i64 %a, i64 %b +// %addr = getelementptr double, double* %base, i64 %offset +// %ld = load double, double* %addr, align 8 +// +// We won't be able to form a single SCEVAddRecExpr from this since the +// address for each loop iteration depends on %cmp. We could potentially +// produce multiple valid SCEVAddRecExprs, though, and check all of them for +// memory safety/aliasing if needed. +// +// If we encounter some IR we don't yet handle, or something obviously fine +// like a constant, then we just add the SCEV for that term to the list passed +// in by the caller. If we have a node that may potentially yield a valid +// SCEVAddRecExpr then we decompose it into parts and build the SCEV terms +// ourselves before adding to the list. +static void findForkedSCEVs(ScalarEvolution *SE, const Loop *L, Value *Ptr, + SmallVectorImpl &ScevList) { + const SCEV *Scev = SE->getSCEV(Ptr); + if (SE->isLoopInvariant(Scev, L) || isa(Scev) || + !isa(Ptr)) { + ScevList.push_back(Scev); + return; + } + + auto GetBinOpExpr = [&SE](unsigned Opcode, const SCEV* L, const SCEV* R) { + switch (Opcode) { + case Instruction::Add: + return SE->getAddExpr(L, R); + case Instruction::Sub: + return SE->getMinusSCEV(L, R); + case Instruction::Mul: + return SE->getMulExpr(L, R); + default: + llvm_unreachable("Unexpected binary operator when walking ForkedPtrs"); + } + }; + + Instruction *I = cast(Ptr); + unsigned Opcode = I->getOpcode(); + switch (Opcode) { + case Instruction::BitCast: + findForkedSCEVs(SE, L, I->getOperand(0), ScevList); + break; + case Instruction::SExt: + case Instruction::ZExt: { + SmallVector ExtScevs; + findForkedSCEVs(SE, L, I->getOperand(0), ExtScevs); + for (const SCEV *Scev : ExtScevs) + if (Opcode == Instruction::SExt) + ScevList.push_back(SE->getSignExtendExpr(Scev, I->getType())); + else + ScevList.push_back(SE->getZeroExtendExpr(Scev, I->getType())); + break; + } + case Instruction::GetElementPtr: { + GetElementPtrInst *GEP = cast(I); + Type *SourceTy = GEP->getSourceElementType(); + // We only handle base + single offset GEPs here for now. + // Not dealing with preexisting gathers yet, so no vectors. + if (I->getNumOperands() != 2 || SourceTy->isVectorTy()) { + ScevList.push_back(Scev); + break; + } + SmallVector BaseScevs; + SmallVector OffsetScevs; + findForkedSCEVs(SE, L, I->getOperand(0), BaseScevs); + findForkedSCEVs(SE, L, I->getOperand(1), OffsetScevs); + + // Make sure we get the correct pointer type to extend to, including the + // address space. + const SCEV *BaseExpr = SE->getSCEV(GEP->getPointerOperand()); + Type *IntPtrTy = SE->getEffectiveSCEVType(BaseExpr->getType()); + SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW + : SCEV::FlagAnyWrap; + // Find the size of the type being pointed to. We only have a single + // index term (guarded above) so we don't need to index into arrays or + // structures, just get the size of the scalar value. + const SCEV *Size = SE->getSizeOfExpr(IntPtrTy, SourceTy); + + if (OffsetScevs.size() == 2 && BaseScevs.size() == 1) { + const SCEV *Off1 = SE->getTruncateOrSignExtend(OffsetScevs[0], IntPtrTy); + const SCEV *Off2 = SE->getTruncateOrSignExtend(OffsetScevs[1], IntPtrTy); + const SCEV *Mul1 = SE->getMulExpr(Size, Off1, Wrap); + const SCEV *Mul2 = SE->getMulExpr(Size, Off2, Wrap); + const SCEV *Add1 = SE->getAddExpr(BaseScevs[0], Mul1, Wrap); + const SCEV *Add2 = SE->getAddExpr(BaseScevs[0], Mul2, Wrap); + ScevList.push_back(Add1); + ScevList.push_back(Add2); + } else if (BaseScevs.size() == 2 && OffsetScevs.size() == 1) { + const SCEV *Off = SE->getTruncateOrSignExtend(OffsetScevs[0], IntPtrTy); + const SCEV *Mul = SE->getMulExpr(Size, Off, Wrap); + const SCEV *Add1 = SE->getAddExpr(BaseScevs[0], Mul, Wrap); + const SCEV *Add2 = SE->getAddExpr(BaseScevs[1], Mul, Wrap); + ScevList.push_back(Add1); + ScevList.push_back(Add2); + } else + ScevList.push_back(Scev); + break; + } + case Instruction::Select: { + SmallVector ChildScevs; + // A select means we've found a forked pointer, but we currently only + // support a single select per pointer so if there's another behind this + // then we just bail out and return the generic SCEV. + findForkedSCEVs(SE, L, I->getOperand(1), ChildScevs); + findForkedSCEVs(SE, L, I->getOperand(2), ChildScevs); + if (ChildScevs.size() == 2) { + ScevList.push_back(ChildScevs[0]); + ScevList.push_back(ChildScevs[1]); + } else + ScevList.push_back(Scev); + break; + } + // If adding another binop to this list, update GetBinOpExpr above + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: { + SmallVector LScevs; + SmallVector RScevs; + findForkedSCEVs(SE, L, I->getOperand(0), LScevs); + findForkedSCEVs(SE, L, I->getOperand(1), RScevs); + if (LScevs.size() == 2 && RScevs.size() == 1) { + const SCEV *Op1 = GetBinOpExpr(Opcode, LScevs[0], RScevs[0]); + const SCEV *Op2 = GetBinOpExpr(Opcode, LScevs[1], RScevs[0]); + ScevList.push_back(Op1); + ScevList.push_back(Op2); + } else if (LScevs.size() == 1 && RScevs.size() == 2) { + const SCEV *Op1 = GetBinOpExpr(Opcode, LScevs[0], RScevs[0]); + const SCEV *Op2 = GetBinOpExpr(Opcode, LScevs[0], RScevs[1]); + ScevList.push_back(Op1); + ScevList.push_back(Op2); + } else + ScevList.push_back(Scev); + break; + } + default: + // Just return the current SCEV if we haven't handled the instruction yet. + LLVM_DEBUG(dbgs() << "ForkedPtr unhandled instruction: " << *I << "\n"); + ScevList.push_back(Scev); + break; + } + + return; +} + +ForkedPointer ScalarEvolution::findForkedPointer(Value *V, const Loop *L) { + assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); + SmallVector Scevs; + findForkedSCEVs(this, L, V, Scevs); + + // For now, we will only accept a forked pointer with two options. + if (Scevs.size() == 2) + return std::make_pair(Scevs[0], Scevs[1]); + + return None; +} + const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) return stripInjectiveFunctions(ZExt->getOperand()); diff --git a/llvm/test/Transforms/LoopVectorize/forked-pointers.ll b/llvm/test/Transforms/LoopVectorize/forked-pointers.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopVectorize/forked-pointers.ll @@ -0,0 +1,306 @@ +; RUN: opt -loop-accesses -analyze -enable-new-pm=0 %s 2>&1 | FileCheck %s --check-prefix=NO-FORKED-PTRS +; RUN: opt -disable-output -passes='require,require,loop(print-access-info)' %s 2>&1 | FileCheck %s --check-prefix=NO-FORKED-PTRS +; RUN: opt -enable-forked-pointer-detection -loop-accesses -analyze -enable-new-pm=0 %s 2>&1 | FileCheck %s --check-prefix=FORKED-PTRS +; RUN: opt -enable-forked-pointer-detection -disable-output -passes='require,require,loop(print-access-info)' %s 2>&1 | FileCheck %s --check-prefix=FORKED-PTRS +; RUN: opt -loop-vectorize -instcombine -enable-forked-pointer-detection -force-vector-width=4 -S < %s 2>&1 | FileCheck %s --check-prefix=FP-VEC + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" + +; NO-FORKED-PTRS-LABEL: function 'forked_ptrs_different_base_same_offset': +; NO-FORKED-PTRS: for.body: +; NO-FORKED-PTRS: Report: cannot identify array bounds + + +; FORKED-PTRS-LABEL: function 'forked_ptrs_different_base_same_offset': +; FORKED-PTRS: for.body: +; FORKED-PTRS: Memory dependences are safe with run-time checks +; FORKED-PTRS: Dependences: +; FORKED-PTRS: Run-time memory checks: +; FORKED-PTRS: Check 0: +; FORKED-PTRS: Comparing group +; FORKED-PTRS: %1 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv +; FORKED-PTRS: Against group +; FORKED-PTRS: %arrayidx = getelementptr inbounds i32, i32* %Preds, i64 %indvars.iv +; FORKED-PTRS: Check 1: +; FORKED-PTRS: Comparing group +; FORKED-PTRS: %1 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv +; FORKED-PTRS: Against group +; FORKED-PTRS: %.sink.in = getelementptr inbounds float, float* %spec.select, i64 %indvars.iv +; FORKED-PTRS: Check 2: +; FORKED-PTRS: Comparing group +; FORKED-PTRS: %1 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv +; FORKED-PTRS: Against group +; FORKED-PTRS: %.sink.in = getelementptr inbounds float, float* %spec.select, i64 %indvars.iv +; FORKED-PTRS: Grouped accesses: +; FORKED-PTRS: Group +; FORKED-PTRS: (Low: %Dest High: (400 + %Dest)) +; FORKED-PTRS: Member: {%Dest,+,4}<%for.body> +; FORKED-PTRS: Group +; FORKED-PTRS: (Low: %Preds High: (400 + %Preds)) +; FORKED-PTRS: Member: {%Preds,+,4}<%for.body> +; FORKED-PTRS: Group +; FORKED-PTRS: (Low: %Base2 High: (400 + %Base2)) +; FORKED-PTRS: Member: {%Base2,+,4}<%for.body> +; FORKED-PTRS: Member(Fork): {%Base1,+,4}<%for.body> +; FORKED-PTRS: Group +; FORKED-PTRS: (Low: %Base1 High: (400 + %Base1)) +; FORKED-PTRS: Member: {%Base2,+,4}<%for.body> +; FORKED-PTRS: Member(Fork): {%Base1,+,4}<%for.body> +; FORKED-PTRS: Non vectorizable stores to invariant address were not found in loop. +; FORKED-PTRS: SCEV assumptions: +; FORKED-PTRS: Expressions re-written: + +; FP-VEC-LABEL: @forked_ptrs_different_base_same_offset +; FP-VEC: vector.memcheck: +; FP-VEC: [[DESTEND:%[a-zA-Z0-9]+]] = getelementptr float, float* %Dest, i64 100 +; FP-VEC: [[PREDSEND:%[a-zA-Z0-9]+]] = getelementptr i32, i32* %Preds, i64 100 +; FP-VEC: [[BASE2END:%[a-zA-Z0-9]+]] = getelementptr float, float* %Base2, i64 100 +; FP-VEC: [[BASE1END:%[a-zA-Z0-9]+]] = getelementptr float, float* %Base1, i64 100 + +;;;; Preds vs. Dest +; FP-VEC: [[PREDSCAST:%[0-9]+]] = bitcast i32* [[PREDSEND]] to float* +; FP-VEC: [[PBOUND1:%[a-zA-Z0-9]+]] = icmp ugt float* [[PREDSCAST]], %Dest +; FP-VEC: [[DESTCAST:%[0-9]+]] = bitcast float* [[DESTEND]] to i32* +; FP-VEC: [[PBOUND2:%[a-zA-Z0-9]+]] = icmp ugt i32* [[DESTCAST]], %Preds +; FP-VEC: [[PCONFLICT:%[a-zA-Z0-9\.]+]] = and i1 [[PBOUND1]], [[PBOUND2]] + +;;;; Base2 vs. Dest +; FP-VEC: [[B2BOUND1:%[a-zA-Z0-9]+]] = icmp ugt float* [[BASE2END]], %Dest +; FP-VEC: [[B2BOUND2:%[a-zA-Z0-9]+]] = icmp ugt float* [[DESTEND]], %Base2 +; FP-VEC: [[B2CONFLICT:%[a-zA-Z0-9\.]+]] = and i1 [[B2BOUND1]], [[B2BOUND2]] +; FP-VEC: [[COMBINED1:%[a-zA-Z0-9\.]+]] = or i1 [[PCONFLICT]], [[B2CONFLICT]] + +;;;; Base1 vs. Dest +; FP-VEC: [[B1BOUND1:%[a-zA-Z0-9]+]] = icmp ugt float* [[BASE1END]], %Dest +; FP-VEC: [[B1BOUND2:%[a-zA-Z0-9]+]] = icmp ugt float* [[DESTEND]], %Base1 +; FP-VEC: [[B1CONFLICT:%[a-zA-Z0-9\.]+]] = and i1 [[B1BOUND1]], [[B1BOUND2]] +; FP-VEC: [[COMBINED2:%[a-zA-Z0-9\.]+]] = or i1 [[COMBINED1]], [[B1CONFLICT]] + +; FP-VEC: br i1 [[COMBINED2]], label %scalar.ph, label %vector.ph + +;;;; Check we get a gather load. We could do two contiguous masked loads but +;;;; haven't implemented that yet. +; FP-VEC: [[LANE1:%[0-9]+]] = load float +; FP-VEC: [[LANE2:%[0-9]+]] = load float +; FP-VEC: [[LANE3:%[0-9]+]] = load float +; FP-VEC: [[LANE4:%[0-9]+]] = load float +; FP-VEC: [[INS1:%[0-9]+]] = insertelement <4 x float> poison, float [[LANE1]], i32 0 +; FP-VEC: [[INS2:%[0-9]+]] = insertelement <4 x float> [[INS1]], float [[LANE2]], i32 1 +; FP-VEC: [[INS3:%[0-9]+]] = insertelement <4 x float> [[INS2]], float [[LANE3]], i32 2 +; FP-VEC: [[INS4:%[0-9]+]] = insertelement <4 x float> [[INS3]], float [[LANE4]], i32 3 + + +;;;; Derived from the following C code +;; void forked_ptrs_different_base_same_offset(float *A, float *B, float *C, int *D) { +;; for (int i=0; i<100; i++) { +;; if (D[i] != 0) { +;; C[i] = A[i]; +;; } else { +;; C[i] = B[i]; +;; } +;; } +;; } + +define dso_local void @forked_ptrs_different_base_same_offset(float* nocapture readonly %Base1, float* nocapture readonly %Base2, float* nocapture %Dest, i32* nocapture readonly %Preds) { +entry: + br label %for.body + +for.cond.cleanup: + ret void + +for.body: + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %arrayidx = getelementptr inbounds i32, i32* %Preds, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %cmp1.not = icmp eq i32 %0, 0 + %spec.select = select i1 %cmp1.not, float* %Base2, float* %Base1 + %.sink.in = getelementptr inbounds float, float* %spec.select, i64 %indvars.iv + %.sink = load float, float* %.sink.in, align 4 + %1 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv + store float %.sink, float* %1, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +} + +; NO-FORKED-PTRS-LABEL: function 'forked_ptrs_same_base_different_offset': +; NO-FORKED-PTRS: for.body: +; NO-FORKED-PTRS: Report: cannot identify array bounds + + +; FORKED-PTRS-LABEL: function 'forked_ptrs_same_base_different_offset': +; FORKED-PTRS: for.body: +; FORKED-PTRS: Memory dependences are safe with run-time checks +; FORKED-PTRS: Dependences: +; FORKED-PTRS: Run-time memory checks: +; FORKED-PTRS: Check 0: +; FORKED-PTRS: Comparing group +; FORKED-PTRS: %arrayidx5 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv +; FORKED-PTRS: Against group +; FORKED-PTRS: %arrayidx = getelementptr inbounds i32, i32* %Preds, i64 %indvars.iv +; FORKED-PTRS: Check 1: +; FORKED-PTRS: Comparing group +; FORKED-PTRS: %arrayidx5 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv +; FORKED-PTRS: Against group +; FORKED-PTRS: %arrayidx3 = getelementptr inbounds float, float* %Base, i64 %idxprom213 +; FORKED-PTRS: Grouped accesses: +; FORKED-PTRS: Group +; FORKED-PTRS: (Low: %Dest High: (400 + %Dest)) +; FORKED-PTRS: Member: {%Dest,+,4}<%for.body> +; FORKED-PTRS: Group +; FORKED-PTRS: (Low: %Preds High: (400 + %Preds)) +; FORKED-PTRS: Member: {%Preds,+,4}<%for.body> +; FORKED-PTRS: Group +; FORKED-PTRS: (Low: %Base High: (404 + %Base)) +; FORKED-PTRS: Member: {(4 + %Base),+,4}<%for.body> +; FORKED-PTRS: Member(Fork): {%Base,+,4}<%for.body> +; FORKED-PTRS: Non vectorizable stores to invariant address were not found in loop. +; FORKED-PTRS: SCEV assumptions: +; FORKED-PTRS: Expressions re-written: + +; FP-VEC-LABEL: @forked_ptrs_same_base_different_offset +; FP-VEC: vector.memcheck: +; FP-VEC: [[DESTEND:%[a-zA-Z0-9]+]] = getelementptr float, float* %Dest, i64 100 +; FP-VEC: [[PREDSEND:%[a-zA-Z0-9]+]] = getelementptr i32, i32* %Preds, i64 100 +; FP-VEC: [[BASEEND:%[a-zA-Z0-9]+]] = getelementptr float, float* %Base, i64 101 + +;;;; Preds vs. Dest +; FP-VEC: [[PREDSCAST:%[0-9]+]] = bitcast i32* [[PREDSEND]] to float* +; FP-VEC: [[PBOUND1:%[a-zA-Z0-9]+]] = icmp ugt float* [[PREDSCAST]], %Dest +; FP-VEC: [[DESTCAST:%[0-9]+]] = bitcast float* [[DESTEND]] to i32* +; FP-VEC: [[PBOUND2:%[a-zA-Z0-9]+]] = icmp ugt i32* [[DESTCAST]], %Preds +; FP-VEC: [[PCONFLICT:%[a-zA-Z0-9\.]+]] = and i1 [[PBOUND1]], [[PBOUND2]] + +;;;; Base vs. Dest +; FP-VEC: [[BBOUND1:%[a-zA-Z0-9]+]] = icmp ugt float* [[BASEEND]], %Dest +; FP-VEC: [[BBOUND2:%[a-zA-Z0-9]+]] = icmp ugt float* [[DESTEND]], %Base +; FP-VEC: [[BCONFLICT:%[a-zA-Z0-9\.]+]] = and i1 [[BBOUND1]], [[BBOUND2]] +; FP-VEC: [[COMBINED:%[a-zA-Z0-9\.]+]] = or i1 [[PCONFLICT]], [[BCONFLICT]] + +; FP-VEC: br i1 [[COMBINED]], label %scalar.ph, label %vector.ph + +;;;; Check we get a gather load. We could do a better job here, especially +;;;; since the offset is only 1 element, but we haven't implemented that yet. +; FP-VEC: [[LANE1:%[0-9]+]] = load float +; FP-VEC: [[LANE2:%[0-9]+]] = load float +; FP-VEC: [[LANE3:%[0-9]+]] = load float +; FP-VEC: [[LANE4:%[0-9]+]] = load float +; FP-VEC: [[INS1:%[0-9]+]] = insertelement <4 x float> poison, float [[LANE1]], i32 0 +; FP-VEC: [[INS2:%[0-9]+]] = insertelement <4 x float> [[INS1]], float [[LANE2]], i32 1 +; FP-VEC: [[INS3:%[0-9]+]] = insertelement <4 x float> [[INS2]], float [[LANE3]], i32 2 +; FP-VEC: [[INS4:%[0-9]+]] = insertelement <4 x float> [[INS3]], float [[LANE4]], i32 3 + +;;;; Derived from the following C code +;; void forked_ptrs_same_base_different_offset(float *A, float *B, int *C) { +;; int offset; +;; for (int i = 0; i < 100; i++) { +;; if (C[i] != 0) +;; offset = i; +;; else +;; offset = i+1; +;; B[i] = A[offset]; +;; } +;; } + +define dso_local void @forked_ptrs_same_base_different_offset(float* nocapture readonly %Base, float* nocapture %Dest, i32* nocapture readonly %Preds) { +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret void + +for.body: ; preds = %entry, %for.body + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %i.014 = phi i32 [ 0, %entry ], [ %add, %for.body ] + %arrayidx = getelementptr inbounds i32, i32* %Preds, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %cmp1.not = icmp eq i32 %0, 0 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %add = add nuw nsw i32 %i.014, 1 + %1 = trunc i64 %indvars.iv to i32 + %offset.0 = select i1 %cmp1.not, i32 %add, i32 %1 + %idxprom213 = zext i32 %offset.0 to i64 + %arrayidx3 = getelementptr inbounds float, float* %Base, i64 %idxprom213 + %2 = load float, float* %arrayidx3, align 4 + %arrayidx5 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv + store float %2, float* %arrayidx5, align 4 + %exitcond.not = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +} + +;;;; Cases that can be handled by a forked pointer but are not currently allowed. + +; NO-FORKED-PTRS-LABEL: function 'forked_ptrs_uniform_and_contiguous_forks': +; NO-FORKED-PTRS: for.body: +; NO-FORKED-PTRS: Report: cannot identify array bounds + + +; FORKED-PTRS-LABEL: function 'forked_ptrs_uniform_and_contiguous_forks': +; FORKED-PTRS: for.body: +; FORKED-PTRS: Report: cannot identify array bounds + +;;;; Derived from forked_ptrs_same_base_different_offset with a manually +;;;; added uniform offset. + +define dso_local void @forked_ptrs_uniform_and_contiguous_forks(float* nocapture readonly %Base, float* nocapture %Dest, i32* nocapture readonly %Preds) { +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret void + +for.body: ; preds = %entry, %for.body + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %i.014 = phi i32 [ 0, %entry ], [ %add, %for.body ] + %arrayidx = getelementptr inbounds i32, i32* %Preds, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %cmp1.not = icmp eq i32 %0, 0 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %add = add nuw nsw i32 %i.014, 1 + %1 = trunc i64 %indvars.iv to i32 + %offset.0 = select i1 %cmp1.not, i32 4, i32 %1 + %idxprom213 = zext i32 %offset.0 to i64 + %arrayidx3 = getelementptr inbounds float, float* %Base, i64 %idxprom213 + %2 = load float, float* %arrayidx3, align 4 + %arrayidx5 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv + store float %2, float* %arrayidx5, align 4 + %exitcond.not = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +} + +; NO-FORKED-PTRS-LABEL: function 'forked_ptrs_gather_and_contiguous_forks': +; NO-FORKED-PTRS: for.body: +; NO-FORKED-PTRS: Report: cannot identify array bounds + + +; FORKED-PTRS-LABEL: function 'forked_ptrs_gather_and_contiguous_forks': +; FORKED-PTRS: for.body: +; FORKED-PTRS: Report: cannot identify array bounds + +;;;; Derived from forked_ptrs_same_base_different_offset with a gather +;;;; added using Preds as an index array in addition to the per-iteration +;;;; condition. + +define dso_local void @forked_ptrs_gather_and_contiguous_forks(float* nocapture readonly %Base1, float* nocapture readonly %Base2, float* nocapture %Dest, i32* nocapture readonly %Preds) { +entry: + br label %for.body + +for.cond.cleanup: ; preds = %for.body + ret void + +for.body: ; preds = %entry, %for.body + %indvars.iv = phi i64 [ 0, %entry ], [ %indvars.iv.next, %for.body ] + %arrayidx = getelementptr inbounds i32, i32* %Preds, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %cmp1.not = icmp eq i32 %0, 0 + %arrayidx9 = getelementptr inbounds float, float* %Base2, i64 %indvars.iv + %idxprom4 = sext i32 %0 to i64 + %arrayidx5 = getelementptr inbounds float, float* %Base1, i64 %idxprom4 + %.sink.in = select i1 %cmp1.not, float* %arrayidx9, float* %arrayidx5 + %.sink = load float, float* %.sink.in, align 4 + %1 = getelementptr inbounds float, float* %Dest, i64 %indvars.iv + store float %.sink, float* %1, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, 100 + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body +}