Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9949,6 +9949,47 @@ return SDValue(); } +/// Fold +/// (sext (select c, load x, load y)) -> (select c, sextload x, sextload y) +/// (zext (select c, load x, load y)) -> (select c, zextload x, zextload y) +/// (aext (select c, load x, load y)) -> (select c, extload x, extload y) +/// This function is called by the DAGCombiner when visiting sext/zext/aext +/// dag nodes (see for example method DAGCombiner::visitSIGN_EXTEND). +static SDValue tryToFoldExtendSelectLoad(SDNode *N, const TargetLowering &TLI, + SelectionDAG &DAG) { + unsigned Opcode = N->getOpcode(); + SDValue N0 = N->getOperand(0); + EVT VT = N->getValueType(0); + SDLoc DL(N); + + assert((Opcode == ISD::SIGN_EXTEND || Opcode == ISD::ZERO_EXTEND || + Opcode == ISD::ANY_EXTEND) && + "Expected EXTEND dag node in input!"); + + if (N0->getOpcode() != ISD::SELECT || !N0.hasOneUse()) + return SDValue(); + + SDValue Op1 = N0->getOperand(1); + SDValue Op2 = N0->getOperand(2); + if (!isa(Op1) || !isa(Op2) || !Op1.hasOneUse() + || !Op2.hasOneUse()) + return SDValue(); + + auto ExtLoadOpcode = ISD::EXTLOAD; + if (Opcode == ISD::SIGN_EXTEND) + ExtLoadOpcode = ISD::SEXTLOAD; + else if (Opcode == ISD::ZERO_EXTEND) + ExtLoadOpcode = ISD::ZEXTLOAD; + + LoadSDNode *Load1 = cast(Op1); + if (!TLI.isLoadExtLegal(ExtLoadOpcode, VT, Load1->getMemoryVT())) + return SDValue(); + + SDValue Ext1 = DAG.getNode(Opcode, DL, VT, Op1); + SDValue Ext2 = DAG.getNode(Opcode, DL, VT, Op2); + return DAG.getSelect(DL, VT, N0->getOperand(0), Ext1, Ext2); +} + /// Try to fold a sext/zext/aext dag node into a ConstantSDNode or /// a build_vector of constants. /// This function is called by the DAGCombiner when visiting sext/zext/aext @@ -10733,6 +10774,9 @@ return DAG.getNode(ISD::ADD, DL, VT, Zext, DAG.getAllOnesConstant(DL, VT)); } + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11045,6 +11089,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } @@ -11197,6 +11244,9 @@ if (SDValue NewCtPop = widenCtPop(N, DAG)) return NewCtPop; + if (SDValue Res = tryToFoldExtendSelectLoad(N, TLI, DAG)) + return Res; + return SDValue(); } Index: llvm/test/CodeGen/X86/select-ext.ll =================================================================== --- llvm/test/CodeGen/X86/select-ext.ll +++ llvm/test/CodeGen/X86/select-ext.ll @@ -1,16 +1,14 @@ -; RUN: llc -mtriple=x86_64-unknown-unknown < %s | FileCheck %s +; RUN: llc -mtriple=x86_64-unknown-unknown -mattr=+sse4.1 < %s | FileCheck %s -; TODO: (zext(select c, load1, load2)) -> (select c, zextload1, zextload2) - -; CHECK-LABEL: foo -; CHECK: movzbl (%rdi), %eax -; CHECK-NEXT: movzbl 1(%rdi), %ecx +; (zext(select c, load1, load2)) -> (select c, zextload1, zextload2) +; CHECK-LABEL: zext_scalar +; CHECK: movzbl (%rdi), %ecx +; CHECK-NEXT: movzbl 1(%rdi), %eax ; CHECK-NEXT: testl %esi, %esi -; CHECK-NEXT: cmovel %eax, %ecx -; CHECK-NEXT: movzbl %cl, %eax +; CHECK-NEXT: cmoveq %rcx, %rax ; CHECK-NEXT: retq -define i64 @foo(i8* %p, i1 zeroext %c) { +define i64 @zext_scalar(i8* %p, i1 zeroext %c) { %ld1 = load volatile i8, i8* %p %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1 %ld2 = load volatile i8, i8* %arrayidx1 @@ -19,3 +17,60 @@ ret i64 %cond } + +; (sext(select c, load1, load2)) -> (select c, sextload1, sextload2) +; CHECK-LABEL: sext_scalar +; CHECK: movsbq (%rdi), %rcx +; CHECK-NEXT: movsbq 1(%rdi), %rax +; CHECK-NEXT: testl %esi, %esi +; CHECK-NEXT: cmoveq %rcx, %rax +; CHECK-NEXT: retq + +define i64 @sext_scalar(i8* %p, i1 zeroext %c) { + %ld1 = load volatile i8, i8* %p + %arrayidx1 = getelementptr inbounds i8, i8* %p, i64 1 + %ld2 = load volatile i8, i8* %arrayidx1 + %cond.v = select i1 %c, i8 %ld2, i8 %ld1 + %cond = sext i8 %cond.v to i64 + ret i64 %cond +} + + +; Same as zext_scalar, but operate on vectors. +; CHECK-LABEL: zext_vector +; CHECK: pmovzxdq (%rdi), %xmm1 +; CHECK-NEXT: pmovzxdq 8(%rdi), %xmm0 +; CHECK-NEXT: testl %esi, %esi +; CHECK-NEXT: jne .LBB2_2 +; CHECK: movdqa %xmm1, %xmm0 +; CHECK-NEXT: .LBB2_2: +; CHECK-NEXT: retq + +define <2 x i64> @zext_vector(<2 x i32>* %p, i1 zeroext %c) { + %ld1 = load volatile <2 x i32>, <2 x i32>* %p + %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1 + %ld2 = load volatile <2 x i32>, <2 x i32>* %arrayidx1 + %cond.v = select i1 %c, <2 x i32> %ld2, <2 x i32> %ld1 + %cond = zext <2 x i32> %cond.v to <2 x i64> + ret <2 x i64> %cond +} + + +; Same as sext_scalar, but operate on vectors. +; CHECK-LABEL: sext_vector +; CHECK: pmovsxdq (%rdi), %xmm1 +; CHECK-NEXT: pmovsxdq 8(%rdi), %xmm0 +; CHECK-NEXT: testl %esi, %esi +; CHECK-NEXT: jne .LBB3_2 +; CHECK: movdqa %xmm1, %xmm0 +; CHECK-NEXT: .LBB3_2: +; CHECK-NEXT: retq +define <2 x i64> @sext_vector(<2 x i32>* %p, i1 zeroext %c) { + %ld1 = load volatile <2 x i32>, <2 x i32>* %p + %arrayidx1 = getelementptr inbounds <2 x i32>, <2 x i32>* %p, i64 1 + %ld2 = load volatile <2 x i32>, <2 x i32>* %arrayidx1 + %cond.v = select i1 %c, <2 x i32> %ld2, <2 x i32> %ld1 + %cond = sext <2 x i32> %cond.v to <2 x i64> + ret <2 x i64> %cond +} +