Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9949,6 +9949,47 @@ return SDValue(); } +/// Fold +/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y) +/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y) +/// (aext (select c, load x, load y)) -> (select c, extload x, extload y) +/// This function is called by the DAGCombiner when visiting sext/zext/aext +/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). +static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, + SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || + Opcode == ISD::ANY_EXTEND) && + "Expected EXTEND dag node in input!"); + + if (N0->getOpcode() != ISD::SELECT || !N0.hasOneUse()) + return SDValue(); + + SDValue Op1 = N0->getOperand(1); + SDValue Op2 = N0->getOperand(2); + if (!isa(Op1) || !isa(Op2) || !Op1.hasOneUse() + || !Op2.hasOneUse()) + return SDValue(); + + auto ExtLoadOpcode = ISD::EXTLOAD; + if (Opcode == ISD::SIGN_EXTEND) + ExtLoadOpcode = ISD::SEXTLOAD; + else if (Opcode == ISD::ZERO_EXTEND) + ExtLoadOpcode = ISD::ZEXTLOAD; + + LoadSDNode *Load1 = cast(Op1); + if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT())) + return SDValue(); + + SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1); + SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2); + return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2); +} + /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or /// a build_vector of constants. /// This function is called by the DAGCombiner when visiting sext/zext/aext @@ -10733,6 +10774,9 @@ return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT)); } + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11045,6 +11089,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11197,6 +11244,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } Index: llvm/test/CodeGen/X86/select-ext.ll =================================================================== --- llvm/test/CodeGen/X86/select-ext.ll +++ llvm/test/CodeGen/X86/select-ext.ll @@ -5,11 +5,10 @@ define i64 @zext_scalar(i8* %p, i1 zeroext %c) { ; CHECK-LABEL: zext_scalar: ; CHECK: # %bb.0: -; CHECK-NEXT: movzbl (%rdi), %eax -; CHECK-NEXT: movzbl 1(%rdi), %ecx +; CHECK-NEXT: movzbl (%rdi), %ecx +; CHECK-NEXT: movzbl 1(%rdi), %eax ; CHECK-NEXT: testl %esi, %esi -; CHECK-NEXT: cmovel %eax, %ecx -; CHECK-NEXT: movzbl %cl, %eax +; CHECK-NEXT: cmoveq %rcx, %rax ; CHECK-NEXT: retq %ld1 = load volatile i8, i8* %p %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1 @@ -23,11 +22,10 @@ define i64 @sext_scalar(i8* %p, i1 zeroext %c) { ; CHECK-LABEL: sext_scalar: ; CHECK: # %bb.0: -; CHECK-NEXT: movzbl (%rdi), %eax -; CHECK-NEXT: movzbl 1(%rdi), %ecx +; CHECK-NEXT: movsbq (%rdi), %rcx +; CHECK-NEXT: movsbq 1(%rdi), %rax ; CHECK-NEXT: testl %esi, %esi -; CHECK-NEXT: cmovel %eax, %ecx -; CHECK-NEXT: movsbq %cl, %rax +; CHECK-NEXT: cmoveq %rcx, %rax ; CHECK-NEXT: retq %ld1 = load volatile i8, i8* %p %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1 @@ -41,14 +39,13 @@ define <2 x i64> @zext_vector_i1(<2 x i32>* %p, i1 zeroext %c) { ; CHECK-LABEL: zext_vector_i1: ; CHECK: # %bb.0: -; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero -; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: pmovzxdq {{.*#+}} xmm1 = mem[0],zero,mem[1],zero +; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = mem[0],zero,mem[1],zero ; CHECK-NEXT: testl %esi, %esi ; CHECK-NEXT: jne .LBB2_2 ; CHECK-NEXT: # %bb.1: ; CHECK-NEXT: movdqa %xmm1, %xmm0 ; CHECK-NEXT: .LBB2_2: -; CHECK-NEXT: pmovzxdq {{.*#+}} xmm0 = xmm0[0],zero,xmm0[1],zero ; CHECK-NEXT: retq %ld1 = load volatile <2 x i32>, <2 x i32>* %p %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1 @@ -80,14 +77,13 @@ define <2 x i64> @sext_vector_i1(<2 x i32>* %p, i1 zeroext %c) { ; CHECK-LABEL: sext_vector_i1: ; CHECK: # %bb.0: -; CHECK-NEXT: movq {{.*#+}} xmm1 = mem[0],zero -; CHECK-NEXT: movq {{.*#+}} xmm0 = mem[0],zero +; CHECK-NEXT: pmovsxdq (%rdi), %xmm1 +; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm0 ; CHECK-NEXT: testl %esi, %esi ; CHECK-NEXT: jne .LBB4_2 ; CHECK-NEXT: # %bb.1: ; CHECK-NEXT: movdqa %xmm1, %xmm0 ; CHECK-NEXT: .LBB4_2: -; CHECK-NEXT: pmovsxdq %xmm0, %xmm0 ; CHECK-NEXT: retq %ld1 = load volatile <2 x i32>, <2 x i32>* %p %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1