diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -896,8 +896,8 @@ /// processLoopMemSet - See if this memset can be promoted to a large memset. bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, const SCEV *BECount) { - // We can only handle non-volatile memsets with a constant size. - if (MSI->isVolatile() || !isa(MSI->getLength())) + // We can only handle non-volatile memsets. + if (MSI->isVolatile()) return false; // If we're not allowed to hack on memset, we fail. @@ -910,23 +910,72 @@ // loop, which indicates a strided store. If we have something else, it's a // random store we can't handle. const SCEVAddRecExpr *Ev = dyn_cast(SE->getSCEV(Pointer)); - if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) + if (!Ev || Ev->getLoop() != CurLoop) return false; - - // Reject memsets that are so large that they overflow an unsigned. - uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); - if ((SizeInBytes >> 32) != 0) + if (!Ev->isAffine()) { + LLVM_DEBUG(dbgs() << " Pointer is not affine, abort\n"); return false; + } - // Check to see if the stride matches the size of the memset. If so, then we - // know that every byte is touched in the loop. - const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); - if (!ConstStride) + const SCEV *PointerStrideSCEV = Ev->getOperand(1); + const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); + if (!PointerStrideSCEV || !MemsetSizeSCEV) return false; - APInt Stride = ConstStride->getAPInt(); - if (SizeInBytes != Stride && SizeInBytes != -Stride) - return false; + bool IsNegStride = false; + const bool IsConstantSize = isa(MSI->getLength()); + + if (IsConstantSize) { + // Memset size is constant. + // Check if the pointer stride matches the memset size. If so, then + // we know that every byte is touched in the loop. + LLVM_DEBUG(dbgs() << " memset size is constant\n"); + uint64_t SizeInBytes = cast(MSI->getLength())->getZExtValue(); + const SCEVConstant *ConstStride = dyn_cast(Ev->getOperand(1)); + if (!ConstStride) + return false; + + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + IsNegStride = SizeInBytes == -Stride; + } else { + // Memset size is non-constant. + // Check if the pointer stride matches the memset size. + // To be conservative, the pass would not promote pointers that aren't in + // address space zero. Also, the pass only handles memset length and stride + // that are invariant for the top level loop. + LLVM_DEBUG(dbgs() << " memset size is non-constant\n"); + if (Pointer->getType()->getPointerAddressSpace() != 0) { + LLVM_DEBUG(dbgs() << " pointer is not in address space zero, " + << "abort\n"); + return false; + } + if (!SE->isLoopInvariant(MemsetSizeSCEV, CurLoop)) { + LLVM_DEBUG(dbgs() << " memset size is not a loop-invariant, " + << "abort\n"); + return false; + } + + // Compare positive direction PointerStrideSCEV with MemsetSizeSCEV + IsNegStride = PointerStrideSCEV->isNonConstantNegative(); + const SCEV *PositiveStrideSCEV = + IsNegStride ? SE->getNegativeSCEV(PointerStrideSCEV) + : PointerStrideSCEV; + LLVM_DEBUG(dbgs() << " MemsetSizeSCEV: " << *MemsetSizeSCEV << "\n" + << " PositiveStrideSCEV: " << *PositiveStrideSCEV + << "\n"); + + if (PositiveStrideSCEV != MemsetSizeSCEV) { + // TODO: folding can be done to the SCEVs + // The folding is to fold expressions that is covered by the loop guard + // at loop entry. After the folding, compare again and proceed + // optimization if equal. + LLVM_DEBUG(dbgs() << " SCEV don't match, abort\n"); + return false; + } + } // Verify that the memset value is loop invariant. If not, we can't promote // the memset. @@ -936,7 +985,6 @@ SmallPtrSet MSIs; MSIs.insert(MSI); - bool IsNegStride = SizeInBytes == -Stride; return processLoopStridedStore(Pointer, SE->getSCEV(MSI->getLength()), MaybeAlign(MSI->getDestAlignment()), SplatValue, MSI, MSIs, Ev, BECount, @@ -1028,20 +1076,6 @@ /// /// This also maps the SCEV into the provided type and tries to handle the /// computation in a way that will fold cleanly. -static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, - unsigned StoreSize, Loop *CurLoop, - const DataLayout *DL, ScalarEvolution *SE) { - const SCEV *TripCountSCEV = getTripCount(BECount, IntPtr, CurLoop, DL, SE); - - // And scale it based on the store size. - if (StoreSize != 1) { - return SE->getMulExpr(TripCountSCEV, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); - } - return TripCountSCEV; -} - -/// getNumBytes that takes StoreSize as a SCEV static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, const SCEV *StoreSizeSCEV, Loop *CurLoop, const DataLayout *DL, ScalarEvolution *SE) { @@ -1342,8 +1376,8 @@ // Okay, everything is safe, we can transform this! - const SCEV *NumBytesS = - getNumBytes(BECount, IntIdxTy, StoreSize, CurLoop, DL, SE); + const SCEV *NumBytesS = getNumBytes( + BECount, IntIdxTy, SE->getConstant(IntIdxTy, StoreSize), CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntIdxTy, Preheader->getTerminator()); diff --git a/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopIdiom/memset-runtime-debug.ll @@ -0,0 +1,270 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; REQUIRES: asserts +; RUN: opt < %s -S -debug -passes=loop-idiom 2>&1 | FileCheck %s +; The C code to generate this testcase: +; void test(int *ar, int n, int m) +; { +; long i; +; for (i=0; i +; CHECK-NEXT: PositiveStrideSCEV: (4 + (4 * (sext i32 %m to i64))) +; CHECK-NEXT: SCEV don't match, abort +; CHECK: loop-idiom Scanning: F[NonZeroAddressSpace] Countable Loop %for.cond1.preheader +; CHECK-NEXT: memset size is non-constant +; CHECK-NEXT: pointer is not in address space zero, abort +; CHECK: loop-idiom Scanning: F[NonAffinePointer] Countable Loop %for.body +; CHECK-NEXT: Pointer is not affine, abort + +define void @MemsetSize_LoopVariant(i32* %ar, i32 %n, i32 %m) { +; CHECK-LABEL: @MemsetSize_LoopVariant( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[N:%.*]] to i64 +; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i64 0, [[CONV]] +; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[M:%.*]] to i64 +; CHECK-NEXT: [[CONV2:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[MUL3:%.*]] = mul i64 [[CONV2]], 4 +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[I_02:%.*]] = phi i64 [ 0, [[FOR_BODY_LR_PH]] ], [ [[INC:%.*]], [[FOR_INC:%.*]] ] +; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[I_02]], [[CONV1]] +; CHECK-NEXT: [[ADD_PTR:%.*]] = getelementptr inbounds i32, i32* [[AR:%.*]], i64 [[MUL]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[ADD_PTR]] to i8* +; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[I_02]], [[MUL3]] +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[TMP0]], i8 0, i64 [[ADD]], i1 false) +; CHECK-NEXT: br label [[FOR_INC]] +; CHECK: for.inc: +; CHECK-NEXT: [[INC]] = add nuw nsw i64 [[I_02]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INC]], [[CONV]] +; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_COND_FOR_END_CRIT_EDGE:%.*]] +; CHECK: for.cond.for.end_crit_edge: +; CHECK-NEXT: br label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + %conv = sext i32 %n to i64 + %cmp1 = icmp slt i64 0, %conv + br i1 %cmp1, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %conv1 = sext i32 %m to i64 + %conv2 = sext i32 %m to i64 + %mul3 = mul i64 %conv2, 4 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.inc + %i.02 = phi i64 [ 0, %for.body.lr.ph ], [ %inc, %for.inc ] + %mul = mul nsw i64 %i.02, %conv1 + %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %mul + %0 = bitcast i32* %add.ptr to i8* + %add = add nsw i64 %i.02, %mul3 + call void @llvm.memset.p0i8.i64(i8* align 4 %0, i8 0, i64 %add, i1 false) + br label %for.inc + +for.inc: ; preds = %for.body + %inc = add nuw nsw i64 %i.02, 1 + %cmp = icmp slt i64 %inc, %conv + br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge + +for.cond.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond.for.end_crit_edge, %entry + ret void +} +; void test(int *ar, int n, int m) +; { +; long i; +; for (i=0; i=0; i--) { +; int *arr = ar + i * m; +; memset(arr, 0, m * sizeof(int)); +; } +; } +define void @For_NegativeStride(i32* %ar, i32 %n, i32 %m) { +; CHECK-LABEL: @For_NegativeStride( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[AR1:%.*]] = bitcast i32* [[AR:%.*]] to i8* +; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[N:%.*]], 1 +; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[SUB]] to i64 +; CHECK-NEXT: [[CMP1:%.*]] = icmp sge i64 [[CONV]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label [[FOR_BODY_LR_PH:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.lr.ph: +; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[M:%.*]] to i64 +; CHECK-NEXT: [[CONV2:%.*]] = sext i32 [[M]] to i64 +; CHECK-NEXT: [[MUL3:%.*]] = mul i64 [[CONV2]], 4 +; CHECK-NEXT: [[TMP0:%.*]] = sub i64 [[CONV]], -1 +; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[CONV1]], [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 2 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 4 [[AR1]], i8 0, i64 [[TMP2]], i1 false) +; CHECK-NEXT: br label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; +entry: + %sub = sub nsw i32 %n, 1 + %conv = sext i32 %sub to i64 + %cmp1 = icmp sge i64 %conv, 0 + br i1 %cmp1, label %for.body.lr.ph, label %for.end + +for.body.lr.ph: ; preds = %entry + %conv1 = sext i32 %m to i64 + %conv2 = sext i32 %m to i64 + %mul3 = mul i64 %conv2, 4 + br label %for.body + +for.body: ; preds = %for.body.lr.ph, %for.inc + %i.02 = phi i64 [ %conv, %for.body.lr.ph ], [ %dec, %for.inc ] + %mul = mul nsw i64 %i.02, %conv1 + %add.ptr = getelementptr inbounds i32, i32* %ar, i64 %mul + %0 = bitcast i32* %add.ptr to i8* + call void @llvm.memset.p0i8.i64(i8* align 4 %0, i8 0, i64 %mul3, i1 false) + br label %for.inc + +for.inc: ; preds = %for.body + %dec = add nsw i64 %i.02, -1 + %cmp = icmp sge i64 %dec, 0 + br i1 %cmp, label %for.body, label %for.cond.for.end_crit_edge + +for.cond.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond.for.end_crit_edge, %entry + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg)