diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -340,7 +340,8 @@ /// If we can compute the length of the string pointed to by the specified /// pointer, return 'len+1'. If we can't, return 0. - uint64_t GetStringLength(const Value *V, unsigned CharSize = 8); + uint64_t GetStringLength(const Value *V, const TargetLibraryInfo *TLI = nullptr, + unsigned CharSize = 8); /// This function returns call pointer argument that is considered the same by /// aliasing rules. You CAN'T use it to replace one value with another. If diff --git a/llvm/lib/Analysis/MemoryBuiltins.cpp b/llvm/lib/Analysis/MemoryBuiltins.cpp --- a/llvm/lib/Analysis/MemoryBuiltins.cpp +++ b/llvm/lib/Analysis/MemoryBuiltins.cpp @@ -374,7 +374,8 @@ // Handle strdup-like functions separately. if (FnData->AllocTy == StrDupLike) { - APInt Size(IntTyBits, GetStringLength(Mapper(CB->getArgOperand(0)))); + APInt Size(IntTyBits, + GetStringLength(Mapper(CB->getArgOperand(0)), TLI)); if (!Size) return None; diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -4199,7 +4199,8 @@ /// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. static uint64_t GetStringLengthH(const Value *V, - SmallPtrSetImpl &PHIs, + SmallPtrSetImpl &PHIs, + const TargetLibraryInfo *TLI, unsigned CharSize) { // Look through noop bitcast instructions. V = V->stripPointerCasts(); @@ -4213,7 +4214,7 @@ // If it was new, see if all the input strings are the same length. uint64_t LenSoFar = ~0ULL; for (Value *IncValue : PN->incoming_values()) { - uint64_t Len = GetStringLengthH(IncValue, PHIs, CharSize); + uint64_t Len = GetStringLengthH(IncValue, PHIs, TLI, CharSize); if (Len == 0) return 0; // Unknown length -> unknown. if (Len == ~0ULL) continue; @@ -4229,9 +4230,9 @@ // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y) if (const SelectInst *SI = dyn_cast(V)) { - uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, CharSize); + uint64_t Len1 = GetStringLengthH(SI->getTrueValue(), PHIs, TLI, CharSize); if (Len1 == 0) return 0; - uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, CharSize); + uint64_t Len2 = GetStringLengthH(SI->getFalseValue(), PHIs, TLI, CharSize); if (Len2 == 0) return 0; if (Len1 == ~0ULL) return Len2; if (Len2 == ~0ULL) return Len1; @@ -4239,6 +4240,22 @@ return Len1; } + if (auto *CB = dyn_cast(V)) { + Function *Callee = CB->getCalledFunction(); + if (!Callee) + return 0; + + LibFunc TLIFn; + if (!TLI || !TLI->getLibFunc(*CB->getCalledFunction(), TLIFn) || + !TLI->has(TLIFn)) + return 0; + + if (TLIFn == LibFunc_strdup || TLIFn == LibFunc_dunder_strdup) + return GetStringLengthH(CB->getArgOperand(0), PHIs, TLI, CharSize); + + return 0; + } + // Otherwise, see if we can read the string. ConstantDataArraySlice Slice; if (!getConstantDataArrayInfo(V, Slice, CharSize)) @@ -4259,12 +4276,12 @@ /// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. -uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { +uint64_t llvm::GetStringLength(const Value *V, const TargetLibraryInfo *TLI, unsigned CharSize) { if (!V->getType()->isPointerTy()) return 0; SmallPtrSet PHIs; - uint64_t Len = GetStringLengthH(V, PHIs, CharSize); + uint64_t Len = GetStringLengthH(V, PHIs, TLI, CharSize); // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return // an empty string as a length. return Len == ~0ULL ? 1 : Len; diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -212,7 +212,7 @@ annotateNonNullNoUndefBasedOnAccess(CI, {0, 1}); // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); + uint64_t Len = GetStringLength(Src, TLI); if (Len) annotateDereferenceableBytes(CI, 1, Len); else @@ -269,7 +269,7 @@ } // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); + uint64_t SrcLen = GetStringLength(Src, TLI); if (SrcLen) { annotateDereferenceableBytes(CI, 1, SrcLen); --SrcLen; // Unbias length. @@ -300,7 +300,7 @@ // of the input string and turn this into memchr. ConstantInt *CharC = dyn_cast(CI->getArgOperand(1)); if (!CharC) { - uint64_t Len = GetStringLength(SrcStr); + uint64_t Len = GetStringLength(SrcStr, TLI); if (Len) annotateDereferenceableBytes(CI, 0, Len); else @@ -387,10 +387,10 @@ CI->getType()); // strcmp(P, "x") -> memcmp(P, "x", 2) - uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len1 = GetStringLength(Str1P, TLI); if (Len1) annotateDereferenceableBytes(CI, 0, Len1); - uint64_t Len2 = GetStringLength(Str2P); + uint64_t Len2 = GetStringLength(Str2P, TLI); if (Len2) annotateDereferenceableBytes(CI, 1, Len2); @@ -464,10 +464,10 @@ return B.CreateZExt(B.CreateLoad(B.getInt8Ty(), Str1P, "strcmpload"), CI->getType()); - uint64_t Len1 = GetStringLength(Str1P); + uint64_t Len1 = GetStringLength(Str1P, TLI); if (Len1) annotateDereferenceableBytes(CI, 0, Len1); - uint64_t Len2 = GetStringLength(Str2P); + uint64_t Len2 = GetStringLength(Str2P, TLI); if (Len2) annotateDereferenceableBytes(CI, 1, Len2); @@ -496,7 +496,7 @@ Value *LibCallSimplifier::optimizeStrNDup(CallInst *CI, IRBuilderBase &B) { Value *Src = CI->getArgOperand(0); ConstantInt *Size = dyn_cast(CI->getArgOperand(1)); - uint64_t SrcLen = GetStringLength(Src); + uint64_t SrcLen = GetStringLength(Src, TLI); if (SrcLen && Size) { annotateDereferenceableBytes(CI, 0, SrcLen); if (SrcLen <= Size->getZExtValue() + 1) @@ -513,7 +513,7 @@ annotateNonNullNoUndefBasedOnAccess(CI, {0, 1}); // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); + uint64_t Len = GetStringLength(Src, TLI); if (Len) annotateDereferenceableBytes(CI, 1, Len); else @@ -544,7 +544,7 @@ } // See if we can get the length of the input string. - uint64_t Len = GetStringLength(Src); + uint64_t Len = GetStringLength(Src, TLI); if (Len) annotateDereferenceableBytes(CI, 1, Len); else @@ -584,7 +584,7 @@ return Dst; // See if we can get the length of the input string. - uint64_t SrcLen = GetStringLength(Src); + uint64_t SrcLen = GetStringLength(Src, TLI); if (SrcLen) { annotateDereferenceableBytes(CI, 1, SrcLen); --SrcLen; // Unbias length. @@ -633,7 +633,7 @@ Value *Src = CI->getArgOperand(0); // Constant folding: strlen("xyz") -> 3 - if (uint64_t Len = GetStringLength(Src, CharSize)) + if (uint64_t Len = GetStringLength(Src, TLI, CharSize)) return ConstantInt::get(CI->getType(), Len - 1); // If s is a constant pointer pointing to a string literal, we can fold @@ -688,8 +688,10 @@ // strlen(x?"foo":"bars") --> x ? 3 : 4 if (SelectInst *SI = dyn_cast(Src)) { - uint64_t LenTrue = GetStringLength(SI->getTrueValue(), CharSize); - uint64_t LenFalse = GetStringLength(SI->getFalseValue(), CharSize); + uint64_t LenTrue = + GetStringLength(SI->getTrueValue(), TLI, CharSize); + uint64_t LenFalse = + GetStringLength(SI->getFalseValue(), TLI, CharSize); if (LenTrue && LenFalse) { ORE.emit([&]() { return OptimizationRemark("instcombine", "simplify-libcalls", CI) @@ -2511,7 +2513,7 @@ // sprintf(dest, "%s", str) -> strcpy(dest, str) return copyFlags(*CI, emitStrCpy(Dest, CI->getArgOperand(2), B, TLI)); - uint64_t SrcLen = GetStringLength(CI->getArgOperand(2)); + uint64_t SrcLen = GetStringLength(CI->getArgOperand(2), TLI); if (SrcLen) { B.CreateMemCpy( Dest, Align(1), CI->getArgOperand(2), Align(1), @@ -2803,7 +2805,7 @@ return nullptr; // fputs(s,F) --> fwrite(s,strlen(s),1,F) - uint64_t Len = GetStringLength(CI->getArgOperand(0)); + uint64_t Len = GetStringLength(CI->getArgOperand(0), TLI); if (!Len) return nullptr; @@ -3247,7 +3249,8 @@ if (OnlyLowerUnknownSize) return false; if (StrOp) { - uint64_t Len = GetStringLength(CI->getArgOperand(*StrOp)); + uint64_t Len = + GetStringLength(CI->getArgOperand(*StrOp), TLI); // If the length is 0 we don't know how long it is and so we can't // remove the check. if (Len) @@ -3351,7 +3354,7 @@ return nullptr; // Maybe we can stil fold __st[rp]cpy_chk to __memcpy_chk. - uint64_t Len = GetStringLength(Src); + uint64_t Len = GetStringLength(Src, TLI); if (Len) annotateDereferenceableBytes(CI, 1, Len); else diff --git a/llvm/test/Transforms/InstCombine/strlen-1.ll b/llvm/test/Transforms/InstCombine/strlen-1.ll --- a/llvm/test/Transforms/InstCombine/strlen-1.ll +++ b/llvm/test/Transforms/InstCombine/strlen-1.ll @@ -14,6 +14,7 @@ @null_hello_mid = constant [13 x i8] c"hello wor\00ld\00" declare i32 @strlen(i8*) +declare noalias i8* @strdup(i8*) ; Check strlen(string constant) -> integer constant. @@ -203,7 +204,7 @@ define i32 @test1(i8* %str) { ; CHECK-LABEL: @test1( -; CHECK-NEXT: [[LEN:%.*]] = tail call i32 @strlen(i8* noundef nonnull dereferenceable(1) [[STR:%.*]]) [[ATTR1:#.*]] +; CHECK-NEXT: [[LEN:%.*]] = tail call i32 @strlen(i8* noundef nonnull dereferenceable(1) [[STR:%.*]]) #[[ATTR1:[0-9]+]] ; CHECK-NEXT: ret i32 [[LEN]] ; %len = tail call i32 @strlen(i8* %str) nounwind @@ -212,7 +213,7 @@ define i32 @test2(i8* %str) #0 { ; CHECK-LABEL: @test2( -; CHECK-NEXT: [[LEN:%.*]] = tail call i32 @strlen(i8* noundef [[STR:%.*]]) [[ATTR1]] +; CHECK-NEXT: [[LEN:%.*]] = tail call i32 @strlen(i8* noundef [[STR:%.*]]) #[[ATTR1]] ; CHECK-NEXT: ret i32 [[LEN]] ; %len = tail call i32 @strlen(i8* %str) nounwind @@ -270,4 +271,17 @@ ret i1 %cmp } +; Check strlen(strdup(string constant)) -> integer constant. + +define i32 @test_simplify_strduped_constant() { +; CHECK-LABEL: @test_simplify_strduped_constant( +; CHECK-NEXT: ret i32 5 +; + %hello_p = getelementptr [6 x i8], [6 x i8]* @hello, i32 0, i32 0 + %hello_s = call i8* @strdup(i8* %hello_p) + %hello_l = call i32 @strlen(i8* %hello_s) + ret i32 %hello_l +} + + attributes #0 = { null_pointer_is_valid }