Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -6715,12 +6715,13 @@ /// The caller already checked that \p I calls the appropriate LibFunc with a /// correct prototype. bool SelectionDAGBuilder::visitMemCmpCall(const CallInst &I) { + const auto &TLI = DAG.getTargetLoweringInfo(); + const auto &DL = DAG.getDataLayout(); const Value *LHS = I.getArgOperand(0), *RHS = I.getArgOperand(1); const Value *Size = I.getArgOperand(2); const ConstantInt *CSize = dyn_cast(Size); if (CSize && CSize->getZExtValue() == 0) { - EVT CallVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(), - I.getType(), true); + EVT CallVT = TLI.getValueType(DL, I.getType(), true); setValue(&I, DAG.getConstant(0, getCurSDLoc(), CallVT)); return true; } @@ -6735,70 +6736,119 @@ return true; } - // memcmp(S1,S2,2) != 0 -> (*(short*)LHS != *(short*)RHS) != 0 - // memcmp(S1,S2,4) != 0 -> (*(int*)LHS != *(int*)RHS) != 0 - if (!CSize || !isOnlyUsedInZeroEqualityComparison(&I)) + if (!isOnlyUsedInZeroEqualityComparison(&I)) return false; - // If the target has a fast compare for the given size, it will return a - // preferred load type for that size. Require that the load VT is legal and - // that the target supports unaligned loads of that type. Otherwise, return - // INVALID. - auto hasFastLoadsAndCompare = [&](unsigned NumBits) { - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - MVT LVT = TLI.hasFastEqualityCompare(NumBits); - if (LVT != MVT::INVALID_SIMPLE_VALUE_TYPE) { - // TODO: Handle 5 byte compare as 4-byte + 1 byte. - // TODO: Handle 8 byte compare on x86-32 as two 32-bit loads. - // TODO: Check alignment of src and dest ptrs. - unsigned DstAS = LHS->getType()->getPointerAddressSpace(); - unsigned SrcAS = RHS->getType()->getPointerAddressSpace(); - if (!TLI.isTypeLegal(LVT) || - !TLI.allowsMisalignedMemoryAccesses(LVT, SrcAS) || - !TLI.allowsMisalignedMemoryAccesses(LVT, DstAS)) - LVT = MVT::INVALID_SIMPLE_VALUE_TYPE; - } - - return LVT; - }; + // We're only interested in the boolean comparison value (equal/not equal). + + // If the size is a compile-time constant, we first try to lower to a single + // comparison between two loads: + // memcmp(S1,S2,2) != 0 -> (*(short*)LHS != *(short*)RHS) != 0 + // memcmp(S1,S2,4) != 0 -> (*(int*)LHS != *(int*)RHS) != 0 + if (CSize) { + // If the target has a fast compare for the given size, it will return a + // preferred load type for that size. Require that the load VT is legal and + // that the target supports unaligned loads of that type. Otherwise, return + // INVALID. + auto hasFastLoadsAndCompare = [&](unsigned NumBits) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + MVT LVT = TLI.hasFastEqualityCompare(NumBits); + if (LVT != MVT::INVALID_SIMPLE_VALUE_TYPE) { + // TODO: Handle 5 byte compare as 4-byte + 1 byte. + // TODO: Handle 8 byte compare on x86-32 as two 32-bit loads. + // TODO: Check alignment of src and dest ptrs. + unsigned DstAS = LHS->getType()->getPointerAddressSpace(); + unsigned SrcAS = RHS->getType()->getPointerAddressSpace(); + if (!TLI.isTypeLegal(LVT) || + !TLI.allowsMisalignedMemoryAccesses(LVT, SrcAS) || + !TLI.allowsMisalignedMemoryAccesses(LVT, DstAS)) + LVT = MVT::INVALID_SIMPLE_VALUE_TYPE; + } - // This turns into unaligned loads. We only do this if the target natively - // supports the MVT we'll be loading or if it is small enough (<= 4) that - // we'll only produce a small number of byte loads. - MVT LoadVT; - unsigned NumBitsToCompare = CSize->getZExtValue() * 8; - switch (NumBitsToCompare) { - default: - return false; - case 16: - LoadVT = MVT::i16; - break; - case 32: - LoadVT = MVT::i32; - break; - case 64: - case 128: - case 256: - LoadVT = hasFastLoadsAndCompare(NumBitsToCompare); - break; - } + return LVT; + }; - if (LoadVT == MVT::INVALID_SIMPLE_VALUE_TYPE) - return false; + // This turns into unaligned loads. We only do this if the target natively + // supports the MVT we'll be loading or if it is small enough (<= 4) that + // we'll only produce a small number of byte loads. + MVT LoadVT; + unsigned NumBitsToCompare = CSize->getZExtValue() * 8; + switch (NumBitsToCompare) { + default: + LoadVT = MVT::INVALID_SIMPLE_VALUE_TYPE; + break; + case 16: + LoadVT = MVT::i16; + break; + case 32: + LoadVT = MVT::i32; + break; + case 64: + case 128: + case 256: + LoadVT = hasFastLoadsAndCompare(NumBitsToCompare); + break; + } + + if (LoadVT != MVT::INVALID_SIMPLE_VALUE_TYPE) { + SDValue LoadL = getMemCmpLoad(LHS, LoadVT, *this); + SDValue LoadR = getMemCmpLoad(RHS, LoadVT, *this); - SDValue LoadL = getMemCmpLoad(LHS, LoadVT, *this); - SDValue LoadR = getMemCmpLoad(RHS, LoadVT, *this); + // Bitcast to a wide integer type if the loads are vectors. + if (LoadVT.isVector()) { + EVT CmpVT = + EVT::getIntegerVT(LHS->getContext(), LoadVT.getSizeInBits()); + LoadL = DAG.getBitcast(CmpVT, LoadL); + LoadR = DAG.getBitcast(CmpVT, LoadR); + } - // Bitcast to a wide integer type if the loads are vectors. - if (LoadVT.isVector()) { - EVT CmpVT = EVT::getIntegerVT(LHS->getContext(), LoadVT.getSizeInBits()); - LoadL = DAG.getBitcast(CmpVT, LoadL); - LoadR = DAG.getBitcast(CmpVT, LoadR); + SDValue Cmp = + DAG.getSetCC(getCurSDLoc(), MVT::i1, LoadL, LoadR, ISD::SETNE); + processIntegerCallValue(I, Cmp, false); + return true; + } } - SDValue Cmp = DAG.getSetCC(getCurSDLoc(), MVT::i1, LoadL, LoadR, ISD::SETNE); - processIntegerCallValue(I, Cmp, false); - return true; + // The size is not constant or it's not efficient to use the strategy above. + // If the module provided a `memeq` library function, call it. + if (const MDString *MemeqLibFunction = dyn_cast_or_null( + DAG.getMachineFunction().getFunction().getParent()->getModuleFlag( + "memeq_lib_function"))) { + TargetLowering::ArgListTy Args; + TargetLowering::ArgListEntry Entry; + // signature: bool(const char*, const char*, size_t) + Entry.Ty = DL.getIntPtrType(*DAG.getContext()); + Entry.Node = getValue(LHS); + Args.push_back(Entry); + Entry.Node = getValue(RHS); + Args.push_back(Entry); + Entry.Node = getValue(Size); + Args.push_back(Entry); + TargetLowering::CallLoweringInfo CLI(DAG); + CLI.setDebugLoc(getCurSDLoc()) + .setChain(DAG.getRoot()) + .setLibCallee( + TLI.getLibcallCallingConv(RTLIB::MEMCPY), + Type::getInt1Ty(*DAG.getContext()), + DAG.getExternalSymbol(MemeqLibFunction->getString().data(), + TLI.getPointerTy(DL)), + std::move(Args)) + .setDiscardResult(false); + + std::pair CallResult = TLI.LowerCallTo(CLI); + processIntegerCallValue( + I, + // We have now turned `memcmp() != 0` into `memeq() != 0`, we need to + // add a not to have: `(!memeq()) != 0`, i.e. `memeq()` + DAG.getLogicalNOT(getCurSDLoc(), CallResult.first, + EVT::getIntegerVT(*DAG.getContext(), 1)), + false); + PendingLoads.push_back(CallResult.second); + return true; + } + + // Nothing better, just call memcmp(). + return false; } /// See if we can lower a memchr call into an optimized form. If so, return Index: test/CodeGen/X86/memcmp-memeq.ll =================================================================== --- /dev/null +++ test/CodeGen/X86/memcmp-memeq.ll @@ -0,0 +1,70 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown | FileCheck %s --check-prefix=MEMEQ + +; This tests optimization of `memcmp() == 0` with a user-provided `memeq()` +; function. + +!llvm.module.flags = !{!0} +!0 = !{i32 1, !"memeq_lib_function", !"user_memeq"} + +declare i32 @memcmp(i8*, i8*, i64) + +define i1 @huge_length_eq(i8* %X, i8* %Y) nounwind { +; MEMEQ-LABEL: huge_length_eq: +; MEMEQ: # %bb.0: +; MEMEQ-NEXT: pushq %rax +; MEMEQ-NEXT: movabsq $9223372036854775807, %rdx # imm = 0x7FFFFFFFFFFFFFFF +; MEMEQ-NEXT: callq user_memeq +; MEMEQ-NEXT: andb $1, %al +; MEMEQ-NEXT: popq %rcx +; MEMEQ-NEXT: retq + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 9223372036854775807) nounwind + %c = icmp eq i32 %m, 0 + ret i1 %c +} + +define i1 @nonconst_length_eq(i8* %X, i8* %Y, i64 %size) nounwind { +; MEMEQ-LABEL: nonconst_length_eq: +; MEMEQ: # %bb.0: +; MEMEQ-NEXT: pushq %rax +; MEMEQ-NEXT: callq user_memeq +; MEMEQ-NEXT: andb $1, %al +; MEMEQ-NEXT: popq %rcx +; MEMEQ-NEXT: retq + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 %size) nounwind + %c = icmp eq i32 %m, 0 + ret i1 %c +} + +; Check that we do not optimize the inline case. +define i1 @length8_eq(i8* %X, i8* %Y) nounwind { +; MEMEQ-LABEL: length8_eq: +; MEMEQ: # %bb.0: +; MEMEQ-NEXT: movq (%rdi), %rax +; MEMEQ-NEXT: cmpq (%rsi), %rax +; MEMEQ-NEXT: sete %al +; MEMEQ-NEXT: retq + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 8) nounwind + %c = icmp eq i32 %m, 0 + ret i1 %c +} + +; Check that we do not optimize the non-equality case. + +define i32 @huge_length(i8* %X, i8* %Y) nounwind { +; MEMEQ-LABEL: huge_length: +; MEMEQ: # %bb.0: +; MEMEQ-NEXT: movabsq $9223372036854775807, %rdx # imm = 0x7FFFFFFFFFFFFFFF +; MEMEQ-NEXT: jmp memcmp # TAILCALL + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 9223372036854775807) nounwind + ret i32 %m +} + +define i32 @nonconst_length(i8* %X, i8* %Y, i64 %size) nounwind { +; MEMEQ-LABEL: nonconst_length: +; MEMEQ: # %bb.0: +; MEMEQ-NEXT: jmp memcmp # TAILCALL + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 %size) nounwind + ret i32 %m +} + Index: test/CodeGen/X86/memcmp.ll =================================================================== --- test/CodeGen/X86/memcmp.ll +++ test/CodeGen/X86/memcmp.ll @@ -1344,3 +1344,70 @@ %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 9223372036854775807) nounwind ret i32 %m } + +define i1 @huge_length_eq(i8* %X, i8* %Y) nounwind { +; X86-LABEL: huge_length_eq: +; X86: # %bb.0: +; X86-NEXT: pushl $2147483647 # imm = 0x7FFFFFFF +; X86-NEXT: pushl $-1 +; X86-NEXT: pushl {{[0-9]+}}(%esp) +; X86-NEXT: pushl {{[0-9]+}}(%esp) +; X86-NEXT: calll memcmp +; X86-NEXT: addl $16, %esp +; X86-NEXT: testl %eax, %eax +; X86-NEXT: sete %al +; X86-NEXT: retl +; +; X64-LABEL: huge_length_eq: +; X64: # %bb.0: +; X64-NEXT: pushq %rax +; X64-NEXT: movabsq $9223372036854775807, %rdx # imm = 0x7FFFFFFFFFFFFFFF +; X64-NEXT: callq memcmp +; X64-NEXT: testl %eax, %eax +; X64-NEXT: sete %al +; X64-NEXT: popq %rcx +; X64-NEXT: retq + + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 9223372036854775807) nounwind + %c = icmp eq i32 %m, 0 + ret i1 %c +} + +; This checks non-constant sizes. +define i32 @nonconst_length(i8* %X, i8* %Y, i64 %size) nounwind { +; X86-LABEL: nonconst_length: +; X86: # %bb.0: +; X86-NEXT: jmp memcmp # TAILCALL +; +; X64-LABEL: nonconst_length: +; X64: # %bb.0: +; X64-NEXT: jmp memcmp # TAILCALL + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 %size) nounwind + ret i32 %m +} + +define i1 @nonconst_length_eq(i8* %X, i8* %Y, i64 %size) nounwind { +; X86-LABEL: nonconst_length_eq: +; X86: # %bb.0: +; X86-NEXT: pushl {{[0-9]+}}(%esp) +; X86-NEXT: pushl {{[0-9]+}}(%esp) +; X86-NEXT: pushl {{[0-9]+}}(%esp) +; X86-NEXT: pushl {{[0-9]+}}(%esp) +; X86-NEXT: calll memcmp +; X86-NEXT: addl $16, %esp +; X86-NEXT: testl %eax, %eax +; X86-NEXT: sete %al +; X86-NEXT: retl +; +; X64-LABEL: nonconst_length_eq: +; X64: # %bb.0: +; X64-NEXT: pushq %rax +; X64-NEXT: callq memcmp +; X64-NEXT: testl %eax, %eax +; X64-NEXT: sete %al +; X64-NEXT: popq %rcx +; X64-NEXT: retq + %m = tail call i32 @memcmp(i8* %X, i8* %Y, i64 %size) nounwind + %c = icmp eq i32 %m, 0 + ret i1 %c +}