diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -44249,6 +44249,59 @@ return DAG.getBitcast(VT, Res); } +static SDValue combineSelectVxi64(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + SDLoc DL(N); + SDValue Cond = N->getOperand(0); + SDValue LHS = N->getOperand(1); + SDValue RHS = N->getOperand(2); + EVT VT = LHS.getValueType(); + EVT CondVT = Cond.getValueType(); + // Combine + // select vXi1 bitcast (int cond), + // ( bitcast a), + // ( bitcast b) + // to + // select cond, a, b + // to create opportunity for mask instructions with AVX512 instructions. + if (!Subtarget.hasAVX512()) + return SDValue(); + + if (!CondVT.isVector() || CondVT.getVectorElementType() != MVT::i1) + return SDValue(); + if (Cond.getOpcode() != ISD::BITCAST) + return SDValue(); + if (!dyn_cast(Cond.getOperand(0))) + return SDValue(); + + if (VT.getVectorElementType() != MVT::i64) + return SDValue(); + + if (LHS.getOpcode() != ISD::BITCAST || + LHS.getOperand(0).getValueType().getVectorElementType() != MVT::i32) + return SDValue(); + if (RHS.getOpcode() != ISD::BITCAST || + RHS.getOperand(0).getValueType().getVectorElementType() != MVT::i32) + return SDValue(); + + if (!Cond.hasOneUse() || !LHS.hasOneUse() || !RHS.hasOneUse()) + return SDValue(); + + int NumElts = VT.getVectorNumElements(); + EVT ExpandCondVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts * 2); + EVT ExpandVT = EVT::getVectorVT(*DAG.getContext(), MVT::i32, NumElts * 2); + + ConstantSDNode *ConstCond = cast(Cond.getOperand(0)); + uint64_t Mask = ConstCond->getZExtValue(); + Mask = (Mask << 1) | Mask; + SDValue MaskVal = DAG.getConstant( + Mask, DL, EVT::getIntegerVT(*DAG.getContext(), NumElts * 2)); + SDValue NewCond = DAG.getBitcast(ExpandCondVT, MaskVal); + return DAG.getBitcast(VT, + DAG.getSelect(DL, ExpandVT, NewCond, LHS.getOperand(0), + RHS.getOperand(0))); +} + /// Do target-specific dag combines on SELECT and VSELECT nodes. static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, @@ -44473,6 +44526,9 @@ return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS); } + if (SDValue V = combineSelectVxi64(N, DAG, Subtarget)) + return V; + // Some mask scalar intrinsics rely on checking if only one bit is set // and implement it in C code like this: // A[0] = (U & 1) ? A[0] : W[0];