diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1361,7 +1361,8 @@ ISD::MemIndexType IndexType); SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, - ISD::MemIndexType IndexType); + ISD::MemIndexType IndexType, + bool IsTruncating = false); /// Construct a node to track a Value* through the backend. SDValue getSrcValue(const Value *v); diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -523,6 +523,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class MaskedStoreSDNode; + friend class MaskedScatterSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -2441,9 +2442,16 @@ MaskedScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT, MachineMemOperand *MMO, - ISD::MemIndexType IndexType) + ISD::MemIndexType IndexType, bool IsTrunc) : MaskedGatherScatterSDNode(ISD::MSCATTER, Order, dl, VTs, MemVT, MMO, - IndexType) {} + IndexType) { + StoreSDNodeBits.IsTruncating = IsTrunc; + } + + /// Return true if the op does a truncation before store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } const SDValue &getValue() const { return getOperand(1); } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -1851,6 +1851,7 @@ SDValue DAGTypeLegalizer::PromoteIntOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo) { + bool TruncateStore = N->isTruncatingStore(); SmallVector NewOps(N->op_begin(), N->op_end()); if (OpNo == 2) { // The Mask @@ -1863,9 +1864,15 @@ NewOps[OpNo] = SExtPromotedInteger(N->getOperand(OpNo)); else NewOps[OpNo] = ZExtPromotedInteger(N->getOperand(OpNo)); - } else + + } else { NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo)); - return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); + TruncateStore = true; + } + + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), N->getMemoryVT(), + SDLoc(N), NewOps, N->getMemOperand(), + N->getIndexType(), TruncateStore); } SDValue DAGTypeLegalizer::PromoteIntOp_TRUNCATE(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -2496,11 +2496,15 @@ SDValue Index = N->getIndex(); SDValue Scale = N->getScale(); SDValue Data = N->getValue(); + EVT MemoryVT = N->getMemoryVT(); Align Alignment = N->getOriginalAlign(); SDLoc DL(N); // Split all operands + EVT LoMemVT, HiMemVT; + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + SDValue DataLo, DataHi; if (getTypeAction(Data.getValueType()) == TargetLowering::TypeSplitVector) // Split Data operand @@ -2531,15 +2535,17 @@ MemoryLocation::UnknownSize, Alignment, N->getAAInfo(), N->getRanges()); SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale}; - Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), - DL, OpsLo, MMO, N->getIndexType()); + Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), LoMemVT, + DL, OpsLo, MMO, N->getIndexType(), + N->isTruncatingStore()); // The order of the Scatter operation after split is well defined. The "Hi" // part comes after the "Lo". So these two operations should be chained one // after another. SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale}; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), - DL, OpsHi, MMO, N->getIndexType()); + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), HiMemVT, + DL, OpsHi, MMO, N->getIndexType(), + N->isTruncatingStore()); } SDValue DAGTypeLegalizer::SplitVecOp_STORE(StoreSDNode *N, unsigned OpNo) { @@ -4717,7 +4723,8 @@ Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), SDLoc(N), Ops, - MSC->getMemOperand(), MSC->getIndexType()); + MSC->getMemOperand(), MSC->getIndexType(), + MSC->isTruncatingStore()); } SDValue DAGTypeLegalizer::WidenVecOp_SETCC(SDNode *N) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -7340,22 +7340,24 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, - ISD::MemIndexType IndexType) { + ISD::MemIndexType IndexType, + bool IsTrunc) { assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData( - dl.getIROrder(), VTs, VT, MMO, IndexType)); + dl.getIROrder(), VTs, VT, MMO, IndexType, IsTrunc)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { cast(E)->refineAlignment(MMO); return SDValue(E, 0); } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), - VTs, VT, MMO, IndexType); + VTs, VT, MMO, IndexType, IsTrunc); createOperands(N, Ops); assert(N->getMask().getValueType().getVectorNumElements() == diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4302,7 +4302,7 @@ } SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale }; SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, - Ops, MMO, IndexType); + Ops, MMO, IndexType, false); DAG.setRoot(Scatter); setValue(&I, Scatter); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -735,7 +735,19 @@ OS << ", compressing"; OS << ">"; - } else if (const MemSDNode* M = dyn_cast(this)) { + } else if (const auto *MScatter = dyn_cast(this)) { + OS << "<"; + printMemOperand(OS, *MScatter->getMemOperand(), G); + + if (MScatter->isTruncatingStore()) + OS << ", trunc to " << MScatter->getMemoryVT().getEVTString(); + + auto Signed = MScatter->isIndexSigned() ? "signed" : "unsigned"; + auto Scaled = MScatter->isIndexScaled() ? "scaled" : "unscaled"; + OS << ", " << Signed << " " << Scaled << " offset"; + + OS << ">"; + } else if (const MemSDNode *M = dyn_cast(this)) { OS << "<"; printMemOperand(OS, *M->getMemOperand(), G); OS << ">"; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47560,7 +47560,8 @@ return DAG.getMaskedScatter(Scatter->getVTList(), Scatter->getMemoryVT(), DL, Ops, Scatter->getMemOperand(), - Scatter->getIndexType()); + Scatter->getIndexType(), + Scatter->isTruncatingStore()); } static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,