diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1318,6 +1318,10 @@ getIndexedMaskedStoreAction(IdxMode, VT.getSimpleVT()) == Custom); } + /// Returns true if the index type for a masked gather/scatter requires + /// extending + virtual bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const { return false; } + // Returns true if VT is a legal index type for masked gathers/scatters // on this target virtual bool shouldRemoveExtendFromGSIndex(EVT VT) const { return false; } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4339,6 +4339,14 @@ IndexType = ISD::SIGNED_UNSCALED; Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } + + EVT IdxVT = Index.getValueType(); + EVT EltTy = IdxVT.getVectorElementType(); + if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) { + EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy); + Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index); + } + SDValue Ops[] = { getMemoryRoot(), Src0, Mask, Base, Index, Scale }; SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO, IndexType, false); @@ -4450,6 +4458,14 @@ IndexType = ISD::SIGNED_UNSCALED; Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } + + EVT IdxVT = Index.getValueType(); + EVT EltTy = IdxVT.getVectorElementType(); + if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) { + EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy); + Index = DAG.getNode(ISD::SIGN_EXTEND, sdl, NewIdxVT, Index); + } + SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale }; SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO, IndexType, ISD::NON_EXTLOAD); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -996,6 +996,7 @@ return TargetLowering::getInlineAsmMemConstraint(ConstraintCode); } + bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const override; bool shouldRemoveExtendFromGSIndex(EVT VT) const override; bool isVectorLoadExtDesirable(SDValue ExtVal) const override; bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -873,9 +873,6 @@ if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); - setTargetDAGCombine(ISD::MGATHER); - setTargetDAGCombine(ISD::MSCATTER); - setTargetDAGCombine(ISD::MUL); setTargetDAGCombine(ISD::SELECT); @@ -3825,6 +3822,15 @@ } } +bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const { + if (VT.getVectorElementType() == MVT::i8 || + VT.getVectorElementType() == MVT::i16) { + EltTy = MVT::i32; + return true; + } + return false; +} + bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const { if (VT.getVectorElementType() == MVT::i32 && VT.getVectorElementCount().getKnownMinValue() >= 4) @@ -14395,55 +14401,6 @@ return SDValue(); } -static SDValue performMaskedGatherScatterCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - SelectionDAG &DAG) { - MaskedGatherScatterSDNode *MGS = cast(N); - assert(MGS && "Can only combine gather load or scatter store nodes"); - - SDLoc DL(MGS); - SDValue Chain = MGS->getChain(); - SDValue Scale = MGS->getScale(); - SDValue Index = MGS->getIndex(); - SDValue Mask = MGS->getMask(); - SDValue BasePtr = MGS->getBasePtr(); - ISD::MemIndexType IndexType = MGS->getIndexType(); - - EVT IdxVT = Index.getValueType(); - - if (DCI.isBeforeLegalize()) { - // SVE gather/scatter requires indices of i32/i64. Promote anything smaller - // prior to legalisation so the result can be split if required. - if ((IdxVT.getVectorElementType() == MVT::i8) || - (IdxVT.getVectorElementType() == MVT::i16)) { - EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); - if (MGS->isIndexSigned()) - Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); - else - Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); - - if (auto *MGT = dyn_cast(MGS)) { - SDValue PassThru = MGT->getPassThru(); - SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale }; - return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), - PassThru.getValueType(), DL, Ops, - MGT->getMemOperand(), - MGT->getIndexType(), MGT->getExtensionType()); - } else { - auto *MSC = cast(MGS); - SDValue Data = MSC->getValue(); - SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); - } - } - } - - return SDValue(); -} - /// Target-specific DAG combine function for NEON load/store intrinsics /// to merge base address updates. static SDValue performNEONPostLDSTCombine(SDNode *N, @@ -15638,9 +15595,6 @@ break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); - case ISD::MGATHER: - case ISD::MSCATTER: - return performMaskedGatherScatterCombine(N, DCI, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: