Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9949,6 +9949,47 @@ return SDValue(); } +/// Fold +/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y) +/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y) +/// (aext (select c, load x, load y)) -> (select c, aextload x, extload y) +/// This function is called by the DAGCombiner when visiting sext/zext/aext +/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). +static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, + SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || + Opcode == ISD::ANY_EXTEND) && + "Expected EXTEND dag node in input!"); + + if (N0->getOpcode() != ISD::SELECT || !N0.hasOneUse()) + return SDValue(); + + SDValue Op1 = N0->getOperand(1); + SDValue Op2 = N0->getOperand(2); + if (!isa(Op1) || !isa(Op2) || !Op1.hasOneUse() + || !Op2.hasOneUse()) + return SDValue(); + + auto ExtLoadOpcode = ISD::EXTLOAD; + if (Opcode == ISD::SIGN_EXTEND) + ExtLoadOpcode = ISD::SEXTLOAD; + else if (Opcode == ISD::ZERO_EXTEND) + ExtLoadOpcode = ISD::ZEXTLOAD; + + LoadSDNode *Load1 = cast(Op1); + if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT())) + return SDValue(); + + SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1); + SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2); + return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2); +} + /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or /// a build_vector of constants. /// This function is called by the DAGCombiner when visiting sext/zext/aext @@ -10733,6 +10774,9 @@ return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT)); } + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11045,6 +11089,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11197,6 +11244,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } Index: llvm/test/CodeGen/X86/select-ext.ll =================================================================== --- llvm/test/CodeGen/X86/select-ext.ll +++ llvm/test/CodeGen/X86/select-ext.ll @@ -1,13 +1,12 @@ ; RUN: llc -mtriple=x86_64-unknown-unknown < %s | FileCheck %s -; TODO: (zext(select c, load1, load2)) -> (select c, zextload1, zextload2) +; (zext(select c, load1, load2)) -> (select c, zextload1, zextload2) ; CHECK-LABEL: foo -; CHECK: movzbl (%rdi), %eax -; CHECK-NEXT: movzbl 1(%rdi), %ecx +; CHECK: movzbl (%rdi), %ecx +; CHECK-NEXT: movzbl 1(%rdi), %eax ; CHECK-NEXT: testl %esi, %esi -; CHECK-NEXT: cmovel %eax, %ecx -; CHECK-NEXT: movzbl %cl, %eax +; CHECK-NEXT: cmoveq %rcx, %rax ; CHECK-NEXT: retq define i64 @foo(i8* %p, i1 zeroext %c) {