Index: llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -999,6 +999,7 @@ Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) { Value *SrcStr = CI->getArgOperand(0); + Value *Char = CI->getArgOperand(1); Value *Size = CI->getArgOperand(2); if (isKnownNonZero(Size, DL)) annotateNonNullNoUndefBasedOnAccess(CI, 0); @@ -1089,6 +1090,10 @@ // From now on we need a constant length and constant array. return nullptr; + bool OptForSize = CI->getFunction()->hasOptSize() || + llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI, + PGSOQueryType::IRPass); + // If the char is variable but the input str and length are not we can turn // this memchr call into a simple bit field test. Of course this only works // when the return value is only checked against null. @@ -1099,7 +1104,7 @@ // memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n'))) // != 0 // after bounds check. - if (Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI)) + if (OptForSize || Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI)) return nullptr; unsigned char Max = @@ -1111,8 +1116,19 @@ // FIXME: On a 64 bit architecture this prevents us from using the // interesting range of alpha ascii chars. We could do better by emitting // two bitfields or shifting the range by 64 if no lower chars are used. - if (!DL.fitsInLegalInteger(Max + 1)) - return nullptr; + if (!DL.fitsInLegalInteger(Max + 1)) { + // Build chain of logical ORs + // Transform: + // memchr("abcd", C, 4) != nullptr + // to: + // (C == 'a' || C == 'b' || C == 'c' || C == 'd') != 0 + SmallVector CharCompares; + for (char C : Str) + CharCompares.push_back( + B.CreateICmpEQ(Char, ConstantInt::get(Char->getType(), C))); + + return B.CreateIntToPtr(B.CreateOr(CharCompares), CI->getType()); + } // For the bit field use a power-of-2 type with at least 8 bits to avoid // creating unnecessary illegal types. Index: llvm/test/Transforms/InstCombine/memchr-7.ll =================================================================== --- llvm/test/Transforms/InstCombine/memchr-7.ll +++ llvm/test/Transforms/InstCombine/memchr-7.ll @@ -9,9 +9,11 @@ define zeroext i1 @strchr_to_memchr_n_equals_len(i32 %c) { ; CHECK-LABEL: @strchr_to_memchr_n_equals_len( -; CHECK-NEXT: [[MEMCHR:%.*]] = tail call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([27 x i8], [27 x i8]* @.str, i64 0, i64 0), i32 [[C:%.*]], i64 27) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8* [[MEMCHR]], null -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[C:%.*]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[C]], -97 +; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i32 [[TMP2]], 26 +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[TMP1]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %call = tail call i8* @strchr(i8* nonnull dereferenceable(27) getelementptr inbounds ([27 x i8], [27 x i8]* @.str, i64 0, i64 0), i32 %c) %cmp = icmp ne i8* %call, null @@ -33,9 +35,9 @@ define zeroext i1 @memchr_n_less_than_len(i32 %c) { ; CHECK-LABEL: @memchr_n_less_than_len( -; CHECK-NEXT: [[CALL:%.*]] = tail call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([27 x i8], [27 x i8]* @.str, i64 0, i64 0), i32 [[C:%.*]], i64 15) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8* [[CALL]], null -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = add i32 [[C:%.*]], -97 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ult i32 [[TMP1]], 15 +; CHECK-NEXT: ret i1 [[TMP2]] ; %call = tail call i8* @memchr(i8* getelementptr inbounds ([27 x i8], [27 x i8]* @.str, i64 0, i64 0), i32 %c, i64 15) %cmp = icmp ne i8* %call, null @@ -45,9 +47,11 @@ define zeroext i1 @memchr_n_more_than_len(i32 %c) { ; CHECK-LABEL: @memchr_n_more_than_len( -; CHECK-NEXT: [[CALL:%.*]] = tail call i8* @memchr(i8* noundef nonnull dereferenceable(1) getelementptr inbounds ([27 x i8], [27 x i8]* @.str, i64 0, i64 0), i32 [[C:%.*]], i64 30) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8* [[CALL]], null -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i32 [[C:%.*]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[C]], -97 +; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i32 [[TMP2]], 26 +; CHECK-NEXT: [[TMP4:%.*]] = or i1 [[TMP3]], [[TMP1]] +; CHECK-NEXT: ret i1 [[TMP4]] ; %call = tail call i8* @memchr(i8* getelementptr inbounds ([27 x i8], [27 x i8]* @.str, i64 0, i64 0), i32 %c, i64 30) %cmp = icmp ne i8* %call, null