Index: lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- lib/Transforms/Scalar/LoopFlatten.cpp +++ lib/Transforms/Scalar/LoopFlatten.cpp @@ -302,7 +302,8 @@ static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI, BinaryOperator *InnerIncrement, BinaryOperator *OuterIncrement, Value *InnerLimit, - SmallPtrSetImpl &LinearIVUses) { + SmallPtrSetImpl &LinearIVUses, + SmallPtrSetImpl &GepIVUses) { // We require all uses of both induction variables to match this pattern: // // (OuterPHI * InnerLimit) + InnerPHI @@ -328,6 +329,16 @@ DEBUG(dbgs() << "Use is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); LinearIVUses.insert(U); + } else if (dyn_cast(InnerLimit) && + dyn_cast(U) && U->getNumOperands() == 3 && + U->getOperand(1) == OuterPHI && U->getOperand(2) == InnerPHI && + U->getOperand(0)->getType()->isPointerTy() && + U->getOperand(0)->getType()->getPointerElementType()->isArrayTy() && + U->getOperand(0)->getType()->getPointerElementType()->getArrayNumElements() == + cast(InnerLimit)->getUniqueInteger()) { + DEBUG(dbgs() << "Use is gep optimisable\n"); + ValidOuterPHIUses.insert(U); + GepIVUses.insert(cast(U)); } else { DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); return false; @@ -350,10 +361,13 @@ } } - DEBUG(dbgs() << "Found " << LinearIVUses.size() + DEBUG(dbgs() << "Found " << LinearIVUses.size() + GepIVUses.size() << " value(s) that can be replaced:\n"; for (Value *V : LinearIVUses) { dbgs() << " "; V->dump(); + } + for (Value *V : GepIVUses) { + dbgs() << " "; V->dump(); }); return true; @@ -361,10 +375,11 @@ // Return an OverflowResult dependant on if overflow of the multiplication of // InnerLimit and OuterLimit can be assumed not to happen. -static OverflowResult checkOverflow(Loop *OuterLoop, Value *InnerLimit, - Value *OuterLimit, - SmallPtrSetImpl &LinearIVUses, - DominatorTree *DT, AssumptionCache *AC) { +static OverflowResult +checkOverflow(Loop *OuterLoop, Value *InnerLimit, Value *OuterLimit, + SmallPtrSetImpl &LinearIVUses, + SmallPtrSetImpl &GepIVUses, + DominatorTree *DT, AssumptionCache *AC) { Function *F = OuterLoop->getHeader()->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); @@ -380,23 +395,30 @@ if (OR != OverflowResult::MayOverflow) return OR; - for (Value *V : LinearIVUses) { - for (Value *U : V->users()) { - if (auto *GEP = dyn_cast(U)) { - // The IV is used as the operand of a GEP, and the IV is at least as - // wide as the address space of the GEP. In this case, the GEP would - // wrap around the address space before the IV increment wraps, which - // would be UB. - if (GEP->isInBounds() && - V->getType()->getIntegerBitWidth() >= - DL.getPointerTypeSizeInBits(GEP->getType())) { - DEBUG(dbgs() << "use of linear IV would be UB if overflow occurred: "; - GEP->dump()); - return OverflowResult::NeverOverflows; - } - } + auto CheckGEPForOverflow = [&](GetElementPtrInst *GEP) { + // The IV is used as the operand of a GEP, and the IV is at least as + // wide as the address space of the GEP. In this case, the GEP would + // wrap around the address space before the IV increment wraps, which + // would be UB. + if (GEP->isInBounds() && + GEP->getOperand(1)->getType()->getIntegerBitWidth() >= + DL.getPointerTypeSizeInBits(GEP->getType())) { + DEBUG(dbgs() << "use of linear IV would be UB if overflow occurred: "; + GEP->dump()); + return true; } - } + return false; + }; + + for (Value *V : LinearIVUses) + for (Value *U : V->users()) + if (auto *GEP = dyn_cast(U)) + if(CheckGEPForOverflow(GEP)) + return OverflowResult::NeverOverflows; + + for (GetElementPtrInst *V : GepIVUses) + if (CheckGEPForOverflow(V)) + return OverflowResult::NeverOverflows; return OverflowResult::MayOverflow; } @@ -407,7 +429,6 @@ TargetTransformInfo *TTI, std::function markLoopAsDeleted) { Function *F = OuterLoop->getHeader()->getParent(); - DEBUG(dbgs() << "Running on outer loop " << OuterLoop->getHeader()->getName() << " and inner loop " << InnerLoop->getHeader()->getName() << " in " << F->getName() << "\n"); @@ -460,8 +481,9 @@ // transformation, but we'd have to insert a div/mod to calculate the // original IVs, so it wouldn't be profitable. SmallPtrSet LinearIVUses; + SmallPtrSet GepIVUses; if (!checkIVUsers(InnerInductionPHI, OuterInductionPHI, InnerIncrement, - OuterIncrement, InnerLimit, LinearIVUses)) + OuterIncrement, InnerLimit, LinearIVUses, GepIVUses)) return false; // Check if the new iteration variable might overflow. In this case, we @@ -470,8 +492,8 @@ // TODO: it might be worth using a wider iteration variable rather than // versioning the loop, if a wide enough type is legal. bool MustVersionLoop = true; - OverflowResult OR = - checkOverflow(OuterLoop, InnerLimit, OuterLimit, LinearIVUses, DT, AC); + OverflowResult OR = checkOverflow(OuterLoop, InnerLimit, OuterLimit, + LinearIVUses, GepIVUses, DT, AC); if (OR == OverflowResult::AlwaysOverflows) { DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); return false; @@ -552,6 +574,10 @@ // variables with the one new one. for (Value *V : LinearIVUses) V->replaceAllUsesWith(OuterInductionPHI); + for (GetElementPtrInst *V : GepIVUses) { + V->setOperand(1, Constant::getNullValue(V->getOperand(1)->getType())); + V->setOperand(2, OuterInductionPHI); + } // If we made a fallback copy of the loop, it will still be flattenable if // this pass is run again, but that wouldn't be profitable, so disable Index: test/Transforms/LoopFlatten/loop-flatten-negative.ll =================================================================== --- test/Transforms/LoopFlatten/loop-flatten-negative.ll +++ test/Transforms/LoopFlatten/loop-flatten-negative.ll @@ -341,7 +341,7 @@ ; Outer loop conditional phi -define i32 @e() { +define i32 @test_10() { entry: br label %for.body @@ -393,3 +393,31 @@ for.end19: ; preds = %for.end16 ret i32 undef } + +define i32 @test11([8 x float]* nocapture %A) { +entry: + br label %for.body + +for.body: + %i.016 = phi i32 [ 0, %entry ], [ %inc6, %for.inc5 ] + %mul = mul nuw nsw i32 %i.016, 10 + br label %for.body3 + +for.body3: + %j.015 = phi i32 [ 0, %for.body ], [ %inc, %for.body3 ] + %add = add nuw nsw i32 %j.015, %mul + %conv = sitofp i32 %add to float + %arrayidx4 = getelementptr inbounds [8 x float], [8 x float]* %A, i32 %i.016, i32 %j.015 + store float %conv, float* %arrayidx4, align 4 + %inc = add nuw nsw i32 %j.015, 1 + %exitcond = icmp ne i32 %inc, 10 + br i1 %exitcond, label %for.body3, label %for.inc5 + +for.inc5: + %inc6 = add nuw nsw i32 %i.016, 1 + %exitcond17 = icmp ne i32 %inc6, 10 + br i1 %exitcond17, label %for.body, label %for.end7 + +for.end7: + ret i32 0 +} Index: test/Transforms/LoopFlatten/loop-flatten.ll =================================================================== --- test/Transforms/LoopFlatten/loop-flatten.ll +++ test/Transforms/LoopFlatten/loop-flatten.ll @@ -780,5 +780,55 @@ ; CHECK: ret void } -declare i32 @func(i32) +; CHECK-LABEL: test11 +; A[i][j] +define i32 @test11([10 x float]* nocapture %A) local_unnamed_addr #0 { +entry: + br label %for.body +; CHECK: entry: +; CHECK: %flatten.tripcount = mul i32 10, 10 +; CHECK: br label %for.body + +for.body: + %i.016 = phi i32 [ 0, %entry ], [ %inc6, %for.inc5 ] + %mul = mul nuw nsw i32 %i.016, 10 + br label %for.body3 +; CHECK: for.body: +; CHECK: %i.016 = phi i32 [ 0, %entry ], [ %inc6, %for.inc5 ] +; CHECK: %mul = mul nuw nsw i32 %i.016, 10 +; CHECK: br label %for.body3 +for.body3: + %j.015 = phi i32 [ 0, %for.body ], [ %inc, %for.body3 ] + %add = add nuw nsw i32 %j.015, %mul + %conv = sitofp i32 %add to float + %arrayidx4 = getelementptr inbounds [10 x float], [10 x float]* %A, i32 %i.016, i32 %j.015 + store float %conv, float* %arrayidx4, align 4 + %inc = add nuw nsw i32 %j.015, 1 + %exitcond = icmp ne i32 %inc, 10 + br i1 %exitcond, label %for.body3, label %for.inc5 +; CHECK: for.body3: +; CHECK: %j.015 = phi i32 [ 0, %for.body ] +; CHECK: %add = add nuw nsw i32 %j.015, %mul +; CHECK: %conv = sitofp i32 %i.016 to float +; CHECK: %arrayidx4 = getelementptr inbounds [10 x float], [10 x float]* %A, i32 0, i32 %i.016 +; CHECK: store float %conv, float* %arrayidx4, align 4 +; CHECK: %inc = add nuw nsw i32 %j.015, 1 +; CHECK: %exitcond = icmp ne i32 %inc, 10 +; CHECK: br label %for.inc5 + +for.inc5: + %inc6 = add nuw nsw i32 %i.016, 1 + %exitcond17 = icmp ne i32 %inc6, 10 + br i1 %exitcond17, label %for.body, label %for.end7 +; CHECK: for.inc5: +; CHECK: %inc6 = add nuw nsw i32 %i.016, 1 +; CHECK: %exitcond17 = icmp ne i32 %inc6, %flatten.tripcount +; CHECK: br i1 %exitcond17, label %for.body, label %for.end7 + +for.end7: + ret i32 0 +} + + +declare i32 @func(i32)