diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -307,6 +307,24 @@ } }; +// The Folder will fold expressions that are guarded by the loop entry. +class SCEVFolder : public SCEVRewriteVisitor { +public: + ScalarEvolution &SE; + const Loop *CurLoop; + SCEVFolder(ScalarEvolution &SE, const Loop *CurLoop) + : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {} + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + // If expression is guarded by CurLoop to be greater or equal to zero + // then convert sext to zext. Otherwise return the original expression. + if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr, + SE.getZero(Expr->getType()))) + return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType()); + return Expr; + } +}; + } // end anonymous namespace char LoopIdiomRecognizeLegacyPass::ID = 0; @@ -968,12 +986,22 @@ << "\n"); if (PositiveStrideSCEV != MemsetSizeSCEV) { - // TODO: folding can be done to the SCEVs - // The folding is to fold expressions that is covered by the loop guard - // at loop entry. After the folding, compare again and proceed - // optimization if equal. - LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); - return false; + // The folding is to fold an expression that is covered by the loop guard + // at loop entry. After the folding, compare again and proceed with + // optimization, if equal. + SCEVFolder Folder(*SE, CurLoop); + const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV); + const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV); + + LLVM_DEBUG(dbgs() << " Try to fold SCEV based on loop guard\n" + << " FoldedMemsetSize: " << *FoldedMemsetSize << "\n" + << " FoldedPositiveStride: " << *FoldedPositiveStride + << "\n"); + + if (FoldedPositiveStride != FoldedMemsetSize) { + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } } } diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll @@ -19,6 +19,9 @@ ; CHECK-NEXT: memset size is non-constant ; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64)) ; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))) +; CHECK-NEXT: Try to fold SCEV based on loop guard +; CHECK-NEXT: FoldedMemsetSize: (4 * (sext i32 %m to i64)) +; CHECK-NEXT: FoldedPositiveStride: (4 + (4 * (sext i32 %m to i64))) ; CHECK-NEXT: SCEV don't match, abort ; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader ; CHECK-NEXT: memset size is non-constant diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll @@ -26,23 +26,23 @@ ; entry: %0 = shl nuw i64 %m, 2 - br label %for.cond1.preheader + br label %for.header -for.cond1.preheader: ; preds = %for.inc4, %entry - %i.017 = phi i64 [ 0, %entry ], [ %inc5, %for.inc4 ] +for.header: ; preds = %for.inc, %entry + %i.017 = phi i64 [ 0, %entry ], [ %inc5, %for.inc ] %1 = mul i64 %m, %i.017 %scevgep = getelementptr i32, i32* %ar, i64 %1 %scevgep1 = bitcast i32* %scevgep to i8* %mul = mul nsw i64 %i.017, %m call void @llvm.memset.p0i8.i64(i8* align 4 %scevgep1, i8 0, i64 %0, i1 false) - br label %for.inc4 + br label %for.inc -for.inc4: ; preds = %for.cond1.preheader +for.inc: ; preds = %for.header %inc5 = add nuw nsw i64 %i.017, 1 %exitcond18.not = icmp eq i64 %inc5, %n - br i1 %exitcond18.not, label %for.end6, label %for.cond1.preheader + br i1 %exitcond18.not, label %end, label %for.header -for.end6: ; preds = %for.inc4 +end: ; preds = %for.inc ret void } @@ -62,8 +62,8 @@ ; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[N:%.*]], 1 ; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[SUB]] to i64 ; CHECK-NEXT: [[CMP1:%.*]] = icmp sge i64 [[CONV]], 0 -; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]] -; CHECK: for.body.lr.ph: +; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_PREHEADER:%.*]], label [[END:%.*]] +; CHECK: for.preheader: ; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[M:%.*]] to i64 ; CHECK-NEXT: [[CONV2:%.*]] = sext i32 [[M]] to i64 ; CHECK-NEXT: [[MUL3:%.*]] = mul i64 [[CONV2]], 4 @@ -71,24 +71,24 @@ ; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[CONV1]], [[TMP0]] ; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 2 ; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[AR1]], i8 0, i64 [[TMP2]], i1 false) -; CHECK-NEXT: br label [[FOR_END]] -; CHECK: for.end: +; CHECK-NEXT: br label [[END]] +; CHECK: end: ; CHECK-NEXT: ret void ; entry: %sub = sub nsw i32 %n, 1 %conv = sext i32 %sub to i64 %cmp1 = icmp sge i64 %conv, 0 - br i1 %cmp1, label %for.body.lr.ph, label %for.end + br i1 %cmp1, label %for.preheader, label %end -for.body.lr.ph: ; preds = %entry +for.preheader: ; preds = %entry %conv1 = sext i32 %m to i64 %conv2 = sext i32 %m to i64 %mul3 = mul i64 %conv2, 4 br label %for.body -for.body: ; preds = %for.body.lr.ph, %for.inc - %i.02 = phi i64 [ %conv, %for.body.lr.ph ], [ %dec, %for.inc ] +for.body: ; preds = %for.preheader, %for.inc + %i.02 = phi i64 [ %conv, %for.preheader ], [ %dec, %for.inc ] %mul = mul nsw i64 %i.02, %conv1 %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %mul %0 = bitcast i32* %add.ptr to i8* @@ -98,12 +98,119 @@ for.inc: ; preds = %for.body %dec = add nsw i64 %i.02, -1 %cmp = icmp sge i64 %dec, 0 - br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge + br i1 %cmp, label %for.body, label %for.exit + +for.exit: ; preds = %for.inc + br label %end + +end: ; preds = %for.exit, %entry + ret void +} + +; The C code to generate this testcase: +; void test(int n, int m, int o, int *ar) +; { +; for (int i=0; i= 0 +; inside the loop, so m can be converted from sext to zext, making the two SCEV-s equal. +; Below are the debug log of LoopIdiomRecognize. +; loop-idiom Scanning: F[NestedFor] Countable Loop %for.body3.us +; memset size is non-constant +; MemsetSizeSCEV: (4 * (sext i32 %o to i64)) +; PositiveStrideSCEV: (4 * (sext i32 %o to i64)) +; Formed memset: call void @llvm.memset.p0i8.i64(i8* align 4 %scevgep1, i8 0, i64 %6, i1 false) +; loop-idiom Scanning: F[NestedFor] Countable Loop %for.body.us +; memset size is non-constant +; MemsetSizeSCEV: (4 * (zext i32 %m to i64) * (sext i32 %o to i64)) +; PositiveStrideSCEV: (4 * (sext i32 %m to i64) * (sext i32 %o to i64)) +; Try to fold SCEV based on loop guard +; FoldedMemsetSize: (4 * (zext i32 %m to i64) * (sext i32 %o to i64)) +; FoldedPositiveStride: (4 * (zext i32 %m to i64) * (sext i32 %o to i64)) +; Formed memset: call void @llvm.memset.p0i8.i64(i8* align 4 %ar2, i8 0, i64 %8, i1 false) +define void @NestedFor(i32 %n, i32 %m, i32 %o, i32* %ar) { +; CHECK-LABEL: @NestedFor( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AR2:%.*]] = bitcast i32* [[AR:%.*]] to i8* +; CHECK-NEXT: [[CMP3:%.*]] = icmp slt i32 0, [[N:%.*]] +; CHECK-NEXT: br i1 [[CMP3]], label [[FOR_PREHEADER:%.*]], label [[FOR_END:%.*]] +; CHECK: for.preheader: +; CHECK-NEXT: [[CMP21:%.*]] = icmp slt i32 0, [[M:%.*]] +; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[O:%.*]] to i64 +; CHECK-NEXT: [[MUL8:%.*]] = mul i64 [[CONV]], 4 +; CHECK-NEXT: br i1 [[CMP21]], label [[FOR_I_ENTERING:%.*]], label [[FOR_END]] +; CHECK: for.i.entering: +; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[O]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[O]] to i64 +; CHECK-NEXT: [[WIDE_TRIP_COUNT10:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP0]], [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[M]] to i64 +; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP0]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = shl i64 [[TMP5]], 2 +; CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP5]], [[WIDE_TRIP_COUNT10]] +; CHECK-NEXT: [[TMP8:%.*]] = shl i64 [[TMP7]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[AR2]], i8 0, i64 [[TMP8]], i1 false) +; CHECK-NEXT: br label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + %cmp3 = icmp slt i32 0, %n + br i1 %cmp3, label %for.preheader, label %for.end + +for.preheader: ; preds = %entry + %cmp21 = icmp slt i32 0, %m + %conv = sext i32 %o to i64 + %mul8 = mul i64 %conv, 4 + br i1 %cmp21, label %for.i.entering, label %for.end + +for.i.entering: ; preds = %for.preheader + %0 = sext i32 %o to i64 + %1 = sext i32 %m to i64 + %2 = sext i32 %o to i64 + %wide.trip.count10 = zext i32 %n to i64 + br label %for.i.header + +for.i.header: ; preds = %for.i.inc, %for.i.entering + %indvars.iv6 = phi i64 [ %indvars.iv.next7, %for.i.inc ], [ 0, %for.i.entering ] + br label %for.j.entering + +for.j.body: ; preds = %for.j.entering, %for.j.inc + %indvars.iv = phi i64 [ 0, %for.j.entering ], [ %indvars.iv.next, %for.j.inc ] + %3 = mul nsw i64 %indvars.iv, %0 + %add.ptr7.us = getelementptr inbounds i32, i32* %add.ptr.us, i64 %3 + %4 = bitcast i32* %add.ptr7.us to i8* + call void @llvm.memset.p0i8.i64(i8* align 4 %4, i8 0, i64 %mul8, i1 false) + br label %for.j.inc + +for.j.entering: ; preds = %for.i.header + %5 = mul nsw i64 %indvars.iv6, %1 + %6 = mul nsw i64 %5, %2 + %add.ptr.us = getelementptr inbounds i32, i32* %ar, i64 %6 + %wide.trip.count = zext i32 %m to i64 + br label %for.j.body + +for.j.inc: ; preds = %for.j.body + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %for.j.body, label %for.i.inc -for.cond.for.end_crit_edge: ; preds = %for.inc - br label %for.end +for.i.inc: ; preds = %for.j.inc + %indvars.iv.next7 = add nuw nsw i64 %indvars.iv6, 1 + %exitcond11 = icmp ne i64 %indvars.iv.next7, %wide.trip.count10 + br i1 %exitcond11, label %for.i.header, label %for.end -for.end: ; preds = %for.cond.for.end_crit_edge, %entry +for.end: ; preds = %entry, %for.preheader, %for.i.inc ret void }