Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -93,6 +93,19 @@ FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; +static bool +setLoopComponent(Value *&RHS, Value *&TripCount, BinaryOperator *&Increment, + SmallPtrSetImpl &IterationInstructions) { + // This is a helper function for use in findLoopComponents which sets the loop + // trip count. + TripCount = RHS; + IterationInstructions.insert(Increment); + LLVM_DEBUG(dbgs() << "Found Increment: "; Increment->dump()); + LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); + LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); + return true; +} + // Finds the induction variable, increment and trip count for a simple loop that // we can flatten. static bool findLoopComponents( @@ -164,49 +177,68 @@ return false; } // The trip count is the RHS of the compare. If this doesn't match the trip - // count computed by SCEV then this is either because the trip count variable - // has been widened (then leave the trip count as it is), or because it is a - // constant and another transformation has changed the compare, e.g. - // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1. - TripCount = Compare->getOperand(1); + // count computed by SCEV then this is because the trip count variable + // has been widened so the types don't match, or because it is a constant and + // another transformation has changed the compare (e.g. icmp ult %inc, + // tripcount -> icmp ult %j, tripcount-1), or both. + Value *RHS = Compare->getOperand(1); const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); if (isa(BackedgeTakenCount)) { LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); return false; } const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount); - if (SE->getSCEV(TripCount) != SCEVTripCount && !IsWidened) { - ConstantInt *RHS = dyn_cast(TripCount); - if (!RHS) { - LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); - return false; - } - // The L->isCanonical check above ensures we only get here if the loop - // increments by 1 on each iteration, so the RHS of the Compare is - // tripcount-1 (i.e equivalent to the backedge taken count). - assert(SE->getSCEV(RHS) == BackedgeTakenCount && - "Expected RHS of compare to be equal to the backedge taken count"); - ConstantInt *One = ConstantInt::get(RHS->getType(), 1); - TripCount = ConstantInt::get(TripCount->getContext(), - RHS->getValue() + One->getValue()); - } else if (SE->getSCEV(TripCount) != SCEVTripCount) { - auto *TripCountInst = dyn_cast(TripCount); - if (!TripCountInst) { - LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); - return false; - } - if ((!isa(TripCountInst) && !isa(TripCountInst)) || - SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { - LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); - return false; + const SCEV *SCEVRHS = SE->getSCEV(RHS); + if (SCEVRHS == SCEVTripCount) + return setLoopComponent(RHS, TripCount, Increment, IterationInstructions); + else { + ConstantInt *ConstantRHS = dyn_cast(RHS); + if (ConstantRHS) { + const SCEV *BackedgeTCExt = nullptr; + if (IsWidened) { + const SCEV *SCEVTripCountExt; + // Find the extended backedge taken count and extended trip count using + // SCEV. One of these should now match the RHS of the compare. + BackedgeTCExt = + SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt); + if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + } + // If the RHS of the compare is equal to the backedge taken count we need + // to add one to get the trip count. + if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) { + ConstantInt *One = ConstantInt::get(ConstantRHS->getType(), 1); + Value *NewRHS = + ConstantInt::get(ConstantRHS->getContext(), + ConstantRHS->getValue() + One->getValue()); + return setLoopComponent(NewRHS, TripCount, Increment, + IterationInstructions); + } + return setLoopComponent(RHS, TripCount, Increment, IterationInstructions); + } else { + // If the RHS isn't a constant then check that the reason it doesn't match + // the SCEV trip count is because the RHS is a ZExt or SExt instruction + // (and take the trip count to be the RHS). + if (!IsWidened) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + auto *TripCountInst = dyn_cast(RHS); + if (!TripCountInst) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + if ((!isa(TripCountInst) && !isa(TripCountInst)) || + SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + return setLoopComponent(RHS, TripCount, Increment, IterationInstructions); } } - IterationInstructions.insert(Increment); - LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); - LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); - - LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); - return true; } static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI) { Index: llvm/test/Transforms/LoopFlatten/widen-iv.ll =================================================================== --- llvm/test/Transforms/LoopFlatten/widen-iv.ll +++ llvm/test/Transforms/LoopFlatten/widen-iv.ll @@ -525,6 +525,52 @@ ret void } +; Identify trip count when it is constant and the IV has been widened. +define i32 @constTripCount() { +; CHECK-LABEL: @constTripCount( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i64 20, 20 +; CHECK-NEXT: br label [[I_LOOP:%.*]] +; CHECK: i.loop: +; CHECK-NEXT: [[INDVAR1:%.*]] = phi i64 [ [[INDVAR_NEXT2:%.*]], [[J_LOOPDONE:%.*]] ], [ 0, [[ENTRY:%.*]] ] +; CHECK-NEXT: br label [[J_LOOP:%.*]] +; CHECK: j.loop: +; CHECK-NEXT: [[INDVAR:%.*]] = phi i64 [ 0, [[I_LOOP]] ] +; CHECK-NEXT: call void @payload() +; CHECK-NEXT: [[INDVAR_NEXT:%.*]] = add i64 [[INDVAR]], 1 +; CHECK-NEXT: [[J_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT]], 20 +; CHECK-NEXT: br label [[J_LOOPDONE]] +; CHECK: j.loopdone: +; CHECK-NEXT: [[INDVAR_NEXT2]] = add i64 [[INDVAR1]], 1 +; CHECK-NEXT: [[I_ATEND:%.*]] = icmp eq i64 [[INDVAR_NEXT2]], [[FLATTEN_TRIPCOUNT]] +; CHECK-NEXT: br i1 [[I_ATEND]], label [[I_LOOPDONE:%.*]], label [[I_LOOP]] +; CHECK: i.loopdone: +; CHECK-NEXT: ret i32 0 +; +entry: + br label %i.loop + +i.loop: + %i = phi i8 [ 0, %entry ], [ %i.inc, %j.loopdone ] + br label %j.loop + +j.loop: + %j = phi i8 [ 0, %i.loop ], [ %j.inc, %j.loop ] + call void @payload() + %j.inc = add i8 %j, 1 + %j.atend = icmp eq i8 %j.inc, 20 + br i1 %j.atend, label %j.loopdone, label %j.loop + +j.loopdone: + %i.inc = add i8 %i, 1 + %i.atend = icmp eq i8 %i.inc, 20 + br i1 %i.atend, label %i.loopdone, label %i.loop + +i.loopdone: + ret i32 0 +} + +declare void @payload() declare dso_local i32 @use_32(i32) declare dso_local i32 @use_16(i16) declare dso_local i32 @use_64(i64)