diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -307,6 +307,25 @@ } }; +// The Folder will fold expressions that are guarded by the loop entry. +class SCEVSignToZeroExtentionRewriter + : public SCEVRewriteVisitor { +public: + ScalarEvolution &SE; + const Loop *CurLoop; + SCEVSignToZeroExtentionRewriter(ScalarEvolution &SE, const Loop *CurLoop) + : SCEVRewriteVisitor(SE), SE(SE), CurLoop(CurLoop) {} + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + // If expression is guarded by CurLoop to be greater or equal to zero + // then convert sext to zext. Otherwise return the original expression. + if (SE.isLoopEntryGuardedByCond(CurLoop, ICmpInst::ICMP_SGE, Expr, + SE.getZero(Expr->getType()))) + return SE.getZeroExtendExpr(visit(Expr->getOperand()), Expr->getType()); + return Expr; + } +}; + } // end anonymous namespace char LoopIdiomRecognizeLegacyPass::ID = 0; @@ -967,12 +986,22 @@ << "\n"); if (PositiveStrideSCEV != MemsetSizeSCEV) { - // TODO: folding can be done to the SCEVs - // The folding is to fold expressions that is covered by the loop guard - // at loop entry. After the folding, compare again and proceed - // optimization if equal. - LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); - return false; + // The folding is to fold an expression that is covered by the loop guard + // at loop entry. After the folding, compare again and proceed with + // optimization, if equal. + SCEVSignToZeroExtentionRewriter Folder(*SE, CurLoop); + const SCEV *FoldedPositiveStride = Folder.visit(PositiveStrideSCEV); + const SCEV *FoldedMemsetSize = Folder.visit(MemsetSizeSCEV); + + LLVM_DEBUG(dbgs() << " Try to fold SCEV based on loop guard\n" + << " FoldedMemsetSize: " << *FoldedMemsetSize << "\n" + << " FoldedPositiveStride: " << *FoldedPositiveStride + << "\n"); + + if (FoldedPositiveStride != FoldedMemsetSize) { + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } } } diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-32bit.ll @@ -369,4 +369,52 @@ ret void } +; void NegStart(int n, int m, int *ar) { +; for (int i = -100; i < n; i++) { +; int *arr = ar + (i + 100) * m; +; memset(arr, 0, m * sizeof(int)); +; } +; } +define void @NegStart(i32 %n, i32 %m, i32* %ar) { +; CHECK-LABEL: @NegStart( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8* +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]] +; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[MUL1:%.*]] = mul i32 [[M:%.*]], 4 +; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[N]], 100 +; CHECK-NEXT: [[TMP1:%.*]] = mul i32 [[M]], [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = shl i32 [[TMP1]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i32(i8* align 4 [[AR1]], i8 0, i32 [[TMP2]], i1 false) +; CHECK-NEXT: br label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + %cmp1 = icmp slt i32 -100, %n + br i1 %cmp1, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %mul1 = mul i32 %m, 4 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.body + %i.02 = phi i32 [ -100, %for.body.lr.ph ], [ %inc, %for.body ] + %add = add nsw i32 %i.02, 100 + %mul = mul nsw i32 %add, %m + %add.ptr = getelementptr inbounds i32, i32* %ar, i32 %mul + %0 = bitcast i32* %add.ptr to i8* + call void @llvm.memset.p0i8.i32(i8* align 4 %0, i8 0, i32 %mul1, i1 false) + %inc = add nsw i32 %i.02, 1 + %exitcond = icmp ne i32 %inc, %n + br i1 %exitcond, label %for.body, label %for.end.loopexit + +for.end.loopexit: ; preds = %for.body + br label %for.end + +for.end: ; preds = %for.end.loopexit, %entry + ret void +} + declare void @llvm.memset.p0i8.i32(i8* nocapture writeonly, i8, i32, i1 immarg) diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-64bit.ll @@ -268,6 +268,12 @@ for.end: ; preds = %for.body, %entry ret void } +; This case requires SCEVFolder in LoopIdiomRecognize.cpp to fold SCEV prior to comparison. +; For the inner-loop, SCEVFolder is not needed, however the promoted memset size would be based +; on the trip count of inner-loop (which is an unsigned integer). +; Then in the outer loop, the pointer stride SCEV for memset needs to be converted based on the +; loop guard for it to equal to the memset size SCEV. The loop guard guaranteeds that m >= 0 +; inside the loop, so m can be converted from sext to zext, making the two SCEV-s equal. ; void NestedFor32(int *ar, int n, int m, int o) ; { ; int i, j; @@ -281,6 +287,7 @@ define void @NestedFor32(i32* %ar, i32 %n, i32 %m, i32 %o) { ; CHECK-LABEL: @NestedFor32( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[AR2:%.*]] = bitcast i32* [[AR:%.*]] to i8* ; CHECK-NEXT: [[CMP3:%.*]] = icmp slt i32 0, [[N:%.*]] ; CHECK-NEXT: br i1 [[CMP3]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END11:%.*]] ; CHECK: for.body.lr.ph: @@ -296,17 +303,10 @@ ; CHECK-NEXT: [[TMP3:%.*]] = zext i32 [[M]] to i64 ; CHECK-NEXT: [[TMP4:%.*]] = mul i64 [[TMP0]], [[TMP3]] ; CHECK-NEXT: [[TMP5:%.*]] = shl i64 [[TMP4]], 2 -; CHECK-NEXT: br label [[FOR_BODY_US:%.*]] -; CHECK: for.body.us: -; CHECK-NEXT: [[INDVARS_IV6:%.*]] = phi i64 [ 0, [[FOR_BODY_US_PREHEADER]] ], [ [[INDVARS_IV_NEXT7:%.*]], [[FOR_BODY_US]] ] -; CHECK-NEXT: [[TMP6:%.*]] = mul i64 [[TMP2]], [[INDVARS_IV6]] -; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i32, i32* [[AR:%.*]], i64 [[TMP6]] -; CHECK-NEXT: [[SCEVGEP1:%.*]] = bitcast i32* [[SCEVGEP]] to i8* -; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[M]] to i64 -; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[SCEVGEP1]], i8 0, i64 [[TMP5]], i1 false) -; CHECK-NEXT: [[INDVARS_IV_NEXT7]] = add nuw nsw i64 [[INDVARS_IV6]], 1 -; CHECK-NEXT: [[EXITCOND11:%.*]] = icmp ne i64 [[INDVARS_IV_NEXT7]], [[WIDE_TRIP_COUNT10]] -; CHECK-NEXT: br i1 [[EXITCOND11]], label [[FOR_BODY_US]], label [[FOR_END11]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i64 [[TMP4]], [[WIDE_TRIP_COUNT10]] +; CHECK-NEXT: [[TMP7:%.*]] = shl i64 [[TMP6]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[AR2]], i8 0, i64 [[TMP7]], i1 false) +; CHECK-NEXT: br label [[FOR_END11]] ; CHECK: for.end11: ; CHECK-NEXT: ret void ; @@ -357,4 +357,58 @@ ret void } +; void NegStart(int n, int m, int *ar) { +; for (int i = -100; i < n; i++) { +; int *arr = ar + (i + 100) * m; +; memset(arr, 0, m * sizeof(int)); +; } +; } +define void @NegStart(i32 %n, i32 %m, i32* %ar) { +; CHECK-LABEL: @NegStart( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8* +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i32 -100, [[N:%.*]] +; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[M:%.*]] to i64 +; CHECK-NEXT: [[MUL1:%.*]] = mul i64 [[CONV]], 4 +; CHECK-NEXT: [[TMP0:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = sext i32 [[N]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = add nsw i64 [[WIDE_TRIP_COUNT]], 100 +; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], [[TMP0]] +; CHECK-NEXT: [[TMP3:%.*]] = shl i64 [[TMP2]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[AR1]], i8 0, i64 [[TMP3]], i1 false) +; CHECK-NEXT: br label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + %cmp1 = icmp slt i32 -100, %n + br i1 %cmp1, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %conv = sext i32 %m to i64 + %mul1 = mul i64 %conv, 4 + %0 = sext i32 %m to i64 + %wide.trip.count = sext i32 %n to i64 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.body + %indvars.iv = phi i64 [ -100, %for.body.lr.ph ], [ %indvars.iv.next, %for.body ] + %1 = add nsw i64 %indvars.iv, 100 + %2 = mul nsw i64 %1, %0 + %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %2 + %3 = bitcast i32* %add.ptr to i8* + call void @llvm.memset.p0i8.i64(i8* align 4 %3, i8 0, i64 %mul1, i1 false) + %indvars.iv.next = add nsw i64 %indvars.iv, 1 + %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %for.body, label %for.end.loopexit + +for.end.loopexit: ; preds = %for.body + br label %for.end + +for.end: ; preds = %for.end.loopexit, %entry + ret void +} + declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg) diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll --- a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll @@ -19,6 +19,9 @@ ; CHECK-NEXT: memset size is non-constant ; CHECK-NEXT: MemsetSizeSCEV: (4 * (sext i32 %m to i64)) ; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))) +; CHECK-NEXT: Try to fold SCEV based on loop guard +; CHECK-NEXT: FoldedMemsetSize: (4 * (sext i32 %m to i64)) +; CHECK-NEXT: FoldedPositiveStride: (4 + (4 * (sext i32 %m to i64))) ; CHECK-NEXT: SCEV don't match, abort ; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader ; CHECK-NEXT: memset size is non-constant