diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -70,6 +70,9 @@ case ISD::EXTRACT_VECTOR_ELT: Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break; case ISD::LOAD: Res = PromoteIntRes_LOAD(cast(N)); break; + case ISD::VP_LOAD: + Res = PromoteIntRes_VP_LOAD(cast(N)); + break; case ISD::MLOAD: Res = PromoteIntRes_MLOAD(cast(N)); break; case ISD::MGATHER: Res = PromoteIntRes_MGATHER(cast(N)); @@ -708,6 +711,20 @@ return Res; } +SDValue DAGTypeLegalizer::PromoteIntRes_VP_LOAD(VPLoadSDNode *N) { + EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + + SDLoc dl(N); + SDValue Res = DAG.getLoadVP(N->getAddressingMode(), ISD::EXTLOAD, NVT, dl, + N->getChain(), N->getBasePtr(), N->getOffset(), + N->getMask(), N->getVectorLength(), + N->getMemoryVT(), N->getMemOperand()); + // Legalize the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); + return Res; +} + SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) { EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); SDValue ExtPassThru = GetPromotedInteger(N->getPassThru()); @@ -1519,8 +1536,14 @@ case ISD::STRICT_SINT_TO_FP: Res = PromoteIntOp_STRICT_SINT_TO_FP(N); break; case ISD::STORE: Res = PromoteIntOp_STORE(cast(N), OpNo); break; + case ISD::VP_STORE: + Res = PromoteIntOp_VP_STORE(cast(N), OpNo); + break; case ISD::MSTORE: Res = PromoteIntOp_MSTORE(cast(N), OpNo); break; + case ISD::VP_LOAD: + Res = PromoteIntOp_VP_LOAD(cast(N), OpNo); + break; case ISD::MLOAD: Res = PromoteIntOp_MLOAD(cast(N), OpNo); break; case ISD::MGATHER: Res = PromoteIntOp_MGATHER(cast(N), @@ -1840,6 +1863,44 @@ N->getMemoryVT(), N->getMemOperand()); } +SDValue DAGTypeLegalizer::PromoteIntOp_VP_STORE(VPStoreSDNode *N, + unsigned OpNo) { + SDValue DataOp = N->getValue(); + + if (OpNo >= 4) { + // The Mask||EVL. Update in place. + EVT DataVT = DataOp.getValueType(); + SDValue PromotedOperand = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = PromotedOperand; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); + } + + assert(OpNo == 1 && "Unexpected operand for promotion"); + DataOp = GetPromotedInteger(DataOp); + + return DAG.getTruncStoreVP(N->getChain(), SDLoc(N), DataOp, N->getBasePtr(), + N->getMask(), N->getVectorLength(), + N->getMemoryVT(), N->getMemOperand(), + N->isCompressingStore()); +} + +SDValue DAGTypeLegalizer::PromoteIntOp_VP_LOAD(VPLoadSDNode *N, unsigned OpNo) { + assert(OpNo >= 3 && "Only know how to promote the mask or length!"); + EVT DataVT = N->getValueType(0); + SDValue PromotedOperand = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = PromotedOperand; + SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); + if (Res == N) + return SDValue(Res, 0); + + // Update triggered CSE, do our own replacement since caller can't. + ReplaceValueWith(SDValue(N, 0), SDValue(Res, 0)); + ReplaceValueWith(SDValue(N, 1), SDValue(Res, 1)); + return SDValue(); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo) { SDValue DataOp = N->getValue(); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -324,6 +324,7 @@ SDValue PromoteIntRes_FREEZE(SDNode *N); SDValue PromoteIntRes_INT_EXTEND(SDNode *N); SDValue PromoteIntRes_LOAD(LoadSDNode *N); + SDValue PromoteIntRes_VP_LOAD(VPLoadSDNode *N); SDValue PromoteIntRes_MLOAD(MaskedLoadSDNode *N); SDValue PromoteIntRes_MGATHER(MaskedGatherSDNode *N); SDValue PromoteIntRes_Overflow(SDNode *N); @@ -385,7 +386,9 @@ SDValue PromoteIntOp_UINT_TO_FP(SDNode *N); SDValue PromoteIntOp_STRICT_UINT_TO_FP(SDNode *N); SDValue PromoteIntOp_ZERO_EXTEND(SDNode *N); + SDValue PromoteIntOp_VP_STORE(VPStoreSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo); + SDValue PromoteIntOp_VP_LOAD(VPLoadSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MGATHER(MaskedGatherSDNode *N, unsigned OpNo);