diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44325,6 +44325,59 @@ return DAG.getBitcast(VT, Res); } +static SDValue combineSelectVxi64(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDLoc DL(N); + SDValue Cond = N->getOperand(0); + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + EVT VT = LHS.getValueType(); + EVT CondVT = Cond.getValueType(); + // Combine + // select vXi1 bitcast (int cond), + // ( bitcast a), + // ( bitcast b) + // to + // select cond, a, b + // to create opportunity for mask instructions with AVX512 instructions. + if (!Subtarget.hasAVX512()) + return SDValue(); + + if (!CondVT.isVector() || CondVT.getVectorElementType() != MVT::i1) + return SDValue(); + if (Cond.getOpcode() != ISD::BITCAST) + return SDValue(); + if (!dyn_cast(Cond.getOperand(0))) + return SDValue(); + + if (VT.getVectorElementType() != MVT::i64) + return SDValue(); + + if (LHS.getOpcode() != ISD::BITCAST || + LHS.getOperand(0).getValueType().getVectorElementType() != MVT::i32) + return SDValue(); + if (RHS.getOpcode() != ISD::BITCAST || + RHS.getOperand(0).getValueType().getVectorElementType() != MVT::i32) + return SDValue(); + + if (!Cond.hasOneUse() || !LHS.hasOneUse() || !RHS.hasOneUse()) + return SDValue(); + + int NumElts = VT.getVectorNumElements(); + EVT ExpandCondVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts * 2); + EVT ExpandVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts * 2); + + ConstantSDNode *ConstCond = cast(Cond.getOperand(0)); + uint64_t Mask = ConstCond->getZExtValue(); + Mask = (Mask << 1) | Mask; + SDValue MaskVal = DAG.getConstant( + Mask, DL, EVT::getIntegerVT(*DAG.getContext(), NumElts * 2)); + SDValue NewCond = DAG.getBitcast(ExpandCondVT, MaskVal); + return DAG.getBitcast(VT, + DAG.getSelect(DL, ExpandVT, NewCond, LHS.getOperand(0), + RHS.getOperand(0))); +} + /// Do target-specific dag combines on SELECT and VSELECT nodes. static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -44549,6 +44602,9 @@ return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS); } + if (SDValue V = combineSelectVxi64(N, DAG, Subtarget)) + return V; + // Some mask scalar intrinsics rely on checking if only one bit is set // and implement it in C code like this: // A[0] = (U & 1) ? A[0] : W[0]; diff --git a/llvm/test/CodeGen/X86/avx512-shuffles/shuffle-blend.ll b/llvm/test/CodeGen/X86/avx512-shuffles/shuffle-blend.ll --- a/llvm/test/CodeGen/X86/avx512-shuffles/shuffle-blend.ll +++ b/llvm/test/CodeGen/X86/avx512-shuffles/shuffle-blend.ll @@ -5,10 +5,9 @@ ; CHECK-LABEL: shuffle_v8i64: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: vpaddd %zmm1, %zmm0, %zmm2 -; CHECK-NEXT: vpsubd %zmm1, %zmm0, %zmm0 -; CHECK-NEXT: movb $-86, %al +; CHECK-NEXT: movw $510, %ax # imm = 0x1FE ; CHECK-NEXT: kmovd %eax, %k1 -; CHECK-NEXT: vmovdqa64 %zmm0, %zmm2 {%k1} +; CHECK-NEXT: vpsubd %zmm1, %zmm0, %zmm2 {%k1} ; CHECK-NEXT: vmovdqa64 %zmm2, %zmm0 ; CHECK-NEXT: retq entry: