diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -70,6 +70,9 @@ case ISD::EXTRACT_VECTOR_ELT: Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break; case ISD::LOAD: Res = PromoteIntRes_LOAD(cast(N)); break; + case ISD::VP_LOAD: + Res = PromoteIntRes_VP_LOAD(cast(N)); + break; case ISD::MLOAD: Res = PromoteIntRes_MLOAD(cast(N)); break; case ISD::MGATHER: Res = PromoteIntRes_MGATHER(cast(N)); @@ -708,6 +711,23 @@ return Res; } +SDValue DAGTypeLegalizer::PromoteIntRes_VP_LOAD(VPLoadSDNode *N) { + assert(ISD::isUNINDEXEDLoad(N) && + "Indexed vp_load during type legalization!"); + EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + ISD::LoadExtType ExtType = + ISD::isNON_EXTLoad(N) ? ISD::EXTLOAD : N->getExtensionType(); + SDLoc dl(N); + SDValue Res = + DAG.getLoadVP(N->getAddressingMode(), ExtType, NVT, dl, N->getChain(), + N->getBasePtr(), N->getOffset(), N->getMask(), + N->getVectorLength(), N->getMemoryVT(), N->getMemOperand()); + // Legalize the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); + return Res; +} + SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) { EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); SDValue ExtPassThru = GetPromotedInteger(N->getPassThru()); @@ -1519,8 +1539,14 @@ case ISD::STRICT_SINT_TO_FP: Res = PromoteIntOp_STRICT_SINT_TO_FP(N); break; case ISD::STORE: Res = PromoteIntOp_STORE(cast(N), OpNo); break; + case ISD::VP_STORE: + Res = PromoteIntOp_VP_STORE(cast(N), OpNo); + break; case ISD::MSTORE: Res = PromoteIntOp_MSTORE(cast(N), OpNo); break; + case ISD::VP_LOAD: + Res = PromoteIntOp_VP_LOAD(cast(N), OpNo); + break; case ISD::MLOAD: Res = PromoteIntOp_MLOAD(cast(N), OpNo); break; case ISD::MGATHER: Res = PromoteIntOp_MGATHER(cast(N), @@ -1840,6 +1866,48 @@ N->getMemoryVT(), N->getMemOperand()); } +SDValue DAGTypeLegalizer::PromoteIntOp_VP_STORE(VPStoreSDNode *N, + unsigned OpNo) { + SDValue DataOp = N->getValue(); + SDValue Operand = N->getOperand(OpNo); + + if (OpNo >= 4) { + // The Mask or EVL. Update in place. + EVT DataVT = DataOp.getValueType(); + SDValue PromotedOperand = OpNo == 4 ? PromoteTargetBoolean(Operand, DataVT) + : GetPromotedInteger(Operand); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = PromotedOperand; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); + } + + assert(OpNo == 1 && "Unexpected operand for promotion"); + DataOp = GetPromotedInteger(DataOp); + + return DAG.getTruncStoreVP(N->getChain(), SDLoc(N), DataOp, N->getBasePtr(), + N->getMask(), N->getVectorLength(), + N->getMemoryVT(), N->getMemOperand(), + N->isCompressingStore()); +} + +SDValue DAGTypeLegalizer::PromoteIntOp_VP_LOAD(VPLoadSDNode *N, unsigned OpNo) { + assert(OpNo >= 3 && "Only know how to promote the mask or length!"); + EVT DataVT = N->getValueType(0); + SDValue Operand = N->getOperand(OpNo); + SDValue PromotedOperand = OpNo == 3 ? PromoteTargetBoolean(Operand, DataVT) + : GetPromotedInteger(Operand); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = PromotedOperand; + SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); + if (Res == N) + return SDValue(Res, 0); + + // Update triggered CSE, do our own replacement since caller can't. + ReplaceValueWith(SDValue(N, 0), SDValue(Res, 0)); + ReplaceValueWith(SDValue(N, 1), SDValue(Res, 1)); + return SDValue(); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo) { SDValue DataOp = N->getValue(); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -324,6 +324,7 @@ SDValue PromoteIntRes_FREEZE(SDNode *N); SDValue PromoteIntRes_INT_EXTEND(SDNode *N); SDValue PromoteIntRes_LOAD(LoadSDNode *N); + SDValue PromoteIntRes_VP_LOAD(VPLoadSDNode *N); SDValue PromoteIntRes_MLOAD(MaskedLoadSDNode *N); SDValue PromoteIntRes_MGATHER(MaskedGatherSDNode *N); SDValue PromoteIntRes_Overflow(SDNode *N); @@ -385,7 +386,9 @@ SDValue PromoteIntOp_UINT_TO_FP(SDNode *N); SDValue PromoteIntOp_STRICT_UINT_TO_FP(SDNode *N); SDValue PromoteIntOp_ZERO_EXTEND(SDNode *N); + SDValue PromoteIntOp_VP_STORE(VPStoreSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo); + SDValue PromoteIntOp_VP_LOAD(VPLoadSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MGATHER(MaskedGatherSDNode *N, unsigned OpNo);