Index: llvm/lib/Transforms/IPO/PassManagerBuilder.cpp =================================================================== --- llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -441,6 +441,10 @@ MPM.add(createCFGSimplificationPass()); MPM.add(createInstructionCombiningPass()); // We resume loop passes creating a second loop pipeline here. + if (EnableLoopFlatten) { + MPM.add(createLoopFlattenPass()); // Flatten loops + MPM.add(createLoopSimplifyCFGPass()); + } MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. addExtensionsToPM(EP_LateLoopOptimizations, MPM); @@ -448,10 +452,6 @@ if (EnableLoopInterchange) MPM.add(createLoopInterchangePass()); // Interchange loops - if (EnableLoopFlatten) { - MPM.add(createLoopFlattenPass()); // Flatten loops - MPM.add(createLoopSimplifyCFGPass()); - } // Unroll small loops MPM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, @@ -1039,12 +1039,12 @@ PM.add(createMergedLoadStoreMotionPass()); // Merge ld/st in diamonds. // More loops are countable; try to optimize them. + if (EnableLoopFlatten) + PM.add(createLoopFlattenPass()); PM.add(createIndVarSimplifyPass()); PM.add(createLoopDeletionPass()); if (EnableLoopInterchange) PM.add(createLoopInterchangePass()); - if (EnableLoopFlatten) - PM.add(createLoopFlattenPass()); // Unroll small loops PM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -382,14 +382,43 @@ for (Value *V : LinearIVUses) { for (Value *U : V->users()) { + // + // For the overflow check, we first calculate an 'effective' GEP + // bitwidth. I.e., we look through a ZExt instruction which widens the + // induction to a larger type, but only iff it is used as an operand of a + // GEP instruction. This means that we match this pattern: + // + // add.us = add i32 %j, %mul + // %idxprom = zext i32 %add.us to i64 + // %arrayidx.us = getelementptr inbounds i32, i32* %p, i64 %idxprom + // + // The zext instruction here doesn't change values, thus we consider + // %idxprom to be an i32 value, which helps in determining whether the + // the GEP will overflow. Promotion of address calculation to the widest + // available integer happens happens on e.g. 64-bit targets when 32-bit + // loop iterators are used as array indices, so this helps for these + // cases. + // + unsigned GEPBitwidth = 0; + if (auto *ZExt = dyn_cast(U)) { + for (auto *ZExtUser : ZExt->users()) + if (!dyn_cast(ZExtUser)) + return OverflowResult::MayOverflow; + + GEPBitwidth = V->getType()->getScalarSizeInBits(); + U = *ZExt->user_begin(); + } + if (auto *GEP = dyn_cast(U)) { // The IV is used as the operand of a GEP, and the IV is at least as // wide as the address space of the GEP. In this case, the GEP would // wrap around the address space before the IV increment wraps, which // would be UB. + if (!GEPBitwidth) + GEPBitwidth = DL.getPointerTypeSizeInBits(GEP->getType()); + if (GEP->isInBounds() && - V->getType()->getIntegerBitWidth() >= - DL.getPointerTypeSizeInBits(GEP->getType())) { + V->getType()->getIntegerBitWidth() >= GEPBitwidth) { LLVM_DEBUG( dbgs() << "use of linear IV would be UB if overflow occurred: "; GEP->dump()); @@ -443,10 +472,6 @@ OuterInductionPHI, TTI)) return false; - // FIXME: it should be possible to handle different types correctly. - if (InnerInductionPHI->getType() != OuterInductionPHI->getType()) - return false; - if (!checkOuterLoopInsts(OuterLoop, InnerLoop, IterationInstructions, InnerLimit, OuterInductionPHI, TTI)) return false; Index: llvm/test/Transforms/LoopFlatten/zext-i64.ll =================================================================== --- llvm/test/Transforms/LoopFlatten/zext-i64.ll +++ llvm/test/Transforms/LoopFlatten/zext-i64.ll @@ -10,27 +10,28 @@ ; CHECK-NEXT: br i1 [[CMP26_NOT]], label [[FOR_END12:%.*]], label [[FOR_COND1_PREHEADER_LR_PH:%.*]] ; CHECK: for.cond1.preheader.lr.ph: ; CHECK-NEXT: [[CONV4:%.*]] = sext i16 [[VAL:%.*]] to i32 +; CHECK-NEXT: [[FLATTEN_TRIPCOUNT:%.*]] = mul i32 [[N]], [[N]] ; CHECK-NEXT: br label [[FOR_COND1_PREHEADER_US:%.*]] ; CHECK: for.cond1.preheader.us: ; CHECK-NEXT: [[I_027_US:%.*]] = phi i32 [ 0, [[FOR_COND1_PREHEADER_LR_PH]] ], [ [[INC11_US:%.*]], [[FOR_COND1_FOR_INC10_CRIT_EDGE_US:%.*]] ] ; CHECK-NEXT: [[MUL_US:%.*]] = mul i32 [[I_027_US]], [[N]] ; CHECK-NEXT: br label [[FOR_BODY3_US:%.*]] ; CHECK: for.body3.us: -; CHECK-NEXT: [[J_025_US:%.*]] = phi i32 [ 0, [[FOR_COND1_PREHEADER_US]] ], [ [[INC_US:%.*]], [[FOR_BODY3_US]] ] +; CHECK-NEXT: [[J_025_US:%.*]] = phi i32 [ 0, [[FOR_COND1_PREHEADER_US]] ] ; CHECK-NEXT: [[ADD_US:%.*]] = add i32 [[J_025_US]], [[MUL_US]] -; CHECK-NEXT: [[IDXPROM_US:%.*]] = zext i32 [[ADD_US]] to i64 +; CHECK-NEXT: [[IDXPROM_US:%.*]] = zext i32 [[I_027_US]] to i64 ; CHECK-NEXT: [[ARRAYIDX_US:%.*]] = getelementptr inbounds i16, i16* [[A:%.*]], i64 [[IDXPROM_US]] ; CHECK-NEXT: [[TMP0:%.*]] = load i16, i16* [[ARRAYIDX_US]], align 2 ; CHECK-NEXT: [[CONV_US:%.*]] = sext i16 [[TMP0]] to i32 ; CHECK-NEXT: [[MUL5_US:%.*]] = mul nsw i32 [[CONV_US]], [[CONV4]] ; CHECK-NEXT: [[ARRAYIDX9_US:%.*]] = getelementptr inbounds i32, i32* [[C:%.*]], i64 [[IDXPROM_US]] ; CHECK-NEXT: store i32 [[MUL5_US]], i32* [[ARRAYIDX9_US]], align 4 -; CHECK-NEXT: [[INC_US]] = add nuw i32 [[J_025_US]], 1 +; CHECK-NEXT: [[INC_US:%.*]] = add nuw i32 [[J_025_US]], 1 ; CHECK-NEXT: [[CMP2_US:%.*]] = icmp ult i32 [[INC_US]], [[N]] -; CHECK-NEXT: br i1 [[CMP2_US]], label [[FOR_BODY3_US]], label [[FOR_COND1_FOR_INC10_CRIT_EDGE_US]] +; CHECK-NEXT: br label [[FOR_COND1_FOR_INC10_CRIT_EDGE_US]] ; CHECK: for.cond1.for.inc10_crit_edge.us: ; CHECK-NEXT: [[INC11_US]] = add i32 [[I_027_US]], 1 -; CHECK-NEXT: [[CMP_US:%.*]] = icmp ult i32 [[INC11_US]], [[N]] +; CHECK-NEXT: [[CMP_US:%.*]] = icmp ult i32 [[INC11_US]], [[FLATTEN_TRIPCOUNT]] ; CHECK-NEXT: br i1 [[CMP_US]], label [[FOR_COND1_PREHEADER_US]], label [[FOR_END12_LOOPEXIT:%.*]] ; CHECK: for.end12.loopexit: ; CHECK-NEXT: br label [[FOR_END12]]