diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -307,6 +307,24 @@ } }; +// The Folder will fold expressions that are guarded by the loop entry. +class SCEVFolder : public SCEVRewriteVisitor { +public: + ScalarEvolution &SE; + const Loop *CurLoop; + SCEVFolder(ScalarEvolution &SE, const Loop *CurLoop) + : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {} + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + // If expression is guarded by CurLoop to be greater or equal to zero + // then convert sext to zext. Otherwise return the original expression. + if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr, + SE.getZero(Expr->getType()))) + return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType()); + return Expr; + } +}; + } // end anonymous namespace char LoopIdiomRecognizeLegacyPass::ID = 0; @@ -968,12 +986,22 @@ << "\n"); if (PositiveStrideSCEV != MemsetSizeSCEV) { - // TODO: folding can be done to the SCEVs - // The folding is to fold expressions that is covered by the loop guard - // at loop entry. After the folding, compare again and proceed - // optimization if equal. - LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); - return false; + // The folding is to fold an expression that is covered by the loop guard + // at loop entry. After the folding, compare again and proceed with + // optimization, if equal. + SCEVFolder Folder(*SE, CurLoop); + const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV); + const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV); + + LLVM_DEBUG(dbgs() << " Try to fold SCEV based on loop guard\n" + << " FoldedMemsetSize: " << *FoldedMemsetSize << "\n" + << " FoldedPositiveStride: " << *FoldedPositiveStride + << "\n"); + + if (FoldedPositiveStride != FoldedMemsetSize) { + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } } } diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll @@ -19,6 +19,9 @@ ; CHECK-NEXT: memset size is non-constant ; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64)) ; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))) +; CHECK-NEXT: Try to fold SCEV based on loop guard +; CHECK-NEXT: FoldedMemsetSize: (4 * (sext i32 %m to i64)) +; CHECK-NEXT: FoldedPositiveStride: (4 + (4 * (sext i32 %m to i64))) ; CHECK-NEXT: SCEV don't match, abort ; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader ; CHECK-NEXT: memset size is non-constant diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll @@ -107,4 +107,129 @@ ret void } +; The C code to generate this testcase: +; void test(int n, int m, int o, int *ar) +; { +; for (int i=0; i= 0 +; inside the loop, so m can be converted from sext to zext, making the two SCEV-s equal. +; Below are the debug log of LoopIdiomRecognize. +; loop-idiom Scanning: F[NestedFor] Countable Loop %for.body3.us +; memset size is non-constant +; MemsetSizeSCEV: (4 * (sext i32 %o to i64)) +; PositiveStrideSCEV: (4 * (sext i32 %o to i64)) +; Formed memset: call void @llvm.memset.p0i8.i64(i8* align 4 %scevgep1, i8 0, i64 %6, i1 false) +; loop-idiom Scanning: F[NestedFor] Countable Loop %for.body.us +; memset size is non-constant +; MemsetSizeSCEV: (4 * (zext i32 %m to i64) * (sext i32 %o to i64)) +; PositiveStrideSCEV: (4 * (sext i32 %m to i64) * (sext i32 %o to i64)) +; Try to fold SCEV based on loop guard +; FoldedMemsetSize: (4 * (zext i32 %m to i64) * (sext i32 %o to i64)) +; FoldedPositiveStride: (4 * (zext i32 %m to i64) * (sext i32 %o to i64)) +; Formed memset: call void @llvm.memset.p0i8.i64(i8* align 4 %ar2, i8 0, i64 %8, i1 false) +define void @NestedFor(i32 %n, i32 %m, i32 %o, i32* %ar) { +; CHECK-LABEL: @NestedFor( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AR2:%.*]] = bitcast i32* [[AR:%.*]] to i8* +; CHECK-NEXT: [[CMP3:%.*]] = icmp slt i32 0, [[N:%.*]] +; CHECK-NEXT: br i1 [[CMP3]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END11:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[CMP21:%.*]] = icmp slt i32 0, [[M:%.*]] +; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[O:%.*]] to i64 +; CHECK-NEXT: [[MUL8:%.*]] = mul i64 [[CONV]], 4 +; CHECK-NEXT: br i1 [[CMP21]], label [[FOR_BODY_LR_PH_SPLIT_US:%.*]], label [[FOR_END11]] +; CHECK: for.body.lr.ph.split.us: +; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[O]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = sext i32 [[O]] to i64 +; CHECK-NEXT: [[WIDE_TRIP_COUNT10:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: [[TMP3:%.*]] = mul i64 [[TMP0]], [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = zext i32 [[M]] to i64 +; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP0]], [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = shl i64 [[TMP5]], 2 +; CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP5]], [[WIDE_TRIP_COUNT10]] +; CHECK-NEXT: [[TMP8:%.*]] = shl i64 [[TMP7]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[AR2]], i8 0, i64 [[TMP8]], i1 false) +; CHECK-NEXT: br label [[FOR_END11]] +; CHECK: for.end11: +; CHECK-NEXT: ret void +; +entry: + %cmp3 = icmp slt i32 0, %n + br i1 %cmp3, label %for.body.lr.ph, label %for.end11 + +for.body.lr.ph: ; preds = %entry + %cmp21 = icmp slt i32 0, %m + %conv = sext i32 %o to i64 + %mul8 = mul i64 %conv, 4 + br i1 %cmp21, label %for.body.lr.ph.split.us, label %for.body.lr.ph.split + +for.body.lr.ph.split.us: ; preds = %for.body.lr.ph + %0 = sext i32 %o to i64 + %1 = sext i32 %m to i64 + %2 = sext i32 %o to i64 + %wide.trip.count10 = zext i32 %n to i64 + br label %for.body.us + +for.body.us: ; preds = %for.inc9.us, %for.body.lr.ph.split.us + %indvars.iv6 = phi i64 [ %indvars.iv.next7, %for.inc9.us ], [ 0, %for.body.lr.ph.split.us ] + br label %for.body3.lr.ph.us + +for.end.us: ; preds = %for.cond1.for.end_crit_edge.us + br label %for.inc9.us + +for.inc9.us: ; preds = %for.end.us + %indvars.iv.next7 = add nuw nsw i64 %indvars.iv6, 1 + %exitcond11 = icmp ne i64 %indvars.iv.next7, %wide.trip.count10 + br i1 %exitcond11, label %for.body.us, label %for.cond.for.end11_crit_edge.split.us + +for.body3.us: ; preds = %for.body3.lr.ph.us, %for.inc.us + %indvars.iv = phi i64 [ 0, %for.body3.lr.ph.us ], [ %indvars.iv.next, %for.inc.us ] + %3 = mul nsw i64 %indvars.iv, %0 + %add.ptr7.us = getelementptr inbounds i32, i32* %add.ptr.us, i64 %3 + %4 = bitcast i32* %add.ptr7.us to i8* + call void @llvm.memset.p0i8.i64(i8* align 4 %4, i8 0, i64 %mul8, i1 false) + br label %for.inc.us + +for.inc.us: ; preds = %for.body3.us + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %for.body3.us, label %for.cond1.for.end_crit_edge.us + +for.body3.lr.ph.us: ; preds = %for.body.us + %5 = mul nsw i64 %indvars.iv6, %1 + %6 = mul nsw i64 %5, %2 + %add.ptr.us = getelementptr inbounds i32, i32* %ar, i64 %6 + %wide.trip.count = zext i32 %m to i64 + br label %for.body3.us + +for.cond1.for.end_crit_edge.us: ; preds = %for.inc.us + br label %for.end.us + +for.cond.for.end11_crit_edge.split.us: ; preds = %for.inc9.us + br label %for.cond.for.end11_crit_edge + +for.body.lr.ph.split: ; preds = %for.body.lr.ph + br label %for.cond.for.end11_crit_edge.split + +for.cond.for.end11_crit_edge.split: ; preds = %for.body.lr.ph.split + br label %for.cond.for.end11_crit_edge + +for.cond.for.end11_crit_edge: ; preds = %for.cond.for.end11_crit_edge.split.us, %for.cond.for.end11_crit_edge.split + br label %for.end11 + +for.end11: ; preds = %for.cond.for.end11_crit_edge, %entry + ret void +} + declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg)