Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -782,6 +782,7 @@ if (Subtarget->supportsAddressTopByteIgnored()) setTargetDAGCombine(ISD::LOAD); + setTargetDAGCombine(ISD::MGATHER); setTargetDAGCombine(ISD::MSCATTER); setTargetDAGCombine(ISD::MUL); @@ -13996,20 +13997,19 @@ return SDValue(); } -static SDValue performMSCATTERCombine(SDNode *N, +static SDValue performMaskedGatherScatterCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - MaskedScatterSDNode *MSC = cast(N); - assert(MSC && "Can only combine scatter store nodes"); + MaskedGatherScatterSDNode *MGS = cast(N); + assert(MGS && "Can only combine gather load or scatter store nodes"); - SDLoc DL(MSC); - SDValue Chain = MSC->getChain(); - SDValue Scale = MSC->getScale(); - SDValue Index = MSC->getIndex(); - SDValue Data = MSC->getValue(); - SDValue Mask = MSC->getMask(); - SDValue BasePtr = MSC->getBasePtr(); - ISD::MemIndexType IndexType = MSC->getIndexType(); + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Scale = MGS->getScale(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + ISD::MemIndexType IndexType = MGS->getIndexType(); EVT IdxVT = Index.getValueType(); @@ -14019,16 +14019,27 @@ if ((IdxVT.getVectorElementType() == MVT::i8) || (IdxVT.getVectorElementType() == MVT::i16)) { EVT NewIdxVT = IdxVT.changeVectorElementType(MVT::i32); - if (MSC->isIndexSigned()) + if (MGS->isIndexSigned()) Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); else Index = DAG.getNode(ISD::ZERO_EXTEND, DL, NewIdxVT, Index); - SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), - MSC->getMemoryVT(), DL, Ops, - MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); + if (auto *MGT = dyn_cast(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = { Chain, PassThru, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + PassThru.getValueType(), DL, Ops, + MGT->getMemOperand(), + MGT->getIndexType(), MGT->getExtensionType()); + } else { + auto *MSC = cast(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = { Chain, Data, Mask, BasePtr, Index, Scale }; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), + MSC->getMemoryVT(), DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } } } @@ -15005,9 +15016,6 @@ static SDValue performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { - if (DCI.isBeforeLegalizeOps()) - return SDValue(); - SDLoc DL(N); SDValue Src = N->getOperand(0); unsigned Opc = Src->getOpcode(); @@ -15042,6 +15050,9 @@ return DAG.getNode(SOpc, DL, N->getValueType(0), Ext); } + if (DCI.isBeforeLegalizeOps()) + return SDValue(); + // SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes. unsigned NewOpc; @@ -15226,8 +15237,9 @@ break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MGATHER: case ISD::MSCATTER: - return performMSCATTERCombine(N, DCI, DAG); + return performMaskedGatherScatterCombine(N, DCI, DAG); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); case AArch64ISD::TBNZ: Index: llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll @@ -0,0 +1,52 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; Tests that exercise various type legalisation scenarios for ISD::MGATHER. + +; Code generate the worst case scenario when all vector types are legal. +define @masked_gather_nxv16i8(i8* %base, %indices, %mask) { +; CHECK-LABEL: masked_gather_nxv16i8: +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, {{z[0-9]+}}.s, sxtw] +; CHECK: ret + %ptrs = getelementptr i8, i8* %base, %indices + %data = call @llvm.masked.gather.nxv16i8( %ptrs, i32 1, %mask, undef) + ret %data +} + +; Code generate the worst case scenario when all vector types are illegal. +define @masked_gather_nxv32i32(i32* %base, %indices, %mask) { +; CHECK-LABEL: masked_gather_nxv32i32: +; CHECK-NOT: unpkhi +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z0.s, sxtw #2] +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z1.s, sxtw #2] +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z2.s, sxtw #2] +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z3.s, sxtw #2] +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z4.s, sxtw #2] +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z5.s, sxtw #2] +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z6.s, sxtw #2] +; CHECK-DAG: ld1w { {{z[0-9]+}}.s }, {{p[0-9]+}}/z, [x0, z7.s, sxtw #2] +; CHECK: ret + %ptrs = getelementptr i32, i32* %base, %indices + %data = call @llvm.masked.gather.nxv32i32( %ptrs, i32 4, %mask, undef) + ret %data +} + +; TODO: Currently, the sign extend gets applied to the values after a 'uzp1' of two +; registers, so it doesn't get folded away. Same for any other vector-of-pointers +; style gathers which don't fit in an single register. Better folding +; is required before we can check those off. +define @masked_sgather_nxv4i8( %ptrs, %mask) { +; CHECK-LABEL: masked_sgather_nxv4i8: +; CHECK-DAG: ld1sb { {{z[0-9]+}}.d }, {{p[0-9]+}}/z, [x8, {{z[0-9]+}}.d] +; CHECK-DAG: ld1sb { {{z[0-9]+}}.d }, {{p[0-9]+}}/z, [x8, {{z[0-9]+}}.d] + %vals = call @llvm.masked.gather.nxv4i8( %ptrs, i32 1, %mask, undef) + %svals = sext %vals to + ret %svals +} + +declare @llvm.masked.gather.nxv4i8(, i32, , ) + +declare @llvm.masked.gather.nxv16i8(, i32, , ) +declare @llvm.masked.gather.nxv32i32(, i32, , )