diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -475,15 +475,13 @@ SDValue lowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVECTOR_REVERSE(SDValue Op, SelectionDAG &DAG) const; SDValue lowerABS(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerMLOAD(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorFCOPYSIGNToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerMGATHERMSCATTER(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const; - SDValue lowerFixedLengthVectorMaskedLoadToRVV(SDValue Op, - SelectionDAG &DAG) const; - SDValue lowerFixedLengthVectorMaskedStoreToRVV(SDValue Op, - SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorLogicOpToRVV(SDValue Op, SelectionDAG &DAG, unsigned MaskOpc, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -474,6 +474,8 @@ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); + setOperationAction(ISD::MSTORE, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); @@ -517,6 +519,8 @@ setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); setOperationAction(ISD::FCOPYSIGN, VT, Legal); + setOperationAction(ISD::MLOAD, VT, Custom); + setOperationAction(ISD::MSTORE, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); @@ -1653,9 +1657,9 @@ case ISD::STORE: return lowerFixedLengthVectorStoreToRVV(Op, DAG); case ISD::MLOAD: - return lowerFixedLengthVectorMaskedLoadToRVV(Op, DAG); + return lowerMLOAD(Op, DAG); case ISD::MSTORE: - return lowerFixedLengthVectorMaskedStoreToRVV(Op, DAG); + return lowerMSTORE(Op, DAG); case ISD::SETCC: return lowerFixedLengthVectorSetccToRVV(Op, DAG); case ISD::ADD: @@ -3192,50 +3196,63 @@ Store->getMemoryVT(), Store->getMemOperand()); } -SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedLoadToRVV( - SDValue Op, SelectionDAG &DAG) const { +SDValue RISCVTargetLowering::lowerMLOAD(SDValue Op, SelectionDAG &DAG) const { auto *Load = cast(Op); SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); - MVT ContainerVT = getContainerForFixedLengthVector(VT); - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); MVT XLenVT = Subtarget.getXLenVT(); - SDValue Mask = - convertToScalableVector(MaskVT, Load->getMask(), DAG, Subtarget); - SDValue PassThru = - convertToScalableVector(ContainerVT, Load->getPassThru(), DAG, Subtarget); - SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); + SDValue Mask = Load->getMask(); + SDValue PassThru = Load->getPassThru(); + SDValue VL; + + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(VT); + MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); + + Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); + PassThru = convertToScalableVector(ContainerVT, PassThru, DAG, Subtarget); + VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); + } else + VL = DAG.getRegister(RISCV::X0, XLenVT); SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other}); SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vle_mask, DL, XLenVT); SDValue Ops[] = {Load->getChain(), IntID, PassThru, Load->getBasePtr(), Mask, VL}; - SDValue NewLoad = + SDValue Result = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, VTs, Ops, Load->getMemoryVT(), Load->getMemOperand()); + SDValue Chain = Result.getValue(1); - SDValue Result = convertFromScalableVector(VT, NewLoad, DAG, Subtarget); - return DAG.getMergeValues({Result, NewLoad.getValue(1)}, DL); + if (VT.isFixedLengthVector()) + Result = convertFromScalableVector(VT, Result, DAG, Subtarget); + + return DAG.getMergeValues({Result, Chain}, DL); } -SDValue RISCVTargetLowering::lowerFixedLengthVectorMaskedStoreToRVV( - SDValue Op, SelectionDAG &DAG) const { +SDValue RISCVTargetLowering::lowerMSTORE(SDValue Op, SelectionDAG &DAG) const { auto *Store = cast(Op); SDLoc DL(Op); SDValue Val = Store->getValue(); + SDValue Mask = Store->getMask(); MVT VT = Val.getSimpleValueType(); - MVT ContainerVT = getContainerForFixedLengthVector(VT); - MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); MVT XLenVT = Subtarget.getXLenVT(); + SDValue VL; - Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget); - SDValue Mask = - convertToScalableVector(MaskVT, Store->getMask(), DAG, Subtarget); + MVT ContainerVT = VT; + if (VT.isFixedLengthVector()) { + ContainerVT = getContainerForFixedLengthVector(VT); + MVT MaskVT = MVT::getVectorVT(MVT::i1, ContainerVT.getVectorElementCount()); - SDValue VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); + Val = convertToScalableVector(ContainerVT, Val, DAG, Subtarget); + Mask = convertToScalableVector(MaskVT, Mask, DAG, Subtarget); + VL = DAG.getConstant(VT.getVectorNumElements(), DL, XLenVT); + } else + VL = DAG.getRegister(RISCV::X0, XLenVT); SDValue IntID = DAG.getTargetConstant(Intrinsic::riscv_vse_mask, DL, XLenVT); return DAG.getMemIntrinsicNode( diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -33,21 +33,6 @@ def SplatPat_simm5 : ComplexPattern; def SplatPat_uimm5 : ComplexPattern; -def masked_load : - PatFrag<(ops node:$ptr, node:$mask, node:$maskedoff), - (masked_ld node:$ptr, undef, node:$mask, node:$maskedoff), [{ - return !cast(N)->isExpandingLoad() && - cast(N)->getExtensionType() == ISD::NON_EXTLOAD && - cast(N)->isUnindexed(); -}]>; -def masked_store : - PatFrag<(ops node:$val, node:$ptr, node:$mask), - (masked_st node:$val, node:$ptr, undef, node:$mask), [{ - return !cast(N)->isTruncatingStore() && - !cast(N)->isCompressingStore() && - cast(N)->isUnindexed(); -}]>; - class SwapHelper { dag Value = !con(Prefix, !if(swap, B, A), !if(swap, A, B), Suffix); } @@ -68,25 +53,6 @@ (store_instr reg_class:$rs2, BaseAddr:$rs1, avl, sew)>; } -multiclass VPatUSLoadStoreSDNodeMask -{ - defvar load_instr = !cast("PseudoVLE"#sew#"_V_"#vlmul.MX#"_MASK"); - defvar store_instr = !cast("PseudoVSE"#sew#"_V_"#vlmul.MX#"_MASK"); - // Load - def : Pat<(type (masked_load BaseAddr:$rs1, (mask_type V0), type:$merge)), - (load_instr reg_class:$merge, BaseAddr:$rs1, (mask_type V0), - avl, sew)>; - // Store - def : Pat<(masked_store type:$rs2, BaseAddr:$rs1, (mask_type V0)), - (store_instr reg_class:$rs2, BaseAddr:$rs1, (mask_type V0), - avl, sew)>; -} - multiclass VPatUSLoadStoreWholeVRSDNode; -foreach vti = AllVectors in - defm "" : VPatUSLoadStoreSDNodeMask; foreach vti = [VI8M1, VI16M1, VI32M1, VI64M1, VF16M1, VF32M1, VF64M1] in defm "" : VPatUSLoadStoreWholeVRSDNode;