Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -884,26 +884,62 @@ Value *SrcStr = CI->getArgOperand(0); Value *Size = CI->getArgOperand(2); annotateNonNullAndDereferenceable(CI, 0, Size, DL); - ConstantInt *CharC = dyn_cast(CI->getArgOperand(1)); + Value *CharVal = CI->getArgOperand(1); + ConstantInt *CharC = dyn_cast(CharVal); ConstantInt *LenC = dyn_cast(Size); + // Set to the Size argument value if it's constant or left at maximum. + uint64_t MaxLen = UINT64_MAX; // memchr(x, y, 0) -> null if (LenC) { if (LenC->isZero()) return Constant::getNullValue(CI->getType()); - } else { - // From now on we need at least constant length and string. - return nullptr; + MaxLen = LenC->getZExtValue(); + } + + if (MaxLen == 1) { + // Fold memchr(x, y, 1) --> *x == y ? x : null for any x and y, + // constant or otherwise. + Value *Val = B.CreateLoad(B.getInt8Ty(), SrcStr, "memchr.char0"); + // Slice off the character's high end bits. + CharVal = B.CreateTrunc(CharVal, B.getInt8Ty()); + Value *Cmp = B.CreateICmpEQ(Val, CharVal, "memchr.char0cmp"); + Value *NullPtr = Constant::getNullValue(CI->getType()); + return B.CreateSelect(Cmp, SrcStr, NullPtr, "memchr.sel"); } StringRef Str; if (!getConstantStringInfo(SrcStr, Str, 0, /*TrimAtNul=*/false)) return nullptr; - // Truncate the string to LenC. If Str is smaller than LenC we will still only - // scan the string, as reading past the end of it is undefined and we can just - // return null if we don't find the char. - Str = Str.substr(0, LenC->getZExtValue()); + if (CharC) { + size_t Pos = Str.find(CharC->getZExtValue()); + if (Pos == StringRef::npos) + // When the character is not in the source array fold the result + // to null regardless of Size. + return Constant::getNullValue(CI->getType()); + + // Fold memchr(s, c, n) -> n <= Pos ? null : s + Pos + // When the constant Size is less than or equal to the character + // position also fold the result to null. + Value *Cmp = B.CreateICmpULE(Size, ConstantInt::get(Size->getType(), Pos), + "memchr.cmp"); + Value *NullPtr = Constant::getNullValue(CI->getType()); + Value *SrcPlus = + B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(Pos), "memchr.ptr"); + return B.CreateSelect(Cmp, NullPtr, SrcPlus); + } + + if (!LenC) + // From now on we need a constant length and constant array. + return nullptr; + + // Truncate the string to at most the constant LenC which at this point + // is greater than 1. + Str = Str.substr(0, MaxLen); + + if (Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI)) + return nullptr; // If the char is variable but the input str and length are not we can turn // this memchr call into a simple bit field test. Of course this only works @@ -915,58 +951,44 @@ // memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n'))) // != 0 // after bounds check. - if (!CharC && !Str.empty() && isOnlyUsedInZeroEqualityComparison(CI)) { - unsigned char Max = - *std::max_element(reinterpret_cast(Str.begin()), - reinterpret_cast(Str.end())); - - // Make sure the bit field we're about to create fits in a register on the - // target. - // FIXME: On a 64 bit architecture this prevents us from using the - // interesting range of alpha ascii chars. We could do better by emitting - // two bitfields or shifting the range by 64 if no lower chars are used. - if (!DL.fitsInLegalInteger(Max + 1)) - return nullptr; - - // For the bit field use a power-of-2 type with at least 8 bits to avoid - // creating unnecessary illegal types. - unsigned char Width = NextPowerOf2(std::max((unsigned char)7, Max)); - - // Now build the bit field. - APInt Bitfield(Width, 0); - for (char C : Str) - Bitfield.setBit((unsigned char)C); - Value *BitfieldC = B.getInt(Bitfield); - - // Adjust width of "C" to the bitfield width, then mask off the high bits. - Value *C = B.CreateZExtOrTrunc(CI->getArgOperand(1), BitfieldC->getType()); - C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); + unsigned char Max = + *std::max_element(reinterpret_cast(Str.begin()), + reinterpret_cast(Str.end())); + + // Make sure the bit field we're about to create fits in a register on the + // target. + // FIXME: On a 64 bit architecture this prevents us from using the + // interesting range of alpha ascii chars. We could do better by emitting + // two bitfields or shifting the range by 64 if no lower chars are used. + if (!DL.fitsInLegalInteger(Max + 1)) + return nullptr; - // First check that the bit field access is within bounds. - Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), - "memchr.bounds"); + // For the bit field use a power-of-2 type with at least 8 bits to avoid + // creating unnecessary illegal types. + unsigned char Width = NextPowerOf2(std::max((unsigned char)7, Max)); - // Create code that checks if the given bit is set in the field. - Value *Shl = B.CreateShl(B.getIntN(Width, 1ULL), C); - Value *Bits = B.CreateIsNotNull(B.CreateAnd(Shl, BitfieldC), "memchr.bits"); + // Now build the bit field. + APInt Bitfield(Width, 0); + for (char C : Str) + Bitfield.setBit((unsigned char)C); + Value *BitfieldC = B.getInt(Bitfield); - // Finally merge both checks and cast to pointer type. The inttoptr - // implicitly zexts the i1 to intptr type. - return B.CreateIntToPtr(B.CreateLogicalAnd(Bounds, Bits, "memchr"), - CI->getType()); - } + // Adjust width of "C" to the bitfield width, then mask off the high bits. + Value *C = B.CreateZExtOrTrunc(CharVal, BitfieldC->getType()); + C = B.CreateAnd(C, B.getIntN(Width, 0xFF)); - // Check if all arguments are constants. If so, we can constant fold. - if (!CharC) - return nullptr; + // First check that the bit field access is within bounds. + Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width), + "memchr.bounds"); - // Compute the offset. - size_t I = Str.find(CharC->getSExtValue() & 0xFF); - if (I == StringRef::npos) // Didn't find the char. memchr returns null. - return Constant::getNullValue(CI->getType()); + // Create code that checks if the given bit is set in the field. + Value *Shl = B.CreateShl(B.getIntN(Width, 1ULL), C); + Value *Bits = B.CreateIsNotNull(B.CreateAnd(Shl, BitfieldC), "memchr.bits"); - // memchr(s+n,c,l) -> gep(s+n+i,c) - return B.CreateGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "memchr"); + // Finally merge both checks and cast to pointer type. The inttoptr + // implicitly zexts the i1 to intptr type. + return B.CreateIntToPtr(B.CreateLogicalAnd(Bounds, Bits, "memchr"), + CI->getType()); } static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS, Index: llvm/test/Transforms/InstCombine/memchr-2.ll =================================================================== --- llvm/test/Transforms/InstCombine/memchr-2.ll +++ llvm/test/Transforms/InstCombine/memchr-2.ll @@ -8,14 +8,14 @@ @ax = external global [0 x i8] @a12345 = constant [5 x i8] c"\01\02\03\04\05" +@a123f45 = constant [5 x i8] c"\01\02\03\f4\05" ; Fold memchr(a12345, '\06', n) to null. define i8* @fold_memchr_a12345_6_n(i64 %n) { ; CHECK-LABEL: @fold_memchr_a12345_6_n( -; CHECK-NEXT: [[RES:%.*]] = call i8* @memchr(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a12345, i64 0, i64 0), i32 6, i64 [[N:%.*]]) -; CHECK-NEXT: ret i8* [[RES]] +; CHECK-NEXT: ret i8* null ; %ptr = getelementptr [5 x i8], [5 x i8]* @a12345, i32 0, i32 0 @@ -76,12 +76,27 @@ } -; Fold memchr(a12345, '\03', n) to n < 3 ? null : a12345 + 3. +; Fold memchr(a123f45, 500, 9) to a123f45 + 3 (verify that 500 is +; truncated to (unsigned char)500 == '\xf4') -define i8* @call_a12345_3_n(i64 %n) { -; CHECK-LABEL: @call_a12345_3_n( -; CHECK-NEXT: [[RES:%.*]] = call i8* @memchr(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a12345, i64 0, i64 0), i32 3, i64 [[N:%.*]]) -; CHECK-NEXT: ret i8* [[RES]] +define i8* @fold_memchr_a123f45_500_9() { +; CHECK-LABEL: @fold_memchr_a123f45_500_9( +; CHECK-NEXT: ret i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a123f45, i64 0, i64 3) +; + + %ptr = getelementptr [5 x i8], [5 x i8]* @a123f45, i32 0, i32 0 + %res = call i8* @memchr(i8* %ptr, i32 500, i64 9) + ret i8* %res +} + + +; Fold memchr(a12345, '\03', n) to n < 3 ? null : a12345 + 2. + +define i8* @fold_a12345_3_n(i64 %n) { +; CHECK-LABEL: @fold_a12345_3_n( +; CHECK-NEXT: [[MEMCHR_CMP:%.*]] = icmp ult i64 [[N:%.*]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[MEMCHR_CMP]], i8* null, i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a12345, i64 0, i64 2) +; CHECK-NEXT: ret i8* [[TMP1]] ; %ptr = getelementptr [5 x i8], [5 x i8]* @a12345, i32 0, i32 0 @@ -90,13 +105,14 @@ } -; Fold memchr(a12345, 259, n) to n < 4 ? null : a12345 + 3 +; Fold memchr(a12345, 259, n) to n < 3 ? null : a12345 + 2 ; to verify the constant 259 is converted to unsigned char (yielding 3). -define i8* @call_a12345_259_n(i64 %n) { -; CHECK-LABEL: @call_a12345_259_n( -; CHECK-NEXT: [[RES:%.*]] = call i8* @memchr(i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a12345, i64 0, i64 0), i32 259, i64 [[N:%.*]]) -; CHECK-NEXT: ret i8* [[RES]] +define i8* @fold_a12345_259_n(i64 %n) { +; CHECK-LABEL: @fold_a12345_259_n( +; CHECK-NEXT: [[MEMCHR_CMP:%.*]] = icmp ult i64 [[N:%.*]], 3 +; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[MEMCHR_CMP]], i8* null, i8* getelementptr inbounds ([5 x i8], [5 x i8]* @a12345, i64 0, i64 2) +; CHECK-NEXT: ret i8* [[TMP1]] ; %ptr = getelementptr [5 x i8], [5 x i8]* @a12345, i32 0, i32 0 Index: llvm/test/Transforms/InstCombine/memchr-3.ll =================================================================== --- llvm/test/Transforms/InstCombine/memchr-3.ll +++ llvm/test/Transforms/InstCombine/memchr-3.ll @@ -41,8 +41,10 @@ define i8* @fold_memchr_ax_257_1(i32 %chr, i64 %n) { ; CHECK-LABEL: @fold_memchr_ax_257_1( -; CHECK-NEXT: [[RES:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([0 x i8], [0 x i8]* @ax, i64 0, i64 0), i32 257, i64 1) -; CHECK-NEXT: ret i8* [[RES]] +; CHECK-NEXT: [[MEMCHR_CHAR0:%.*]] = load i8, i8* getelementptr inbounds ([0 x i8], [0 x i8]* @ax, i64 0, i64 0), align 1 +; CHECK-NEXT: [[MEMCHR_CHAR0CMP:%.*]] = icmp eq i8 [[MEMCHR_CHAR0]], 1 +; CHECK-NEXT: [[MEMCHR_SEL:%.*]] = select i1 [[MEMCHR_CHAR0CMP]], i8* getelementptr inbounds ([0 x i8], [0 x i8]* @ax, i64 0, i64 0), i8* null +; CHECK-NEXT: ret i8* [[MEMCHR_SEL]] ; %ptr = getelementptr [0 x i8], [0 x i8]* @ax, i32 0, i32 0 @@ -55,8 +57,11 @@ define i8* @fold_memchr_ax_c_1(i32 %chr, i64 %n) { ; CHECK-LABEL: @fold_memchr_ax_c_1( -; CHECK-NEXT: [[RES:%.*]] = call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([0 x i8], [0 x i8]* @ax, i64 0, i64 0), i32 [[CHR:%.*]], i64 1) -; CHECK-NEXT: ret i8* [[RES]] +; CHECK-NEXT: [[MEMCHR_CHAR0:%.*]] = load i8, i8* getelementptr inbounds ([0 x i8], [0 x i8]* @ax, i64 0, i64 0), align 1 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[CHR:%.*]] to i8 +; CHECK-NEXT: [[MEMCHR_CHAR0CMP:%.*]] = icmp eq i8 [[MEMCHR_CHAR0]], [[TMP1]] +; CHECK-NEXT: [[MEMCHR_SEL:%.*]] = select i1 [[MEMCHR_CHAR0CMP]], i8* getelementptr inbounds ([0 x i8], [0 x i8]* @ax, i64 0, i64 0), i8* null +; CHECK-NEXT: ret i8* [[MEMCHR_SEL]] ; %ptr = getelementptr [0 x i8], [0 x i8]* @ax, i32 0, i32 0 Index: llvm/test/Transforms/InstCombine/memchr.ll =================================================================== --- llvm/test/Transforms/InstCombine/memchr.ll +++ llvm/test/Transforms/InstCombine/memchr.ll @@ -174,9 +174,9 @@ define i1 @test14(i32 %C) { ; CHECK-LABEL: @test14( -; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[C:%.*]], 255 -; CHECK-NEXT: [[MEMCHR_BITS:%.*]] = icmp eq i32 [[TMP1]], 31 -; CHECK-NEXT: ret i1 [[MEMCHR_BITS]] +; CHECK-NEXT: [[TMP1:%.*]] = trunc i32 [[C:%.*]] to i8 +; CHECK-NEXT: [[MEMCHR_CHAR0CMP:%.*]] = icmp eq i8 [[TMP1]], 31 +; CHECK-NEXT: ret i1 [[MEMCHR_CHAR0CMP]] ; %dst = call i8* @memchr(i8* getelementptr inbounds ([2 x i8], [2 x i8]* @single, i64 0, i64 0), i32 %C, i32 1) %cmp = icmp ne i8* %dst, null