diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -295,31 +295,69 @@ return copyFlags(*CI, emitStrLenMemCpy(Src, Dst, SrcLen, B)); } +// Helper to transform memchr(S, C, N) == S to N && *S == C and, when +// NBytes is null, strchr(S, C) to *S == C. A precondition of the function +// is that either S is dereferenceable or the value of N is nonzero. +static Value* memChrToCharCompare(CallInst *CI, Value *NBytes, + IRBuilderBase &B, const DataLayout &DL) +{ + Value *Src = CI->getArgOperand(0); + Value *CharVal = CI->getArgOperand(1); + + // Fold memchr(A, C, N) == A to N && *A == C. + Type *CharTy = B.getInt8Ty(); + Value *Char0 = B.CreateLoad(CharTy, Src); + CharVal = B.CreateTrunc(CharVal, CharTy); + Value *Cmp = B.CreateICmpEQ(Char0, CharVal, "char0cmp"); + + if (NBytes) { + Value *Zero = ConstantInt::get(NBytes->getType(), 0); + Value *And = B.CreateICmpNE(NBytes, Zero); + Cmp = B.CreateLogicalAnd(And, Cmp); + } + + Value *NullPtr = Constant::getNullValue(CI->getType()); + return B.CreateSelect(Cmp, Src, NullPtr); +} + Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) { - Function *Callee = CI->getCalledFunction(); - FunctionType *FT = Callee->getFunctionType(); Value *SrcStr = CI->getArgOperand(0); + Value *CharVal = CI->getArgOperand(1); annotateNonNullNoUndefBasedOnAccess(CI, 0); + if (isOnlyUsedInEqualityComparison(CI, SrcStr)) + return memChrToCharCompare(CI, nullptr, B, DL); + // If the second operand is non-constant, see if we can compute the length // of the input string and turn this into memchr. - ConstantInt *CharC = dyn_cast(CI->getArgOperand(1)); + ConstantInt *CharC = dyn_cast(CharVal); if (!CharC) { uint64_t Len = GetStringLength(SrcStr); if (Len) annotateDereferenceableBytes(CI, 0, Len); else return nullptr; + + Function *Callee = CI->getCalledFunction(); + FunctionType *FT = Callee->getFunctionType(); if (!FT->getParamType(1)->isIntegerTy(32)) // memchr needs i32. return nullptr; return copyFlags( *CI, - emitMemChr(SrcStr, CI->getArgOperand(1), // include nul. + emitMemChr(SrcStr, CharVal, // include nul. ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len), B, DL, TLI)); } + if (CharC->isZero()) { + Value *NullPtr = Constant::getNullValue(CI->getType()); + if (isOnlyUsedInEqualityComparison(CI, NullPtr)) + // Pre-empt the transformation to strlen below and fold + // strchr(A, '\0') == null to false. + return B.CreateIntToPtr(B.getTrue(), CI->getType()); + } + // Otherwise, the character is a constant, see if the first argument is // a string literal. If so, we can constant fold. StringRef Str; @@ -1008,8 +1046,12 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { Value *SrcStr = CI->getArgOperand(0); Value *Size = CI->getArgOperand(2); - if (isKnownNonZero(Size, DL)) + + if (isKnownNonZero(Size, DL)) { annotateNonNullNoUndefBasedOnAccess(CI, 0); + if (isOnlyUsedInEqualityComparison(CI, SrcStr)) + return memChrToCharCompare(CI, Size, B, DL); + } Value *CharVal = CI->getArgOperand(1); ConstantInt *CharC = dyn_cast(CharVal); @@ -1099,9 +1141,16 @@ return B.CreateSelect(And, SrcStr, Sel1, "memchr.sel2"); } - if (!LenC) + if (!LenC) { + if (isOnlyUsedInEqualityComparison(CI, SrcStr)) + // S is dereferenceable so it's safe to load from it and fold + // memchr(S, C, N) == S to N && *S == C for any C and N. + // TODO: This is safe even even for nonconstant S. + return memChrToCharCompare(CI, Size, B, DL); + // From now on we need a constant length and constant array. return nullptr; + } // If the char is variable but the input str and length are not we can turn // this memchr call into a simple bit field test. Of course this only works diff --git a/llvm/test/Transforms/InstCombine/memchr-11.ll b/llvm/test/Transforms/InstCombine/memchr-11.ll --- a/llvm/test/Transforms/InstCombine/memchr-11.ll +++ b/llvm/test/Transforms/InstCombine/memchr-11.ll @@ -13,9 +13,9 @@ define i1 @fold_memchr_a_c_5_eq_a(i32 %c) { ; CHECK-LABEL: @fold_memchr_a_c_5_eq_a( -; CHECK-NEXT: [[Q:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0), i32 [[C:%.*]], i64 5) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[Q]], getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0) -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], 49 +; CHECK-NEXT: ret i1 [[CHAR0CMP]] ; %p = getelementptr [5 x i8], [5 x i8]* @a5, i32 0, i32 0 %q = call i8* @memchr(i8* %p, i32 %c, i64 5) @@ -30,9 +30,11 @@ define i1 @fold_memchr_a_c_n_eq_a(i32 %c, i64 %n) { ; CHECK-LABEL: @fold_memchr_a_c_n_eq_a( -; CHECK-NEXT: [[Q:%.*]] = call i8* @memchr(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0), i32 [[C:%.*]], i64 [[N:%.*]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[Q]], getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0) -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], 49 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[N:%.*]], 0 +; CHECK-NEXT: [[TMP3:%.*]] = select i1 [[TMP2]], i1 [[CHAR0CMP]], i1 false +; CHECK-NEXT: ret i1 [[TMP3]] ; %p = getelementptr [5 x i8], [5 x i8]* @a5, i32 0, i32 0 %q = call i8* @memchr(i8* %p, i32 %c, i64 %n) @@ -61,9 +63,10 @@ define i1 @fold_memchr_s_c_15_eq_s(i8* %s, i32 %c) { ; CHECK-LABEL: @fold_memchr_s_c_15_eq_s( -; CHECK-NEXT: [[P:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]], i64 15) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[P]], [[S]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret i1 [[CHAR0CMP]] ; %p = call i8* @memchr(i8* %s, i32 %c, i64 15) %cmp = icmp eq i8* %p, %s @@ -75,9 +78,10 @@ define i1 @fold_memchr_s_c_17_neq_s(i8* %s, i32 %c) { ; CHECK-LABEL: @fold_memchr_s_c_17_neq_s( -; CHECK-NEXT: [[P:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]], i64 17) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8* [[P]], [[S]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp ne i8 [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret i1 [[CHAR0CMP]] ; %p = call i8* @memchr(i8* %s, i32 %c, i64 17) %cmp = icmp ne i8* %p, %s @@ -89,10 +93,10 @@ define i1 @fold_memchr_s_c_nz_eq_s(i8* %s, i32 %c, i64 %n) { ; CHECK-LABEL: @fold_memchr_s_c_nz_eq_s( -; CHECK-NEXT: [[NZ:%.*]] = or i64 [[N:%.*]], 1 -; CHECK-NEXT: [[P:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]], i64 [[NZ]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[P]], [[S]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret i1 [[CHAR0CMP]] ; %nz = or i64 %n, 1 %p = call i8* @memchr(i8* %s, i32 %c, i64 %nz) diff --git a/llvm/test/Transforms/InstCombine/strchr-4.ll b/llvm/test/Transforms/InstCombine/strchr-4.ll --- a/llvm/test/Transforms/InstCombine/strchr-4.ll +++ b/llvm/test/Transforms/InstCombine/strchr-4.ll @@ -11,9 +11,10 @@ define i1 @fold_strchr_s_c_eq_s(i8* %s, i32 %c) { ; CHECK-LABEL: @fold_strchr_s_c_eq_s( -; CHECK-NEXT: [[P:%.*]] = call i8* @strchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[P]], [[S]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret i1 [[CHAR0CMP]] ; %p = call i8* @strchr(i8* %s, i32 %c) %cmp = icmp eq i8* %p, %s @@ -25,9 +26,10 @@ define i1 @fold_strchr_s_c_neq_s(i8* %s, i32 %c) { ; CHECK-LABEL: @fold_strchr_s_c_neq_s( -; CHECK-NEXT: [[P:%.*]] = call i8* @strchr(i8* noundef nonnull dereferenceable(1) [[S:%.*]], i32 [[C:%.*]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8* [[P]], [[S]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = load i8, i8* [[S:%.*]], align 1 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp ne i8 [[TMP1]], [[TMP2]] +; CHECK-NEXT: ret i1 [[CHAR0CMP]] ; %p = call i8* @strchr(i8* %s, i32 %c) %cmp = icmp ne i8* %p, %s @@ -40,8 +42,7 @@ define i1 @fold_strchr_s_nul_eqz(i8* %s) { ; CHECK-LABEL: @fold_strchr_s_nul_eqz( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[S:%.*]], null -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 false ; %p = call i8* @strchr(i8* %s, i32 0) %cmp = icmp eq i8* %p, null @@ -53,8 +54,7 @@ define i1 @fold_strchr_s_nul_nez(i8* %s) { ; CHECK-LABEL: @fold_strchr_s_nul_nez( -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8* [[S:%.*]], null -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %p = call i8* @strchr(i8* %s, i32 0) %cmp = icmp ne i8* %p, null @@ -68,9 +68,9 @@ define i1 @fold_strchr_a_c_eq_a(i32 %c) { ; CHECK-LABEL: @fold_strchr_a_c_eq_a( -; CHECK-NEXT: [[MEMCHR:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0), i32 [[C:%.*]], i64 6) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8* [[MEMCHR]], getelementptr inbounds ([5 x i8], [5 x i8]* @a5, i64 0, i64 0) -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], 49 +; CHECK-NEXT: ret i1 [[CHAR0CMP]] ; %p = getelementptr [5 x i8], [5 x i8]* @a5, i32 0, i32 0 %q = call i8* @strchr(i8* %p, i32 %c)