diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -70,6 +70,9 @@ case ISD::EXTRACT_VECTOR_ELT: Res = PromoteIntRes_EXTRACT_VECTOR_ELT(N); break; case ISD::LOAD: Res = PromoteIntRes_LOAD(cast(N)); break; + case ISD::VP_LOAD: + Res = PromoteIntRes_VP_LOAD(cast(N)); + break; case ISD::MLOAD: Res = PromoteIntRes_MLOAD(cast(N)); break; case ISD::MGATHER: Res = PromoteIntRes_MGATHER(cast(N)); @@ -738,6 +741,24 @@ return Res; } +SDValue DAGTypeLegalizer::PromoteIntRes_VP_LOAD(VPLoadSDNode *N) { + assert((N->getAddressingMode() == ISD::UNINDEXED) && + "Indexed vp_load during type legalization!"); + EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); + ISD::LoadExtType ExtType = (N->getExtensionType() == ISD::NON_EXTLOAD) + ? ISD::EXTLOAD + : N->getExtensionType(); + SDLoc dl(N); + SDValue Res = + DAG.getLoadVP(N->getAddressingMode(), ExtType, NVT, dl, N->getChain(), + N->getBasePtr(), N->getOffset(), N->getMask(), + N->getVectorLength(), N->getMemoryVT(), N->getMemOperand()); + // Legalize the chain result - switch anything that used the old chain to + // use the new one. + ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); + return Res; +} + SDValue DAGTypeLegalizer::PromoteIntRes_MLOAD(MaskedLoadSDNode *N) { EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); SDValue ExtPassThru = GetPromotedInteger(N->getPassThru()); @@ -1564,8 +1585,14 @@ case ISD::STRICT_SINT_TO_FP: Res = PromoteIntOp_STRICT_SINT_TO_FP(N); break; case ISD::STORE: Res = PromoteIntOp_STORE(cast(N), OpNo); break; + case ISD::VP_STORE: + Res = PromoteIntOp_VP_STORE(cast(N), OpNo); + break; case ISD::MSTORE: Res = PromoteIntOp_MSTORE(cast(N), OpNo); break; + case ISD::VP_LOAD: + Res = PromoteIntOp_VP_LOAD(cast(N), OpNo); + break; case ISD::MLOAD: Res = PromoteIntOp_MLOAD(cast(N), OpNo); break; case ISD::MGATHER: Res = PromoteIntOp_MGATHER(cast(N), @@ -1885,6 +1912,48 @@ N->getMemoryVT(), N->getMemOperand()); } +SDValue DAGTypeLegalizer::PromoteIntOp_VP_STORE(VPStoreSDNode *N, + unsigned OpNo) { + SDValue DataOp = N->getValue(); + SDValue Operand = N->getOperand(OpNo); + + if (OpNo >= 4) { + // The Mask or EVL. Update in place. + EVT DataVT = DataOp.getValueType(); + SDValue PromotedOperand = OpNo == 4 ? PromoteTargetBoolean(Operand, DataVT) + : GetPromotedInteger(Operand); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = PromotedOperand; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); + } + + assert(OpNo == 1 && "Unexpected operand for promotion"); + DataOp = GetPromotedInteger(DataOp); + + return DAG.getTruncStoreVP(N->getChain(), SDLoc(N), DataOp, N->getBasePtr(), + N->getMask(), N->getVectorLength(), + N->getMemoryVT(), N->getMemOperand(), + N->isCompressingStore()); +} + +SDValue DAGTypeLegalizer::PromoteIntOp_VP_LOAD(VPLoadSDNode *N, unsigned OpNo) { + assert(OpNo >= 3 && "Only know how to promote the mask or length!"); + EVT DataVT = N->getValueType(0); + SDValue Operand = N->getOperand(OpNo); + SDValue PromotedOperand = OpNo == 3 ? PromoteTargetBoolean(Operand, DataVT) + : GetPromotedInteger(Operand); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = PromotedOperand; + SDNode *Res = DAG.UpdateNodeOperands(N, NewOps); + if (Res == N) + return SDValue(Res, 0); + + // Update triggered CSE, do our own replacement since caller can't. + ReplaceValueWith(SDValue(N, 0), SDValue(Res, 0)); + ReplaceValueWith(SDValue(N, 1), SDValue(Res, 1)); + return SDValue(); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo) { SDValue DataOp = N->getValue(); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -324,6 +324,7 @@ SDValue PromoteIntRes_FREEZE(SDNode *N); SDValue PromoteIntRes_INT_EXTEND(SDNode *N); SDValue PromoteIntRes_LOAD(LoadSDNode *N); + SDValue PromoteIntRes_VP_LOAD(VPLoadSDNode *N); SDValue PromoteIntRes_MLOAD(MaskedLoadSDNode *N); SDValue PromoteIntRes_MGATHER(MaskedGatherSDNode *N); SDValue PromoteIntRes_Overflow(SDNode *N); @@ -384,7 +385,9 @@ SDValue PromoteIntOp_UINT_TO_FP(SDNode *N); SDValue PromoteIntOp_STRICT_UINT_TO_FP(SDNode *N); SDValue PromoteIntOp_ZERO_EXTEND(SDNode *N); + SDValue PromoteIntOp_VP_STORE(VPStoreSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MSTORE(MaskedStoreSDNode *N, unsigned OpNo); + SDValue PromoteIntOp_VP_LOAD(VPLoadSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MSCATTER(MaskedScatterSDNode *N, unsigned OpNo); SDValue PromoteIntOp_MGATHER(MaskedGatherSDNode *N, unsigned OpNo); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -7371,7 +7371,7 @@ MachinePointerInfo(PtrOperand), MachineMemOperand::MOLoad, VT.getStoreSize().getKnownMinSize(), *Alignment, AAInfo, Ranges); LD = DAG.getLoadVP(VT, DL, InChain, OpValues[0], OpValues[1], OpValues[2], - MMO, false /*IsExpanding */); + MMO); } else { unsigned AS = PtrOperand->getType()->getScalarType()->getPointerAddressSpace();