Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9945,6 +9945,46 @@ return SDValue(); } +/// Fold +/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y) +/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y) +/// (aext (select c, load x, load y)) -> (select c, aextload x, extload y) +/// This function is called by the DAGCombiner when visiting sext/zext/aext +/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). +static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, + SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || + Opcode == ISD::ANY_EXTEND) && "Expected EXTEND dag node in input!"); + + if (N0->getOpcode() != ISD::SELECT || !N0.hasOneUse()) + return SDValue(); + + SDValue Op1 = N0->getOperand(1); + SDValue Op2 = N0->getOperand(2); + if (!isa(Op1) || !isa(Op2) || + !Op1.hasOneUse() || !Op2.hasOneUse()) + return SDValue(); + + auto ExtLoadOpcode = ISD::EXTLOAD; + if (Opcode == ISD::SIGN_EXTEND) + ExtLoadOpcode = ISD::SEXTLOAD; + else if (Opcode == ISD::ZERO_EXTEND) + ExtLoadOpcode = ISD::ZEXTLOAD; + + LoadSDNode *Load1 = cast(Op1); + if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT())) + return SDValue(); + + SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1); + SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2); + return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2); +} + /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or /// a build_vector of constants. /// This function is called by the DAGCombiner when visiting sext/zext/aext @@ -10729,6 +10769,9 @@ return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT)); } + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11041,6 +11084,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11193,6 +11239,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } Index: llvm/test/CodeGen/X86/select-ext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/X86/select-ext.ll @@ -0,0 +1,20 @@ +; RUN: llc -mtriple=x86_64-unknown-unknown < %s | FileCheck %s + +; (zext(select c, load1, load2)) -> (select c, zextload1, zextload2) + +; CHECK-LABEL: foo +; CHECK: movzbl (%rdi), %ecx +; CHECK-NEXT: movzbl 1(%rdi), %eax +; CHECK-NEXT: testl %esi, %esi +; CHECK-NEXT: cmoveq %rcx, %rax +; CHECK-NEXT: retq + +define i64 @foo(i8* %p, i1 zeroext %c) { + %ld1 = load volatile i8, i8* %p + %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1 + %ld2 = load volatile i8, i8* %arrayidx1 + %cond.v = select i1 %c, i8 %ld2, i8 %ld1 + %cond = zext i8 %cond.v to i64 + ret i64 %cond +} +