diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -897,7 +897,7 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, const SCEV *BECount) { // We can only handle non-volatile memsets with a constant size. - if (MSI->isVolatile() || !isa(MSI->getLength())) + if (MSI->isVolatile()) return false; // If we're not allowed to hack on memset, we fail. @@ -913,20 +913,62 @@ if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) return false; - // Reject memsets that are so large that they overflow an unsigned. - uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); - if ((SizeInBytes >> 32) != 0) + const SCEV *StrideSCEV = Ev->getOperand(1); + const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + if (!StrideSCEV || !MemsetSizeSCEV) return false; - // Check to see if the stride matches the size of the memset. If so, then we - // know that every byte is touched in the loop. - const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); - if (!ConstStride) - return false; + bool IsNegStride; + const bool IsConstantSize = isa(MSI->getLength()); - APInt Stride = ConstStride->getAPInt(); - if (SizeInBytes != Stride && SizeInBytes != -Stride) - return false; + if (IsConstantSize) { + // Memset size is constant + // Reject memsets that are so large that they overflow an unsigned. + LLVM_DEBUG(dbgs() << " memset size is constant\n"); + uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); + if ((SizeInBytes >> 32) != 0) + return false; + + // Check to see if the stride matches the size of the memset. If so, then + // we know that every byte is touched in the loop. + const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); + if (!ConstStride) + return false; + + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + IsNegStride = SizeInBytes == -Stride; + } else { + // Memset size is non-constant + // Check if the pointer stride matches the memset size. + // Also, the pass only handle memset length and stride that are invariant + // for the top level loop. + // If the original StrideSCEV and MemsetSizeSCEV does not match, the pass + // will fold expressions that is covered by the loop guard at loop entry. + // The pass will compare again after the folding and proceed if equal. + LLVM_DEBUG(dbgs() << " memset size is non-constant\n"); + if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop)) { + LLVM_DEBUG(dbgs() << " memset size is not a loop-invariant, " + << "abort\n"); + return false; + } + + // compare positive direction strideSCEV with MemsizeSizeSCEV + IsNegStride = StrideSCEV->isNonConstantNegative(); + const SCEV *PositiveStrideSCEV = + IsNegStride ? SE->getNegativeSCEV(StrideSCEV) : StrideSCEV; + LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" + << " PositiveStrideSCEV: " << *PositiveStrideSCEV + << "\n"); + + if (PositiveStrideSCEV != MemsetSizeSCEV) { + // TODO: folding can be done to the SCEVs + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } + } // Verify that the memset value is loop invariant. If not, we can't promote // the memset. @@ -936,7 +978,6 @@ SmallPtrSet MSIs; MSIs.insert(MSI); - bool IsNegStride = SizeInBytes == -Stride; return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()), MaybeAlign(MSI->getDestAlignment()), SplatValue, MSI, MSIs, Ev, BECount, diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll @@ -0,0 +1,87 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: -p --function-signature --check-attributes --check-globals --include-generated-funcs --force-update +; REQUIRES: asserts +; RUN: opt < %s -S -debug -passes=loop-idiom -debug-only=loop-idiom -disable-output 2>&1 | FileCheck %s +; The C code to generate this testcase: +; void test(int *ar, int n, int m) +; { +; long i; +; for (i=0; i +; CHECK: PositiveStrideSCEV: (4 + (4 * (sext i32 %2 to i64))) +; CHECK: SCEV don't match, abort + %4 = icmp slt i32 0, %1 + br i1 %4, label %.lr.ph, label %14 + +.lr.ph: ; preds = %3 + %5 = sext i32 %2 to i64 + %6 = mul i64 %5, 4 + %7 = sext i32 %2 to i64 + %wide.trip.count = zext i32 %1 to i64 + br label %8 + +8: ; preds = %.lr.ph, %13 + %indvars.iv = phi i64 [ 0, %.lr.ph ], [ %indvars.iv.next, %13 ] + %9 = mul nsw i64 %indvars.iv, %7 + %10 = getelementptr inbounds i32, i32* %0, i64 %9 + %11 = getelementptr inbounds i32, i32* %10, i64 %indvars.iv + %12 = bitcast i32* %11 to i8* + call void @llvm.memset.p0i8.i64(i8* align 4 %12, i8 0, i64 %6, i1 false) + br label %13 + +13: ; preds = %8 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond = icmp ne i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond, label %8, label %._crit_edge + +._crit_edge: ; preds = %13 + br label %14 + +14: ; preds = %._crit_edge, %3 + ret void +} + +; Function Attrs: argmemonly nofree nounwind willreturn writeonly +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg) diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime.ll @@ -0,0 +1,111 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; REQUIRES: asserts +; RUN: opt -passes="function(loop(loop-idiom,loop-deletion),simplifycfg)" -S < %s | FileCheck %s +; The C code to generate this testcase: +; void test(int n, int m, int *ar) +; { +; long i; +; for (i=0; i=0; i--) { +; int *arr = ar + i * m; // ar[i]; +; memset(arr, 0, m * sizeof(int)); +; } +; } +define void @For_NegativeStride(i32* %0, i32 %1, i32 %2) { +; CHECK-LABEL: @For_NegativeStride( +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i32* [[TMP0:%.*]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = sub nsw i32 [[TMP1:%.*]], 1 +; CHECK-NEXT: [[TMP6:%.*]] = icmp sge i32 [[TMP5]], 0 +; CHECK-NEXT: br i1 [[TMP6]], label [[DOTLR_PH:%.*]], label [[TMP15:%.*]] +; CHECK: .lr.ph: +; CHECK-NEXT: [[TMP7:%.*]] = sext i32 [[TMP2:%.*]] to i64 +; CHECK-NEXT: [[TMP8:%.*]] = mul i64 [[TMP7]], 4 +; CHECK-NEXT: [[TMP9:%.*]] = add i32 [[TMP1]], -1 +; CHECK-NEXT: [[TMP10:%.*]] = sext i32 [[TMP9]] to i64 +; CHECK-NEXT: [[TMP11:%.*]] = sext i32 [[TMP2]] to i64 +; CHECK-NEXT: [[TMP12:%.*]] = add nuw nsw i64 [[TMP10]], 1 +; CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP12]], [[TMP11]] +; CHECK-NEXT: [[TMP14:%.*]] = shl i64 [[TMP13]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[TMP4]], i8 0, i64 [[TMP14]], i1 false) +; CHECK-NEXT: br label [[TMP15]] +; CHECK: 15: +; CHECK-NEXT: ret void +; + %4 = sub nsw i32 %1, 1 + %5 = icmp sge i32 %4, 0 + br i1 %5, label %.lr.ph, label %17 + +.lr.ph: ; preds = %3 + %6 = sext i32 %2 to i64 + %7 = mul i64 %6, 4 + %8 = add i32 %1, -1 + %9 = sext i32 %8 to i64 + %10 = sext i32 %2 to i64 + br label %11 + +11: ; preds = %.lr.ph, %15 + %indvars.iv = phi i64 [ %9, %.lr.ph ], [ %indvars.iv.next, %15 ] + %12 = mul nsw i64 %indvars.iv, %10 + %13 = getelementptr inbounds i32, i32* %0, i64 %12 + %14 = bitcast i32* %13 to i8* + call void @llvm.memset.p0i8.i64(i8* align 4 %14, i8 0, i64 %7, i1 false) + br label %15 + +15: ; preds = %11 + %indvars.iv.next = add nsw i64 %indvars.iv, -1 + %16 = icmp sge i64 %indvars.iv.next, 0 + br i1 %16, label %11, label %._crit_edge + +._crit_edge: ; preds = %15 + br label %17 + +17: ; preds = %._crit_edge, %3 + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg)