Index: lib/Transforms/Scalar/LoopIdiomRecognize.cpp =================================================================== --- lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -780,6 +780,41 @@ return SE->getMinusSCEV(Start, Index); } +/// Compute the number of bytes as a SCEV from the backedge taken count. +/// +/// This also maps the SCEV into the provided type and tries to handle the +/// computation in a way that will fold cleanly. +static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, + unsigned StoreSize, Loop *CurLoop, + const DataLayout *DL, ScalarEvolution *SE) { + const SCEV *NumBytesS; + // The # stored bytes is (BECount+1)*Size. Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + NumBytesS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + // And scale it based on the store size. + if (StoreSize != 1) { + NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), + SCEV::FlagNUW); + } + return NumBytesS; +} + /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( @@ -837,16 +872,8 @@ // Okay, everything looks good, insert the memset. - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtr); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtr), SCEV::FlagNUW); - if (StoreSize != 1) { - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); - } + getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -976,16 +1003,8 @@ // Okay, everything is safe, we can transform this! - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); - - if (StoreSize != 1) - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), - SCEV::FlagNUW); + getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator()); Index: test/Transforms/LoopIdiom/basic.ll =================================================================== --- test/Transforms/LoopIdiom/basic.ll +++ test/Transforms/LoopIdiom/basic.ll @@ -563,6 +563,75 @@ ; CHECK: ret void } +; Handle loops where the trip count is a narrow integer that needs to be +; extended. +define void @form_memset_narrow_size(i64* %ptr, i32 %size) { +; CHECK-LABEL: @form_memset_narrow_size( +entry: + %cmp1 = icmp sgt i32 %size, 0 + br i1 %cmp1, label %loop.ph, label %exit +; CHECK: entry: +; CHECK: %[[C1:.*]] = icmp sgt i32 %size, 0 +; CHECK-NEXT: br i1 %[[C1]], label %loop.ph, label %exit + +loop.ph: + br label %loop.body +; CHECK: loop.ph: +; CHECK-NEXT: %[[ZEXT_SIZE:.*]] = zext i32 %size to i64 +; CHECK-NEXT: %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3 +; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %{{.*}}, i8 0, i64 %[[SCALED_SIZE]], i32 8, i1 false) + +loop.body: + %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ] + %idxprom = sext i32 %storemerge4 to i64 + %arrayidx = getelementptr inbounds i64, i64* %ptr, i64 %idxprom + store i64 0, i64* %arrayidx, align 8 + %inc = add nsw i32 %storemerge4, 1 + %cmp2 = icmp slt i32 %inc, %size + br i1 %cmp2, label %loop.body, label %loop.exit + +loop.exit: + br label %exit + +exit: + ret void +} + +define void @form_memcpy_narrow_size(i64* noalias %dst, i64* noalias %src, i32 %size) { +; CHECK-LABEL: @form_memcpy_narrow_size( +entry: + %cmp1 = icmp sgt i32 %size, 0 + br i1 %cmp1, label %loop.ph, label %exit +; CHECK: entry: +; CHECK: %[[C1:.*]] = icmp sgt i32 %size, 0 +; CHECK-NEXT: br i1 %[[C1]], label %loop.ph, label %exit + +loop.ph: + br label %loop.body +; CHECK: loop.ph: +; CHECK-NEXT: %[[ZEXT_SIZE:.*]] = zext i32 %size to i64 +; CHECK-NEXT: %[[SCALED_SIZE:.*]] = shl i64 %[[ZEXT_SIZE]], 3 +; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* %{{.*}}, i8* %{{.*}}, i64 %[[SCALED_SIZE]], i32 8, i1 false) + +loop.body: + %storemerge4 = phi i32 [ 0, %loop.ph ], [ %inc, %loop.body ] + %idxprom1 = sext i32 %storemerge4 to i64 + %arrayidx1 = getelementptr inbounds i64, i64* %src, i64 %idxprom1 + %v = load i64, i64* %arrayidx1, align 8 + %idxprom2 = sext i32 %storemerge4 to i64 + %arrayidx2 = getelementptr inbounds i64, i64* %dst, i64 %idxprom2 + store i64 %v, i64* %arrayidx2, align 8 + %inc = add nsw i32 %storemerge4, 1 + %cmp2 = icmp slt i32 %inc, %size + br i1 %cmp2, label %loop.body, label %loop.exit + +loop.exit: + br label %exit + +exit: + ret void +} + ; Validate that "memset_pattern" has the proper attributes. ; CHECK: declare void @memset_pattern16(i8* nocapture, i8* nocapture readonly, i64) [[ATTRS:#[0-9]+]] ; CHECK: [[ATTRS]] = { argmemonly }