Index: include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- include/llvm/CodeGen/ISDOpcodes.h +++ include/llvm/CodeGen/ISDOpcodes.h @@ -687,9 +687,16 @@ ATOMIC_LOAD_UMIN, ATOMIC_LOAD_UMAX, - // Masked load and store + // Masked load and store - consecutive vector load and store operations + // with additional mask operand that prevents memory accesses to the + // masked-off lanes. MLOAD, MSTORE, + // Masked gather and scatter - load and store operations for a vector of + // random addresses with additional mask operand that prevents memory + // accesses to the masked-off lanes. + MGATHER, MSCATTER, + /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. LIFETIME_START, LIFETIME_END, Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -872,6 +872,10 @@ SDValue getMaskedStore(SDValue Chain, SDLoc dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, bool IsTrunc); + SDValue getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, MachineMemOperand *MMO); + SDValue getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, MachineMemOperand *MMO); /// getSrcValue - Construct a node to track a Value* through the backend. SDValue getSrcValue(const Value *v); Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -1147,6 +1147,8 @@ N->getOpcode() == ISD::ATOMIC_STORE || N->getOpcode() == ISD::MLOAD || N->getOpcode() == ISD::MSTORE || + N->getOpcode() == ISD::MGATHER || + N->getOpcode() == ISD::MSCATTER || N->isMemIntrinsic() || N->isTargetMemoryOpcode(); } @@ -1983,6 +1985,82 @@ } }; +/// This is a base class is used to represent +/// MGATHER and MSCATTER nodes +/// +class MaskedGatherScatterSDNode : public MemSDNode { + // Operands + SDUse Ops[5]; +public: + friend class SelectionDAG; + MaskedGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, DebugLoc dl, + ArrayRef Operands, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + assert(Operands.size() == 5 && "Incompatible number of operands"); + InitOperands(Ops, Operands.data(), Operands.size()); + } + + // In the both nodes address is Op1, mask is Op2: + // MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value + // MaskedScatterSDNode (Chain, value, mask, base, index) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { return getOperand(3); } + const SDValue &getIndex() const { return getOperand(4); } + const SDValue &getMask() const { return getOperand(2); } + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::MGATHER || + N->getOpcode() == ISD::MSCATTER; + } +}; + +/// This class is used to represent an MGATHER node +/// +class MaskedGatherSDNode : public MaskedGatherScatterSDNode { +public: + friend class SelectionDAG; + MaskedGatherSDNode(unsigned Order, DebugLoc dl, ArrayRef Operands, + SDVTList VTs, EVT MemVT, MachineMemOperand *MMO) + : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, Operands, VTs, MemVT, + MMO) { + assert(getValue().getValueType() == getValueType(0) && + "Incompatible type of the PathThru value in MaskedGatherSDNode"); + assert(getMask().getValueType().getVectorNumElements() == + getValueType(0).getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(getMask().getValueType().getScalarType() == MVT::i1 && + "Vector width mismatch between mask and data"); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::MGATHER; + } +}; + +/// This class is used to represent an MSCATTER node +/// +class MaskedScatterSDNode : public MaskedGatherScatterSDNode { + +public: + friend class SelectionDAG; + MaskedScatterSDNode(unsigned Order, DebugLoc dl,ArrayRef Operands, + SDVTList VTs, EVT MemVT, MachineMemOperand *MMO) + : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, Operands, VTs, MemVT, + MMO) { + assert(getMask().getValueType().getVectorNumElements() == + getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(getMask().getValueType().getScalarType() == MVT::i1 && + "Vector width mismatch between mask and data"); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::MSCATTER; + } +}; + /// An SDNode that represents everything that will be needed /// to construct a MachineInstr. These nodes are created during the /// instruction selection proper phase. @@ -2074,7 +2152,7 @@ }; /// The largest SDNode class. -typedef AtomicSDNode LargestSDNode; +typedef MaskedGatherScatterSDNode LargestSDNode; /// The SDNode class with the greatest alignment requirement. typedef GlobalAddressSDNode MostAlignedSDNode; Index: include/llvm/Target/TargetSelectionDAG.td =================================================================== --- include/llvm/Target/TargetSelectionDAG.td +++ include/llvm/Target/TargetSelectionDAG.td @@ -208,6 +208,16 @@ SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameAs<0, 3> ]>; +def SDTMaskedGather: SDTypeProfile<2, 3, [ // masked gather + SDTCisVec<0>, SDTCisVec<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<1, 3>, + SDTCisPtrTy<4>, SDTCVecEltisVT<1, i1>, SDTCisSameNumEltsAs<0, 1> +]>; + +def SDTMaskedScatter: SDTypeProfile<1, 3, [ // masked scatter + SDTCisVec<0>, SDTCisVec<1>, SDTCisSameAs<0, 2>, SDTCisSameNumEltsAs<0, 1>, + SDTCVecEltisVT<0, i1>, SDTCisPtrTy<3> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; @@ -480,6 +490,10 @@ [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; def masked_load : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def masked_scatter : SDNode<"ISD::MSCATTER", SDTMaskedScatter, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def masked_gather : SDNode<"ISD::MGATHER", SDTMaskedGather, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -308,6 +308,8 @@ SDValue visitINSERT_SUBVECTOR(SDNode *N); SDValue visitMLOAD(SDNode *N); SDValue visitMSTORE(SDNode *N); + SDValue visitMGATHER(SDNode *N); + SDValue visitMSCATTER(SDNode *N); SDValue XformToShuffleWithZero(SDNode *N); SDValue ReassociateOps(unsigned Opc, SDLoc DL, SDValue LHS, SDValue RHS); @@ -1374,7 +1376,9 @@ case ISD::VECTOR_SHUFFLE: return visitVECTOR_SHUFFLE(N); case ISD::SCALAR_TO_VECTOR: return visitSCALAR_TO_VECTOR(N); case ISD::INSERT_SUBVECTOR: return visitINSERT_SUBVECTOR(N); + case ISD::MGATHER: return visitMGATHER(N); case ISD::MLOAD: return visitMLOAD(N); + case ISD::MSCATTER: return visitMSCATTER(N); case ISD::MSTORE: return visitMSTORE(N); } return SDValue(); @@ -4989,6 +4993,67 @@ TopHalf->isNullValue() ? RHS->getOperand(1) : LHS->getOperand(1)); } +SDValue DAGCombiner::visitMSCATTER(SDNode *N) { + + if (Level >= AfterLegalizeTypes) + return SDValue(); + + MaskedScatterSDNode *MSC = cast(N); + SDValue Mask = MSC->getMask(); + SDValue Data = MSC->getValue(); + SDLoc DL(N); + + // If the MSCATTER data type requires splitting and the mask is provided by a + // SETCC, then split both nodes and its operands before legalization. This + // prevents the type legalizer from unrolling SETCC into scalar comparisons + // and enables future optimizations (e.g. min/max pattern matching on X86). + if (Mask.getOpcode() != ISD::SETCC) + return SDValue(); + + // Check if any splitting is required. + if (TLI.getTypeAction(*DAG.getContext(), Data.getValueType()) != + TargetLowering::TypeSplitVector) + return SDValue(); + SDValue MaskLo, MaskHi, Lo, Hi; + std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MSC->getValueType(0)); + + SDValue Chain = MSC->getChain(); + + EVT MemoryVT = MSC->getMemoryVT(); + unsigned Alignment = MSC->getOriginalAlignment(); + + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue DataLo, DataHi; + std::tie(DataLo, DataHi) = DAG.SplitVector(Data, DL); + + SDValue BasePtr = MSC->getBasePtr(); + SDValue IndexLo, IndexHi; + std::tie(IndexLo, IndexHi) = DAG.SplitVector(MSC->getIndex(), DL); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MSC->getPointerInfo(), + MachineMemOperand::MOStore, LoMemVT.getStoreSize(), + Alignment, MSC->getAAInfo(), MSC->getRanges()); + + SDValue OpsLo[] = { Chain, DataLo, MaskLo, BasePtr, IndexLo }; + Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), + DL, OpsLo, MMO); + + SDValue OpsHi[] = {Chain, DataHi, MaskHi, BasePtr, IndexHi}; + Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), + DL, OpsHi, MMO); + + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); + + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi); +} + SDValue DAGCombiner::visitMSTORE(SDNode *N) { if (Level >= AfterLegalizeTypes) @@ -5063,6 +5128,83 @@ return SDValue(); } +SDValue DAGCombiner::visitMGATHER(SDNode *N) { + + if (Level >= AfterLegalizeTypes) + return SDValue(); + + MaskedGatherSDNode *MGT = dyn_cast(N); + SDValue Mask = MGT->getMask(); + SDLoc DL(N); + + // If the MGATHER result requires splitting and the mask is provided by a + // SETCC, then split both nodes and its operands before legalization. This + // prevents the type legalizer from unrolling SETCC into scalar comparisons + // and enables future optimizations (e.g. min/max pattern matching on X86). + + if (Mask.getOpcode() != ISD::SETCC) + return SDValue(); + + EVT VT = N->getValueType(0); + + // Check if any splitting is required. + if (TLI.getTypeAction(*DAG.getContext(), VT) != + TargetLowering::TypeSplitVector) + return SDValue(); + + SDValue MaskLo, MaskHi, Lo, Hi; + std::tie(MaskLo, MaskHi) = SplitVSETCC(Mask.getNode(), DAG); + + SDValue Src0 = MGT->getValue(); + SDValue Src0Lo, Src0Hi; + std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(Src0, DL); + + EVT LoVT, HiVT; + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(VT); + + SDValue Chain = MGT->getChain(); + EVT MemoryVT = MGT->getMemoryVT(); + unsigned Alignment = MGT->getOriginalAlignment(); + + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue BasePtr = MGT->getBasePtr(); + SDValue Index = MGT->getIndex(); + SDValue IndexLo, IndexHi; + std::tie(IndexLo, IndexHi) = DAG.SplitVector(Index, DL); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MGT->getPointerInfo(), + MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + Alignment, MGT->getAAInfo(), MGT->getRanges()); + + SDValue OpsLo[] = { Chain, Src0Lo, MaskLo, BasePtr, IndexLo }; + Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, DL, OpsLo, + MMO); + + SDValue OpsHi[] = {Chain, Src0Hi, MaskHi, BasePtr, IndexHi}; + Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, DL, OpsHi, + MMO); + + AddToWorklist(Lo.getNode()); + AddToWorklist(Hi.getNode()); + + // Build a factor node to remember that this load is independent of the + // other one. + Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo.getValue(1), + Hi.getValue(1)); + + // Legalized the chain result - switch anything that used the old chain to + // use the new one. + DAG.ReplaceAllUsesOfValueWith(SDValue(MGT, 1), Chain); + + SDValue GatherRes = DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi); + + SDValue RetOps[] = { GatherRes, Chain }; + return DAG.getMergeValues(RetOps, DL); +} + SDValue DAGCombiner::visitMLOAD(SDNode *N) { if (Level >= AfterLegalizeTypes) Index: lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -582,6 +582,7 @@ void SplitVecRes_INSERT_VECTOR_ELT(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_LOAD(LoadSDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_MLOAD(MaskedLoadSDNode *N, SDValue &Lo, SDValue &Hi); + void SplitVecRes_MGATHER(MaskedGatherSDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_SCALAR_TO_VECTOR(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_SIGN_EXTEND_INREG(SDNode *N, SDValue &Lo, SDValue &Hi); void SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi); @@ -599,6 +600,8 @@ SDValue SplitVecOp_EXTRACT_VECTOR_ELT(SDNode *N); SDValue SplitVecOp_STORE(StoreSDNode *N, unsigned OpNo); SDValue SplitVecOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo); + SDValue SplitVecOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo); + SDValue SplitVecOp_MGATHER(MaskedGatherSDNode *N, unsigned OpNo); SDValue SplitVecOp_CONCAT_VECTORS(SDNode *N); SDValue SplitVecOp_TRUNCATE(SDNode *N); SDValue SplitVecOp_VSETCC(SDNode *N); Index: lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -196,6 +196,7 @@ SDValue Result = SDValue(DAG.UpdateNodeOperands(Op.getNode(), Ops), 0); + bool HasVectorValue = false; if (Op.getOpcode() == ISD::LOAD) { LoadSDNode *LD = cast(Op.getNode()); ISD::LoadExtType ExtType = LD->getExtensionType(); @@ -243,9 +244,9 @@ Changed = true; return LegalizeOp(ExpandStore(Op)); } - } + } else if (Op.getOpcode() == ISD::MSCATTER) + HasVectorValue = true; - bool HasVectorValue = false; for (SDNode::value_iterator J = Node->value_begin(), E = Node->value_end(); J != E; ++J) @@ -330,6 +331,9 @@ case ISD::UINT_TO_FP: QueryType = Node->getOperand(0).getValueType(); break; + case ISD::MSCATTER: + QueryType = cast(Node)->getValue().getValueType(); + break; } switch (TLI.getOperationAction(Node->getOpcode(), QueryType)) { Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -600,6 +600,9 @@ case ISD::MLOAD: SplitVecRes_MLOAD(cast(N), Lo, Hi); break; + case ISD::MGATHER: + SplitVecRes_MGATHER(cast(N), Lo, Hi); + break; case ISD::SETCC: SplitVecRes_SETCC(N, Lo, Hi); break; @@ -1043,6 +1046,54 @@ } +void DAGTypeLegalizer::SplitVecRes_MGATHER(MaskedGatherSDNode *MGT, + SDValue &Lo, SDValue &Hi) { + EVT LoVT, HiVT; + SDLoc dl(MGT); + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MGT->getValueType(0)); + + SDValue Ch = MGT->getChain(); + SDValue Ptr = MGT->getBasePtr(); + SDValue Mask = MGT->getMask(); + unsigned Alignment = MGT->getOriginalAlignment(); + + SDValue MaskLo, MaskHi; + std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, dl); + + EVT MemoryVT = MGT->getMemoryVT(); + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue Src0Lo, Src0Hi; + std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(MGT->getValue(), dl); + + SDValue IndexHi, IndexLo; + std::tie(IndexLo, IndexHi) = DAG.SplitVector(MGT->getIndex(), dl); + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MGT->getPointerInfo(), + MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + Alignment, MGT->getAAInfo(), MGT->getRanges()); + + SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo}; + Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, + MMO); + + SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi}; + Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, + MMO); + + // Build a factor node to remember that this load is independent of the + // other one. + Ch = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo.getValue(1), + Hi.getValue(1)); + + // Legalized the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(MGT, 1), Ch); +} + + void DAGTypeLegalizer::SplitVecRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) { assert(N->getValueType(0).isVector() && N->getOperand(0).getValueType().isVector() && @@ -1300,6 +1351,11 @@ break; case ISD::MSTORE: Res = SplitVecOp_MSTORE(cast(N), OpNo); + case ISD::MSCATTER: + Res = SplitVecOp_MSCATTER(cast(N), OpNo); + break; + case ISD::MGATHER: + Res = SplitVecOp_MGATHER(cast(N), OpNo); break; case ISD::VSELECT: Res = SplitVecOp_VSELECT(N, OpNo); @@ -1462,6 +1518,68 @@ MachinePointerInfo(), EltVT, false, false, false, 0); } +SDValue DAGTypeLegalizer::SplitVecOp_MGATHER(MaskedGatherSDNode *MGT, + unsigned OpNo) { + EVT LoVT, HiVT; + SDLoc dl(MGT); + std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(MGT->getValueType(0)); + + SDValue Ch = MGT->getChain(); + SDValue Ptr = MGT->getBasePtr(); + SDValue Index = MGT->getIndex(); + SDValue Mask = MGT->getMask(); + unsigned Alignment = MGT->getOriginalAlignment(); + + SDValue MaskLo, MaskHi; + std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, dl); + + EVT MemoryVT = MGT->getMemoryVT(); + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue Src0Lo, Src0Hi; + std::tie(Src0Lo, Src0Hi) = DAG.SplitVector(MGT->getValue(), dl); + + SDValue IndexHi, IndexLo; + if (Index.getNode()) + std::tie(IndexLo, IndexHi) = DAG.SplitVector(Index, dl); + else + IndexLo = IndexHi = Index; + + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MGT->getPointerInfo(), + MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), + Alignment, MGT->getAAInfo(), MGT->getRanges()); + + SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo}; + SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, + OpsLo, MMO); + + MMO = DAG.getMachineFunction(). + getMachineMemOperand(MGT->getPointerInfo(), + MachineMemOperand::MOLoad, HiMemVT.getStoreSize(), + Alignment, MGT->getAAInfo(), + MGT->getRanges()); + + SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi}; + SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, + OpsHi, MMO); + + // Build a factor node to remember that this load is independent of the + // other one. + Ch = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, Lo.getValue(1), + Hi.getValue(1)); + + // Legalized the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(MGT, 1), Ch); + + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, dl, MGT->getValueType(0), Lo, + Hi); + ReplaceValueWith(SDValue(MGT, 0), Res); + return SDValue(); +} + SDValue DAGTypeLegalizer::SplitVecOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo) { SDValue Ch = N->getChain(); @@ -1514,6 +1632,61 @@ } +SDValue DAGTypeLegalizer::SplitVecOp_MSCATTER(MaskedScatterSDNode *N, + unsigned OpNo) { + SDValue Ch = N->getChain(); + SDValue Ptr = N->getBasePtr(); + SDValue Mask = N->getMask(); + SDValue Index = N->getIndex(); + SDValue Data = N->getValue(); + EVT MemoryVT = N->getMemoryVT(); + unsigned Alignment = N->getOriginalAlignment(); + SDLoc DL(N); + + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + + SDValue DataLo, DataHi; + GetSplitVector(Data, DataLo, DataHi); + SDValue MaskLo, MaskHi; + GetSplitVector(Mask, MaskLo, MaskHi); + + SDValue PtrLo, PtrHi; + if (Ptr.getValueType().isVector()) // gather form vector of pointers + std::tie(PtrLo, PtrHi) = DAG.SplitVector(Ptr, DL); + else + PtrLo = PtrHi = Ptr; + + SDValue IndexHi, IndexLo; + if (Index.getNode()) + std::tie(IndexLo, IndexHi) = DAG.SplitVector(Index, DL); + else + IndexLo = IndexHi = Index; + + SDValue Lo, Hi; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(N->getPointerInfo(), + MachineMemOperand::MOStore, LoMemVT.getStoreSize(), + Alignment, N->getAAInfo(), N->getRanges()); + + SDValue OpsLo[] = {Ch, DataLo, MaskLo, PtrLo, IndexLo}; + Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), + DL, OpsLo, MMO); + + MMO = DAG.getMachineFunction(). + getMachineMemOperand(N->getPointerInfo(), + MachineMemOperand::MOStore, HiMemVT.getStoreSize(), + Alignment, N->getAAInfo(), N->getRanges()); + + SDValue OpsHi[] = {Ch, DataHi, MaskHi, PtrHi, IndexHi}; + Hi = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), + DL, OpsHi, MMO); + + // Build a factor node to remember that this store is independent of the + // other one. + return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Lo, Hi); +} + SDValue DAGTypeLegalizer::SplitVecOp_STORE(StoreSDNode *N, unsigned OpNo) { assert(N->isUnindexed() && "Indexed store of vector?"); assert(OpNo == 1 && "Can only split the stored value"); Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5038,6 +5038,57 @@ return SDValue(N, 0); } +SDValue +SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(encodeMemSDNodeFlags(ISD::NON_EXTLOAD, ISD::UNINDEXED, + MMO->isVolatile(), + MMO->isNonTemporal(), + MMO->isInvariant())); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + MaskedGatherSDNode *N = + new (NodeAllocator) MaskedGatherSDNode(dl.getIROrder(), dl.getDebugLoc(), + Ops, VTs, VT, MMO); + CSEMap.InsertNode(N, IP); + InsertNode(N); + return SDValue(N, 0); +} + + +SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, SDLoc dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(encodeMemSDNodeFlags(false, ISD::UNINDEXED, MMO->isVolatile(), + MMO->isNonTemporal(), + MMO->isInvariant())); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = CSEMap.FindNodeOrInsertPos(ID, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + SDNode *N = + new (NodeAllocator) MaskedScatterSDNode(dl.getIROrder(), dl.getDebugLoc(), + Ops, VTs, VT, MMO); + CSEMap.InsertNode(N, IP); + InsertNode(N); + return SDValue(N, 0); +} + + SDValue SelectionDAG::getVAArg(EVT VT, SDLoc dl, SDValue Chain, SDValue Ptr, SDValue SV, Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -605,6 +605,8 @@ // generate the debug data structures now that we've seen its definition. void resolveDanglingDebugInfo(const Value *V, SDValue Val); SDValue getValue(const Value *V); + bool findValue(const Value *V) const; + SDValue getNonRegisterValue(const Value *V); SDValue getValueImpl(const Value *V); Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -1056,6 +1056,11 @@ return Val; } +bool SelectionDAGBuilder::findValue(const Value *V) const { + return (NodeMap.find(V) != NodeMap.end()) || + (FuncInfo.ValueMap.find(V) != FuncInfo.ValueMap.end()); +} + /// getNonRegisterValue - Return an SDValue for the given Value, but /// don't look in FuncInfo.ValueMap for a virtual register. SDValue SelectionDAGBuilder::getNonRegisterValue(const Value *V) { @@ -3634,39 +3639,107 @@ DAG.setRoot(StoreNode); } +// Gather/scatter receive a vector of pointers. +// This vector of pointers may be represented as a base pointer + vector of +// indices, it depends on GEP and instruction preceeding GEP +// that calculate indices +static bool getSingleBase(Value *& Ptr, SDValue& Base, SDValue& Index, + SelectionDAGBuilder* SDB) { + + assert (Ptr->getType()->isVectorTy() && "Uexpected pointer type"); + GetElementPtrInst *Gep = dyn_cast(Ptr); + if (!Gep || Gep->getNumOperands() > 2) + return false; + ShuffleVectorInst *ShuffleInst = + dyn_cast(Gep->getPointerOperand()); + if (!ShuffleInst || !ShuffleInst->getMask()->isNullValue() || + cast(ShuffleInst->getOperand(0))->getOpcode() != + Instruction::InsertElement) + return false; + + Ptr = cast(ShuffleInst->getOperand(0))->getOperand(1); + + SelectionDAG& DAG = SDB->DAG; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (SDB->findValue(Ptr)) + Base = SDB->getValue(Ptr); + else if (SDB->findValue(ShuffleInst)) { + SDValue ShuffleNode = SDB->getValue(ShuffleInst); + Base = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SDLoc(ShuffleNode), + ShuffleNode.getValueType().getScalarType(), ShuffleNode, + DAG.getConstant(0, TLI.getVectorIdxTy())); + SDB->setValue(Ptr, Base); + } + else + return false; + + Value *IndexVal = Gep->getOperand(1); + if (SDB->findValue(IndexVal)) { + Index = SDB->getValue(IndexVal); + + if (SExtInst* Sext = dyn_cast(IndexVal)) { + IndexVal = Sext->getOperand(0); + if (SDB->findValue(IndexVal)) + Index = SDB->getValue(IndexVal); + } + return true; + } + return false; +} + +// Masked scatter and masked store are handled in the same visitor void SelectionDAGBuilder::visitMaskedStore(const CallInst &I) { SDLoc sdl = getCurSDLoc(); // llvm.masked.store.*(Src0, Ptr, alignemt, Mask) - Value *PtrOperand = I.getArgOperand(1); - SDValue Ptr = getValue(PtrOperand); + Value *Ptr = I.getArgOperand(1); SDValue Src0 = getValue(I.getArgOperand(0)); SDValue Mask = getValue(I.getArgOperand(3)); EVT VT = Src0.getValueType(); unsigned Alignment = (cast(I.getArgOperand(2)))->getZExtValue(); if (!Alignment) Alignment = DAG.getEVTAlignment(VT); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); AAMDNodes AAInfo; I.getAAMetadata(AAInfo); - MachineMemOperand *MMO = - DAG.getMachineFunction(). - getMachineMemOperand(MachinePointerInfo(PtrOperand), - MachineMemOperand::MOStore, VT.getStoreSize(), - Alignment, AAInfo); - SDValue StoreNode = DAG.getMaskedStore(getRoot(), sdl, Src0, Ptr, Mask, VT, - MMO, false); - DAG.setRoot(StoreNode); - setValue(&I, StoreNode); + // The diff between "store" and "scatter" is in type of base pointer - + // "store" has one pointer and "scatter" has a vector of pointers + bool isStore = Ptr->getType()->isPointerTy(); + + SDValue Base; + SDValue Index; + Value *BasePtr = Ptr; + bool SingleBase = !isStore && getSingleBase(BasePtr, Base, Index, this); + + Value *MemOpBasePtr = (isStore || SingleBase) ? BasePtr : NULL; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + SDValue Store; + if (isStore) // Store form + Store = DAG.getMaskedStore(getRoot(), sdl, Src0, getValue(BasePtr), Mask, + VT, MMO, false); + else { // Scatter + if (!SingleBase) { + Base = DAG.getTargetConstant(0, TLI.getPointerTy()); + Index = getValue(Ptr); + } + SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index }; + Store = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO); + } + DAG.setRoot(Store); + setValue(&I, Store); } +// Masked Load and masked Gather are handled in the same visitor void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I) { SDLoc sdl = getCurSDLoc(); // @llvm.masked.load.*(Ptr, alignment, Mask, Src0) - Value *PtrOperand = I.getArgOperand(0); - SDValue Ptr = getValue(PtrOperand); + Value *Ptr = I.getArgOperand(0); SDValue Src0 = getValue(I.getArgOperand(3)); SDValue Mask = getValue(I.getArgOperand(2)); @@ -3680,25 +3753,48 @@ I.getAAMetadata(AAInfo); const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); - SDValue InChain = DAG.getRoot(); - if (AA->pointsToConstantMemory( - AliasAnalysis::Location(PtrOperand, - AA->getTypeStoreSize(I.getType()), + SDValue Root = DAG.getRoot(); + bool isLoad = Ptr->getType()->isPointerTy(); + + SDValue Base; + SDValue Index; + Value *BasePtr = I.getArgOperand(0); + bool SingleBase = !isLoad && getSingleBase(BasePtr, Base, Index, this); + bool ConstantMemory = false; + if ((isLoad || SingleBase) && AA->pointsToConstantMemory( + AliasAnalysis::Location(BasePtr, + AA->getTypeStoreSize(I.getType()), AAInfo))) { // Do not serialize (non-volatile) loads of constant memory with anything. - InChain = DAG.getEntryNode(); + Root = DAG.getEntryNode(); + ConstantMemory = true; } + Value *MemOpBasePtr = (isLoad || SingleBase) ? BasePtr : NULL; MachineMemOperand *MMO = DAG.getMachineFunction(). - getMachineMemOperand(MachinePointerInfo(PtrOperand), + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), MachineMemOperand::MOLoad, VT.getStoreSize(), Alignment, AAInfo, Ranges); - SDValue Load = DAG.getMaskedLoad(VT, sdl, InChain, Ptr, Mask, Src0, VT, MMO, - ISD::NON_EXTLOAD); + SDValue Load; + if (isLoad) // Load form + Load = DAG.getMaskedLoad(VT, sdl, Root, getValue(Ptr), Mask, Src0, VT, MMO, + ISD::NON_EXTLOAD); + else { + if (!SingleBase) { + Base = DAG.getTargetConstant(0, TLI.getPointerTy()); + Index = getValue(Ptr); + } + + SDValue Ops[] = { Root, Src0, Mask, Base, Index }; + Load = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, + MMO); + } + SDValue OutChain = Load.getValue(1); - DAG.setRoot(OutChain); + if (!ConstantMemory) + PendingLoads.push_back(OutChain); setValue(&I, Load); } @@ -4849,9 +4945,11 @@ return nullptr; } + case Intrinsic::masked_gather: case Intrinsic::masked_load: visitMaskedLoad(I); return nullptr; + case Intrinsic::masked_scatter: case Intrinsic::masked_store: visitMaskedStore(I); return nullptr; Index: lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -273,6 +273,8 @@ case ISD::STORE: return "store"; case ISD::MLOAD: return "masked_load"; case ISD::MSTORE: return "masked_store"; + case ISD::MGATHER: return "masked_gather"; + case ISD::MSCATTER: return "masked_scatter"; case ISD::VAARG: return "vaarg"; case ISD::VACOPY: return "vacopy"; case ISD::VAEND: return "vaend"; Index: lib/Target/X86/X86ISelDAGToDAG.cpp =================================================================== --- lib/Target/X86/X86ISelDAGToDAG.cpp +++ lib/Target/X86/X86ISelDAGToDAG.cpp @@ -204,6 +204,9 @@ bool SelectAddr(SDNode *Parent, SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, SDValue &Segment); + bool SelectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base, + SDValue &Scale, SDValue &Index, SDValue &Disp, + SDValue &Segment); bool SelectMOV64Imm32(SDValue N, SDValue &Imm); bool SelectLEAAddr(SDValue N, SDValue &Base, SDValue &Scale, SDValue &Index, SDValue &Disp, @@ -1316,6 +1319,40 @@ return false; } +bool X86DAGToDAGISel::SelectVectorAddr(SDNode *Parent, SDValue N, SDValue &Base, + SDValue &Scale, SDValue &Index, + SDValue &Disp, SDValue &Segment) { + + MaskedGatherScatterSDNode *Mgs = dyn_cast(Parent); + if (!Mgs) + return false; + X86ISelAddressMode AM; + unsigned AddrSpace = Mgs->getPointerInfo().getAddrSpace(); + // AddrSpace 256 -> GS, 257 -> FS. + if (AddrSpace == 256) + AM.Segment = CurDAG->getRegister(X86::GS, MVT::i16); + if (AddrSpace == 257) + AM.Segment = CurDAG->getRegister(X86::FS, MVT::i16); + + Base = Mgs->getBasePtr(); + Index = Mgs->getIndex(); + unsigned ScalarSize = Mgs->getValue().getValueType().getScalarSizeInBits(); + Scale = getI8Imm(ScalarSize/8); + + // If Base is 0, the whole address is in index and the Scale is 1 + if (ConstantSDNode *C = dyn_cast(Base)) { + assert(C->isNullValue() && "Unexpected base in gather/scatter"); + Scale = getI8Imm(1); + Base = CurDAG->getRegister(0, MVT::i32); + } + if (AM.Segment.getNode()) + Segment = AM.Segment; + else + Segment = CurDAG->getRegister(0, MVT::i32); + Disp = CurDAG->getTargetConstant(0, MVT::i32); + return true; +} + /// SelectAddr - returns true if it is able pattern match an addressing mode. /// It returns the operands which make up the maximal addressing mode it can /// match by reference. Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -1376,6 +1376,10 @@ // Custom lower several nodes. for (MVT VT : MVT::vector_valuetypes()) { unsigned EltSize = VT.getVectorElementType().getSizeInBits(); + if (EltSize >= 32 && VT.getSizeInBits() <= 512) { + setOperationAction(ISD::MGATHER, VT, Custom); + setOperationAction(ISD::MSCATTER, VT, Custom); + } // Extract subvector is special because the value type // (result) is 256/128-bit but the source is 512-bit wide. if (VT.is128BitVector() || VT.is256BitVector()) { @@ -1388,7 +1392,7 @@ if (!VT.is512BitVector()) continue; - if ( EltSize >= 32) { + if (EltSize >= 32) { setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom); setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom); setOperationAction(ISD::BUILD_VECTOR, VT, Custom); @@ -15162,8 +15166,8 @@ SDValue Index = Op.getOperand(4); SDValue Mask = Op.getOperand(5); SDValue Scale = Op.getOperand(6); - return getGatherNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index, Scale, Chain, - Subtarget); + return getGatherNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index, Scale, + Chain, Subtarget); } case SCATTER: { //scatter(base, mask, index, v1, scale); @@ -15173,7 +15177,8 @@ SDValue Index = Op.getOperand(4); SDValue Src = Op.getOperand(5); SDValue Scale = Op.getOperand(6); - return getScatterNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index, Scale, Chain); + return getScatterNode(IntrData->Opc0, Op, DAG, Src, Mask, Base, Index, + Scale, Chain); } case PREFETCH: { SDValue Hint = Op.getOperand(6); @@ -15192,7 +15197,8 @@ // Read Time Stamp Counter (RDTSC) and Processor ID (RDTSCP). case RDTSC: { SmallVector Results; - getReadTimeStampCounter(Op.getNode(), dl, IntrData->Opc0, DAG, Subtarget, Results); + getReadTimeStampCounter(Op.getNode(), dl, IntrData->Opc0, DAG, Subtarget, + Results); return DAG.getMergeValues(Results, dl); } // Read Performance Monitoring Counters. @@ -17051,6 +17057,56 @@ return DAG.getNode(ISD::MERGE_VALUES, dl, Tys, SinVal, CosVal); } +static SDValue LowerMSCATTER(SDValue Op, const X86Subtarget *Subtarget, + SelectionDAG &DAG) { + assert(Subtarget->hasAVX512() && + "MGATHER/MSCATTER are supported on AVX-512 arch only"); + + MaskedScatterSDNode *N = cast(Op.getNode()); + EVT VT = N->getValue().getValueType(); + assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op"); + SDLoc dl(Op); + + // X86 scatter kills mask register, so its type should be added to + // the list of return values + if (N->getNumValues() == 1) { + SDValue Index = N->getIndex(); + if (!Subtarget->hasVLX() && !VT.is512BitVector() && + !Index.getValueType().is512BitVector()) + Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); + + SDVTList VTs = DAG.getVTList(N->getMask().getValueType(), MVT::Other); + SDValue Ops[] = { N->getOperand(0), N->getOperand(1), N->getOperand(2), + N->getOperand(3), Index }; + + SDValue NewScatter = DAG.getMaskedScatter(VTs, VT, dl, Ops, N->getMemOperand()); + DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); + return SDValue(NewScatter.getNode(), 0); + } + return Op; +} + +static SDValue LowerMGATHER(SDValue Op, const X86Subtarget *Subtarget, + SelectionDAG &DAG) { + assert(Subtarget->hasAVX512() && + "MGATHER/MSCATTER are supported on AVX-512 arch only"); + + MaskedGatherSDNode *N = cast(Op.getNode()); + EVT VT = Op.getValueType(); + assert(VT.getScalarSizeInBits() >= 32 && "Unsupported gather op"); + SDLoc dl(Op); + + SDValue Index = N->getIndex(); + if (!Subtarget->hasVLX() && !VT.is512BitVector() && + !Index.getValueType().is512BitVector()) { + Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); + SDValue Ops[] = { N->getOperand(0), N->getOperand(1), N->getOperand(2), + N->getOperand(3), Index }; + DAG.UpdateNodeOperands(N, Ops); + } + return Op; +} + /// LowerOperation - Provide custom lowering hooks for some operations. /// SDValue X86TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { @@ -17138,6 +17194,8 @@ case ISD::ADD: return LowerADD(Op, DAG); case ISD::SUB: return LowerSUB(Op, DAG); case ISD::FSINCOS: return LowerFSINCOS(Op, Subtarget, DAG); + case ISD::MGATHER: return LowerMGATHER(Op, Subtarget, DAG); + case ISD::MSCATTER: return LowerMSCATTER(Op, Subtarget, DAG); } } Index: lib/Target/X86/X86InstrAVX512.td =================================================================== --- lib/Target/X86/X86InstrAVX512.td +++ lib/Target/X86/X86InstrAVX512.td @@ -5059,74 +5059,80 @@ //===----------------------------------------------------------------------===// // GATHER - SCATTER Operations -multiclass avx512_gather opc, string OpcodeStr, RegisterClass KRC, - RegisterClass RC, X86MemOperand memop> { -let mayLoad = 1, - Constraints = "@earlyclobber $dst, $src1 = $dst, $mask = $mask_wb" in - def rm : AVX5128I opc, string OpcodeStr, X86VectorVTInfo _, + X86MemOperand memop, PatFrag GatherNode> { + let Constraints = "@earlyclobber $dst, $src1 = $dst, $mask = $mask_wb" in + def rm : AVX5128I, EVEX, EVEX_K; + [(set _.RC:$dst, _.KRCWM:$mask_wb, + (GatherNode (_.VT _.RC:$src1), _.KRCWM:$mask, + vectoraddr:$src2))]>, EVEX, EVEX_K, + EVEX_CD8<_.EltSize, CD8VT1>; } let ExeDomain = SSEPackedDouble in { -defm VGATHERDPDZ : avx512_gather<0x92, "vgatherdpd", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VGATHERQPDZ : avx512_gather<0x93, "vgatherqpd", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; +defm VGATHERDPDZ : avx512_gather<0x92, "vgatherdpd", v8f64_info, vy64xmem, + mgatherv8i32>, EVEX_V512, VEX_W; +defm VGATHERQPDZ : avx512_gather<0x93, "vgatherqpd", v8f64_info, vz64mem, + mgatherv8i64>, EVEX_V512, VEX_W; } let ExeDomain = SSEPackedSingle in { -defm VGATHERDPSZ : avx512_gather<0x92, "vgatherdps", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; -defm VGATHERQPSZ : avx512_gather<0x93, "vgatherqps", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; -} - -defm VPGATHERDQZ : avx512_gather<0x90, "vpgatherdq", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPGATHERDDZ : avx512_gather<0x90, "vpgatherdd", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; - -defm VPGATHERQQZ : avx512_gather<0x91, "vpgatherqq", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPGATHERQDZ : avx512_gather<0x91, "vpgatherqd", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VGATHERDPSZ : avx512_gather<0x92, "vgatherdps", v16f32_info, vz32mem, + mgatherv16i32>, EVEX_V512; +defm VGATHERQPSZ : avx512_gather<0x93, "vgatherqps", v8f32x_info, vz64mem, + mgatherv8i64>, EVEX_V512; +} + +defm VPGATHERDQZ : avx512_gather<0x90, "vpgatherdq", v8i64_info, vy64xmem, + mgatherv8i32>, EVEX_V512, VEX_W; +defm VPGATHERDDZ : avx512_gather<0x90, "vpgatherdd", v16i32_info, vz32mem, + mgatherv16i32>, EVEX_V512; + +defm VPGATHERQQZ : avx512_gather<0x91, "vpgatherqq", v8i64_info, vz64mem, + mgatherv8i64>, EVEX_V512, VEX_W; +defm VPGATHERQDZ : avx512_gather<0x91, "vpgatherqd", v8i32x_info, vz64mem, + mgatherv8i64>, EVEX_V512; + +multiclass avx512_scatter opc, string OpcodeStr, X86VectorVTInfo _, + X86MemOperand memop, PatFrag ScatterNode> { -multiclass avx512_scatter opc, string OpcodeStr, RegisterClass KRC, - RegisterClass RC, X86MemOperand memop> { let mayStore = 1, Constraints = "$mask = $mask_wb" in - def mr : AVX5128I, EVEX, EVEX_K; + "\t{$src, ${dst} {${mask}}|${dst} {${mask}}, $src}"), + [(set _.KRCWM:$mask_wb, (ScatterNode (_.VT _.RC:$src), + _.KRCWM:$mask, vectoraddr:$dst))]>, + EVEX, EVEX_K, EVEX_CD8<_.EltSize, CD8VT1>; } let ExeDomain = SSEPackedDouble in { -defm VSCATTERDPDZ : avx512_scatter<0xA2, "vscatterdpd", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VSCATTERQPDZ : avx512_scatter<0xA3, "vscatterqpd", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; +defm VSCATTERDPDZ : avx512_scatter<0xA2, "vscatterdpd", v8f64_info, vy64xmem, + mscatterv8i32>, EVEX_V512, VEX_W; +defm VSCATTERQPDZ : avx512_scatter<0xA3, "vscatterqpd", v8f64_info, vz64mem, + mscatterv8i64>, EVEX_V512, VEX_W; } let ExeDomain = SSEPackedSingle in { -defm VSCATTERDPSZ : avx512_scatter<0xA2, "vscatterdps", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; -defm VSCATTERQPSZ : avx512_scatter<0xA3, "vscatterqps", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VSCATTERDPSZ : avx512_scatter<0xA2, "vscatterdps", v16f32_info, vz32mem, + mscatterv16i32>, EVEX_V512; +defm VSCATTERQPSZ : avx512_scatter<0xA3, "vscatterqps", v8f32x_info, vz64mem, + mscatterv8i64>, EVEX_V512; } -defm VPSCATTERDQZ : avx512_scatter<0xA0, "vpscatterdq", VK8WM, VR512, vy64xmem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPSCATTERDDZ : avx512_scatter<0xA0, "vpscatterdd", VK16WM, VR512, vz32mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; - -defm VPSCATTERQQZ : avx512_scatter<0xA1, "vpscatterqq", VK8WM, VR512, vz64mem>, - EVEX_V512, VEX_W, EVEX_CD8<64, CD8VT1>; -defm VPSCATTERQDZ : avx512_scatter<0xA1, "vpscatterqd", VK8WM, VR256X, vz64mem>, - EVEX_V512, EVEX_CD8<32, CD8VT1>; +defm VPSCATTERDQZ : avx512_scatter<0xA0, "vpscatterdq", v8i64_info, vy64xmem, + mscatterv8i32>, EVEX_V512, VEX_W; +defm VPSCATTERDDZ : avx512_scatter<0xA0, "vpscatterdd", v16i32_info, vz32mem, + mscatterv16i32>, EVEX_V512; + +defm VPSCATTERQQZ : avx512_scatter<0xA1, "vpscatterqq", v8i64_info, vz64mem, + mscatterv8i64>, EVEX_V512, VEX_W; +defm VPSCATTERQDZ : avx512_scatter<0xA1, "vpscatterqd", v8i32x_info, vz64mem, + mscatterv8i64>, EVEX_V512; // prefetch multiclass avx512_gather_scatter_prefetch opc, Format F, string OpcodeStr, Index: lib/Target/X86/X86InstrFragmentsSIMD.td =================================================================== --- lib/Target/X86/X86InstrFragmentsSIMD.td +++ lib/Target/X86/X86InstrFragmentsSIMD.td @@ -526,6 +526,52 @@ return false; }]>; +def mgatherv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_gather node:$src1, node:$src2, node:$src3) , [{ + if (MaskedGatherSDNode *Mgt = dyn_cast(N)) + return (Mgt->getIndex().getValueType() == MVT::v8i32 || + Mgt->getBasePtr().getValueType() == MVT::v8i32); + return false; +}]>; + +def mgatherv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_gather node:$src1, node:$src2, node:$src3) , [{ + if (MaskedGatherSDNode *Mgt = dyn_cast(N)) + return (Mgt->getIndex().getValueType() == MVT::v8i64 || + Mgt->getBasePtr().getValueType() == MVT::v8i64); + return false; +}]>; +def mgatherv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_gather node:$src1, node:$src2, node:$src3) , [{ + if (MaskedGatherSDNode *Mgt = dyn_cast(N)) + return (Mgt->getIndex().getValueType() == MVT::v16i32 || + Mgt->getBasePtr().getValueType() == MVT::v16i32); + return false; +}]>; + +def mscatterv8i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_scatter node:$src1, node:$src2, node:$src3) , [{ + if (MaskedScatterSDNode *Sc = dyn_cast(N)) + return (Sc->getIndex().getValueType() == MVT::v8i32 || + Sc->getBasePtr().getValueType() == MVT::v8i32); + return false; +}]>; + +def mscatterv8i64 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_scatter node:$src1, node:$src2, node:$src3) , [{ + if (MaskedScatterSDNode *Sc = dyn_cast(N)) + return (Sc->getIndex().getValueType() == MVT::v8i64 || + Sc->getBasePtr().getValueType() == MVT::v8i64); + return false; +}]>; +def mscatterv16i32 : PatFrag<(ops node:$src1, node:$src2, node:$src3), + (masked_scatter node:$src1, node:$src2, node:$src3) , [{ + if (MaskedScatterSDNode *Sc = dyn_cast(N)) + return (Sc->getIndex().getValueType() == MVT::v16i32 || + Sc->getBasePtr().getValueType() == MVT::v16i32); + return false; +}]>; + // 128-bit bitconvert pattern fragments def bc_v4f32 : PatFrag<(ops node:$in), (v4f32 (bitconvert node:$in))>; def bc_v2f64 : PatFrag<(ops node:$in), (v2f64 (bitconvert node:$in))>; Index: lib/Target/X86/X86InstrInfo.td =================================================================== --- lib/Target/X86/X86InstrInfo.td +++ lib/Target/X86/X86InstrInfo.td @@ -716,6 +716,8 @@ def tls64baseaddr : ComplexPattern; +def vectoraddr : ComplexPattern; + //===----------------------------------------------------------------------===// // X86 Instruction Predicate Definitions. def HasCMov : Predicate<"Subtarget->hasCMov()">; Index: test/CodeGen/X86/masked_gather_scatter.ll =================================================================== --- test/CodeGen/X86/masked_gather_scatter.ll +++ test/CodeGen/X86/masked_gather_scatter.ll @@ -0,0 +1,147 @@ +; RUN: llc -mtriple=x86_64-apple-darwin -mcpu=knl < %s | FileCheck %s -check-prefix=KNL + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +; KNL-LABEL: test1 +; KNL: kxnorw %k1, %k1, %k1 +; KNL: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1} +define <16 x float> @test1(float* %base, <16 x i32> %ind) { + + %broadcast.splatinsert = insertelement <16 x float*> undef, float* %base, i32 0 + %broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer + + %sext_ind = sext <16 x i32> %ind to <16 x i64> + %gep.random = getelementptr float, <16 x float*> %broadcast.splat, <16 x i64> %sext_ind + + %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> , <16 x float> undef) + ret <16 x float>%res +} + +declare <16 x i32> @llvm.masked.gather.v16i32(<16 x i32*>, i32, <16 x i1>, <16 x i32>) +declare <16 x float> @llvm.masked.gather.v16f32(<16 x float*>, i32, <16 x i1>, <16 x float>) +declare <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> , i32, <8 x i1> , <8 x i32> ) + +; KNL-LABEL: test2 +; KNL: kmovw %esi, %k1 +; KNL: vgatherdps (%rdi,%zmm0,4), %zmm1 {%k1} +define <16 x float> @test2(float* %base, <16 x i32> %ind, i16 %mask) { + + %broadcast.splatinsert = insertelement <16 x float*> undef, float* %base, i32 0 + %broadcast.splat = shufflevector <16 x float*> %broadcast.splatinsert, <16 x float*> undef, <16 x i32> zeroinitializer + + %sext_ind = sext <16 x i32> %ind to <16 x i64> + %gep.random = getelementptr float, <16 x float*> %broadcast.splat, <16 x i64> %sext_ind + %imask = bitcast i16 %mask to <16 x i1> + %res = call <16 x float> @llvm.masked.gather.v16f32(<16 x float*> %gep.random, i32 4, <16 x i1> %imask, <16 x float>undef) + ret <16 x float> %res +} + +; KNL-LABEL: test3 +; KNL: kmovw %esi, %k1 +; KNL: vpgatherdd (%rdi,%zmm0,4), %zmm1 {%k1} +define <16 x i32> @test3(i32* %base, <16 x i32> %ind, i16 %mask) { + + %broadcast.splatinsert = insertelement <16 x i32*> undef, i32* %base, i32 0 + %broadcast.splat = shufflevector <16 x i32*> %broadcast.splatinsert, <16 x i32*> undef, <16 x i32> zeroinitializer + + %sext_ind = sext <16 x i32> %ind to <16 x i64> + %gep.random = getelementptr i32, <16 x i32*> %broadcast.splat, <16 x i64> %sext_ind + %imask = bitcast i16 %mask to <16 x i1> + %res = call <16 x i32> @llvm.masked.gather.v16i32(<16 x i32*> %gep.random, i32 4, <16 x i1> %imask, <16 x i32>undef) + ret <16 x i32> %res +} + +; KNL-LABEL: test4 +; KNL: kmovw %esi, %k1 +; KNL: kmovw +; KNL: vpgatherdd +; KNL: vpgatherdd + +define <16 x i32> @test4(i32* %base, <16 x i32> %ind, i16 %mask) { + + %broadcast.splatinsert = insertelement <16 x i32*> undef, i32* %base, i32 0 + %broadcast.splat = shufflevector <16 x i32*> %broadcast.splatinsert, <16 x i32*> undef, <16 x i32> zeroinitializer + + %gep.random = getelementptr i32, <16 x i32*> %broadcast.splat, <16 x i32> %ind + %imask = bitcast i16 %mask to <16 x i1> + %gt1 = call <16 x i32> @llvm.masked.gather.v16i32(<16 x i32*> %gep.random, i32 4, <16 x i1> %imask, <16 x i32>undef) + %gt2 = call <16 x i32> @llvm.masked.gather.v16i32(<16 x i32*> %gep.random, i32 4, <16 x i1> %imask, <16 x i32>%gt1) + %res = add <16 x i32> %gt1, %gt2 + ret <16 x i32> %res +} + +; KNL-LABEL: test5 +; KNL: kmovw %k1, %k2 +; KNL: vpscatterdd {{.*}}%k2 +; KNL: vpscatterdd {{.*}}%k1 + +define void @test5(i32* %base, <16 x i32> %ind, i16 %mask, <16 x i32>%val) { + + %broadcast.splatinsert = insertelement <16 x i32*> undef, i32* %base, i32 0 + %broadcast.splat = shufflevector <16 x i32*> %broadcast.splatinsert, <16 x i32*> undef, <16 x i32> zeroinitializer + + %gep.random = getelementptr i32, <16 x i32*> %broadcast.splat, <16 x i32> %ind + %imask = bitcast i16 %mask to <16 x i1> + call void @llvm.masked.scatter.v16i32(<16 x i32>%val, <16 x i32*> %gep.random, i32 4, <16 x i1> %imask) + call void @llvm.masked.scatter.v16i32(<16 x i32>%val, <16 x i32*> %gep.random, i32 4, <16 x i1> %imask) + ret void +} + +declare void @llvm.masked.scatter.v8i32(<8 x i32> , <8 x i32*> , i32 , <8 x i1> ) +declare void @llvm.masked.scatter.v16i32(<16 x i32> , <16 x i32*> , i32 , <16 x i1> ) + +; KNL-LABEL: test6 +; KNL: kxnorw %k1, %k1, %k1 +; KNL: kxnorw %k2, %k2, %k2 +; KNL: vpgatherqd (,%zmm{{.*}}), %ymm{{.*}} {%k2} +; KNL: vpscatterqd %ymm{{.*}}, (,%zmm{{.*}}) {%k1} +define <8 x i32> @test6(<8 x i32>%a1, <8 x i32*> %ptr) { + + %a = call <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> %ptr, i32 4, <8 x i1> , <8 x i32> undef) + + call void @llvm.masked.scatter.v8i32(<8 x i32> %a1, <8 x i32*> %ptr, i32 4, <8 x i1> ) + ret <8 x i32>%a +} + +define <8 x i32> @test7(i32* %base, <8 x i32> %ind, i8 %mask) { + + %broadcast.splatinsert = insertelement <8 x i32*> undef, i32* %base, i32 0 + %broadcast.splat = shufflevector <8 x i32*> %broadcast.splatinsert, <8 x i32*> undef, <8 x i32> zeroinitializer + + %gep.random = getelementptr i32, <8 x i32*> %broadcast.splat, <8 x i32> %ind + %imask = bitcast i8 %mask to <8 x i1> + %gt1 = call <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> %gep.random, i32 4, <8 x i1> %imask, <8 x i32>undef) + %gt2 = call <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> %gep.random, i32 4, <8 x i1> %imask, <8 x i32>%gt1) + %res = add <8 x i32> %gt1, %gt2 + ret <8 x i32> %res +} + +define <8 x i32> @test8(i8* %b, <8 x i32> %ind, i8 %mask) { + %base = bitcast i8* %b to i32* + br label %vector.body + +vector.body: + + %broadcast.splatinsert = insertelement <8 x i32*> undef, i32* %base, i32 0 + %broadcast.splat = shufflevector <8 x i32*> %broadcast.splatinsert, <8 x i32*> undef, <8 x i32> zeroinitializer + + %gep.random = getelementptr i32, <8 x i32*> %broadcast.splat, <8 x i32> %ind + %imask = bitcast i8 %mask to <8 x i1> + %gt1 = call <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> %gep.random, i32 4, <8 x i1> %imask, <8 x i32>undef) + %gt2 = call <8 x i32> @llvm.masked.gather.v8i32(<8 x i32*> %gep.random, i32 4, <8 x i1> %imask, <8 x i32>%gt1) + %res = add <8 x i32> %gt1, %gt2 + ret <8 x i32> %res +} + +define <16 x i32> @test9(<16 x i32*> %ptr.random, <16 x i32> %ind, i16 %mask) { + %imask = bitcast i16 %mask to <16 x i1> + %gt1 = call <16 x i32> @llvm.masked.gather.v16i32(<16 x i32*> %ptr.random, i32 4, <16 x i1> %imask, <16 x i32>undef) + %gt2 = call <16 x i32> @llvm.masked.gather.v16i32(<16 x i32*> %ptr.random, i32 4, <16 x i1> %imask, <16 x i32>%gt1) + %res = add <16 x i32> %gt1, %gt2 + ret <16 x i32> %res +} + + + + Index: test/CodeGen/X86/masked_memop.ll =================================================================== --- test/CodeGen/X86/masked_memop.ll +++ test/CodeGen/X86/masked_memop.ll @@ -7,8 +7,8 @@ ; AVX512: vmovdqu32 (%rdi), %zmm0 {%k1} {z} ; AVX2-LABEL: test1 -; AVX2: vpmaskmovd 32(%rdi) -; AVX2: vpmaskmovd (%rdi) +; AVX2: vpmaskmovd {{.*}}(%rdi) +; AVX2: vpmaskmovd {{.*}}(%rdi) ; AVX2-NOT: blend ; AVX_SCALAR-LABEL: test1