diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp --- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -923,7 +923,28 @@ } } - return nullptr; + // Truncate the string to search at most EndOff characters. + Str = Str.substr(0, EndOff); + if (Str.find_first_not_of(Str[0]) != StringRef::npos) + return nullptr; + + // If the source array consists of all equal characters, then for any + // C and N, fold memrchr(S, C, N) to + // N != 0 && N <= sizeof S && *S == C ? S + N - 1 : null + // AKA + // N - 1 < sizeof S && *S == C ? S + N - 1 : null. + Type *SizeTy = Size->getType(); + Type *Int8Ty = B.getInt8Ty(); + Value *NNeZ = B.CreateICmpNE(Size, ConstantInt::get(SizeTy, 0)); + Value *NLeSize = B.CreateICmpULE(Size, ConstantInt::get(SizeTy, Str.size())); + // Slice off the sought character's high end bits. + CharVal = B.CreateTrunc(CharVal, Int8Ty); + Value *And = B.CreateAnd(NNeZ, NLeSize); + Value *CEqS0 = B.CreateICmpEQ(ConstantInt::get(Int8Ty, Str[0]), CharVal); + And = B.CreateAnd(And, CEqS0); + Value *SizeM1 = B.CreateSub(Size, ConstantInt::get(SizeTy, 1)); + Value *SrcPlus = B.CreateGEP(Int8Ty, SrcStr, SizeM1, "memrchr.ptr_plus"); + return B.CreateSelect(And, SrcPlus, NullPtr, "memrchr.sel"); } Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { diff --git a/llvm/test/Transforms/InstCombine/memrchr-4.ll b/llvm/test/Transforms/InstCombine/memrchr-4.ll --- a/llvm/test/Transforms/InstCombine/memrchr-4.ll +++ b/llvm/test/Transforms/InstCombine/memrchr-4.ll @@ -6,16 +6,67 @@ declare i8* @memrchr(i8*, i32, i64) +@ax = external global [0 x i8] +@a1 = constant [1 x i8] c"\01" @a11111 = constant [5 x i8] c"\01\01\01\01\01" @a1110111 = constant [7 x i8] c"\01\01\01\00\01\01\01" +; Fold memrchr(a1, c, n) to *a1 == c ? a1 : null. + +define i8* @fold_memrchr_a1_c_n(i32 %c, i64 %n) { +; CHECK-LABEL: @fold_memrchr_a1_c_n( +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i64 [[N:%.*]], 1 +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[TMP1]], 1 +; CHECK-NEXT: [[TMP4:%.*]] = and i1 [[TMP2]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = add i64 [[N]], -1 +; CHECK-NEXT: [[MEMRCHR_PTR_PLUS:%.*]] = getelementptr [1 x i8], [1 x i8]* @a1, i64 0, i64 [[TMP5]] +; CHECK-NEXT: [[MEMRCHR_SEL:%.*]] = select i1 [[TMP4]], i8* [[MEMRCHR_PTR_PLUS]], i8* null +; CHECK-NEXT: ret i8* [[MEMRCHR_SEL]] +; + %ptr = getelementptr [1 x i8], [1 x i8]* @a1, i64 0, i64 0 + %call = call i8* @memrchr(i8* %ptr, i32 %c, i64 %n) + ret i8* %call +} + + +; Don't fold memrchr(a1 + 1, c, n) to a past-the-end load from a1[1] +; (this could be folded to null). + +define i8* @call_memrchr_a1_p1_c_n(i32 %c, i64 %n) { +; CHECK-LABEL: @call_memrchr_a1_p1_c_n( +; CHECK-NEXT: [[CALL:%.*]] = call i8* @memrchr(i8* getelementptr inbounds ([1 x i8], [1 x i8]* @a1, i64 1, i64 0), i32 [[C:%.*]], i64 [[N:%.*]]) +; CHECK-NEXT: ret i8* [[CALL]] +; + %ptr = getelementptr [1 x i8], [1 x i8]* @a1, i64 0, i64 1 + %call = call i8* @memrchr(i8* %ptr, i32 %c, i64 %n) + ret i8* %call +} + + +; Don't fold memrchr(ax + 1, c, n) to what could be a past-the-end load +; from ax[1] if ax is defined to be a char[1]). + +define i8* @call_memrchr_ax_p1_c_n(i32 %c, i64 %n) { +; CHECK-LABEL: @call_memrchr_ax_p1_c_n( +; CHECK-NEXT: [[CALL:%.*]] = call i8* @memrchr(i8* getelementptr inbounds ([0 x i8], [0 x i8]* @ax, i64 0, i64 1), i32 [[C:%.*]], i64 [[N:%.*]]) +; CHECK-NEXT: ret i8* [[CALL]] +; + %ptr = getelementptr [0 x i8], [0 x i8]* @ax, i64 0, i64 1 + %call = call i8* @memrchr(i8* %ptr, i32 %c, i64 %n) + ret i8* %call +} + + ; Fold memrchr(a11111, c, 5) to *a11111 == c ? a11111 + 5 - 1 : null. define i8* @fold_memrchr_a11111_c_5(i32 %0) { ; CHECK-LABEL: @fold_memrchr_a11111_c_5( -; CHECK-NEXT: [[RET:%.*]] = call i8* @memrchr(i8* noundef nonnull dereferenceable(5) getelementptr inbounds ([5 x i8], [5 x i8]* @a11111, i64 0, i64 0), i32 [[TMP0:%.*]], i64 5) -; CHECK-NEXT: ret i8* [[RET]] +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[TMP0:%.*]] to i8 +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[TMP2]], 1 +; CHECK-NEXT: [[MEMRCHR_SEL:%.*]] = select i1 [[TMP3]], i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a11111, i64 0, i64 4), i8* null +; CHECK-NEXT: ret i8* [[MEMRCHR_SEL]] ; %ptr = getelementptr [5 x i8], [5 x i8]* @a11111, i64 0, i64 0 @@ -28,8 +79,10 @@ define i8* @fold_memrchr_a1110111_c_3(i32 %0) { ; CHECK-LABEL: @fold_memrchr_a1110111_c_3( -; CHECK-NEXT: [[RET:%.*]] = call i8* @memrchr(i8* noundef nonnull dereferenceable(3) getelementptr inbounds ([7 x i8], [7 x i8]* @a1110111, i64 0, i64 0), i32 [[TMP0:%.*]], i64 3) -; CHECK-NEXT: ret i8* [[RET]] +; CHECK-NEXT: [[TMP2:%.*]] = trunc i32 [[TMP0:%.*]] to i8 +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[TMP2]], 1 +; CHECK-NEXT: [[MEMRCHR_SEL:%.*]] = select i1 [[TMP3]], i8* getelementptr inbounds ([7 x i8], [7 x i8]* @a1110111, i64 0, i64 2), i8* null +; CHECK-NEXT: ret i8* [[MEMRCHR_SEL]] ; %ptr = getelementptr [7 x i8], [7 x i8]* @a1110111, i64 0, i64 0