diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -849,6 +849,7 @@ if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::MGATHER); setTargetDAGCombine(ISD::MSCATTER); setTargetDAGCombine(ISD::MUL); @@ -14063,20 +14064,19 @@ return SDValue(); } -static SDValue performMSCATTERCombine(SDNode *N, +static SDValue performMaskedGatherScatterCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - MaskedScatterSDNode *MSC = cast(N); - assert(MSC && "Can only combine scatter store nodes"); + MaskedGatherScatterSDNode *MGS = cast(N); + assert(MGS && "Can only combine gather load or scatter store nodes"); - SDLoc DL(MSC); - SDValue Chain = MSC->getChain(); - SDValue Scale = MSC->getScale(); - SDValue Index = MSC->getIndex(); - SDValue Data = MSC->getValue(); - SDValue Mask = MSC->getMask(); - SDValue BasePtr = MSC->getBasePtr(); - ISD::MemIndexType IndexType = MSC->getIndexType(); + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Scale = MGS->getScale(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + ISD::MemIndexType IndexType = MGS->getIndexType(); EVT IdxVT = Index.getValueType(); @@ -14086,16 +14086,27 @@ if ((IdxVT.getVectorElementType() == MVT::i8) || (IdxVT.getVectorElementType() == MVT::i16)) { EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); - if (MSC->isIndexSigned()) + if (MGS->isIndexSigned()) Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); else Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); - SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); + if (auto *MGT = dyn_cast(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + PassThru.getValueType(), DL, Ops, + MGT->getMemOperand(), + MGT->getIndexType(), MGT->getExtensionType()); + } else { + auto *MSC = cast(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + MSC->getMemoryVT(), DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } } } @@ -15072,9 +15083,6 @@ static SDValue performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - SDLoc DL(N); SDValue Src = N->getOperand(0); unsigned Opc = Src->getOpcode(); @@ -15109,6 +15117,9 @@ return DAG.getNode(SOpc, DL, N->getValueType(0), Ext); } + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + if (!EnableCombineMGatherIntrinsics) return SDValue(); @@ -15296,8 +15307,9 @@ break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MGATHER: case ISD::MSCATTER: - return performMSCATTERCombine(N, DCI, DAG); + return performMaskedGatherScatterCombine(N, DCI, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: diff --git a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll @@ -54,6 +54,19 @@ ret %data } +; Code generate the worst case scenario when all vector types are legal. +define @masked_gather_nxv16i8(i8* %base, %indices, %mask) { +; CHECK-LABEL: masked_gather_nxv16i8: +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK: ret + %ptrs = getelementptr i8, i8* %base, %indices + %data = call @llvm.masked.gather.nxv16i8( %ptrs, i32 1, %mask, undef) + ret %data +} + ; Code generate the worst case scenario when all vector types are illegal. define @masked_gather_nxv32i32(i32* %base, %indices, %mask) { ; CHECK-LABEL: masked_gather_nxv32i32: