Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16770,14 +16770,117 @@ return true; } +static SDValue tryToCombineToMaskedLoadStore(MaskedGatherScatterSDNode *MGS, + SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + auto *MGT = dyn_cast(MGS); + auto *MSC = dyn_cast(MGS); + if (MGT && !MGT->getPassThru().isUndef()) + return SDValue(); + + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + + EVT MemVT = MGS->getMemoryVT(); + EVT MaskVT = Mask.getValueType(); + EVT BaseVT = BasePtr.getValueType(); + + EVT VT = MGT ? MGT->getValueType(0) : MSC->getValue().getValueType(); + + // Step is a ConstantSDNode since Index is always a STEP_VECTOR + auto Step = cast(Index->getOperand(0))->getSExtValue(); + + if (Step != 2 || + Subtarget->preferGatherScatter(MGS->getOpcode(), MemVT, Step)) + return SDValue(); + + unsigned ShiftAmt; + if (MemVT == MVT::nxv2i64 || MemVT == MVT::nxv2f64) + ShiftAmt = 3; + else if (MemVT == MVT::nxv4i32 || MemVT == MVT::nxv4f32) + ShiftAmt = 2; + else + return SDValue(); + + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + MGS->getPointerInfo(), MGS->getMemOperand()->getFlags(), + MemoryLocation::UnknownSize, MGS->getOriginalAlign(), + MGS->getAAInfo(), MGS->getRanges()); + + auto MPI = MachinePointerInfo(MGS->getPointerInfo().getAddrSpace()); + + // Split the mask into two parts and interleave with false + auto PFalse = SDValue(DAG.getMachineNode(AArch64::PFALSE, DL, MaskVT), 0); + auto PredL = DAG.getNode(AArch64ISD::ZIP1, DL, MaskVT, Mask, PFalse); + auto PredH = DAG.getNode(AArch64ISD::ZIP2, DL, MaskVT, Mask, PFalse); + + auto EltCount = + DAG.getVScale(DL, MVT::i64, + APInt(VT.getScalarSizeInBits(), + VT.getVectorMinNumElements() << ShiftAmt)); + + if (MGT) { + // Perform the actual loads, making sure we use the chain value + // from the first in the second. + auto LoadL = DAG.getMaskedLoad(VT, DL, Chain, BasePtr, DAG.getUNDEF(BaseVT), + PredL, MGT->getPassThru(), MemVT, MMO, + ISD::UNINDEXED, MGT->getExtensionType()); + + MMO = DAG.getMachineFunction().getMachineMemOperand( + MPI, MGT->getMemOperand()->getFlags(), MemoryLocation::UnknownSize, + MGT->getOriginalAlign(), MGT->getAAInfo(), MGT->getRanges()); + + BasePtr = DAG.getNode(ISD::ADD, DL, BaseVT, BasePtr, EltCount); + auto LoadH = DAG.getMaskedLoad(VT, DL, LoadL.getValue(1), BasePtr, + DAG.getUNDEF(BaseVT), PredH, + MGT->getPassThru(), MemVT, MMO, + ISD::UNINDEXED, MGT->getExtensionType()); + + // Combine the loaded lanes into the single vector we'd get from + // the original gather. + auto Res = DAG.getNode(AArch64ISD::UZP1, DL, VT, LoadL.getValue(0), + LoadH.getValue(0)); + + // Make sure we return both the loaded values and the chain from the + // last load. + return DAG.getMergeValues({Res, LoadH.getValue(1)}, DL); + } + + SDValue Data = MSC->getValue(); + + // As with the mask, split the data into two parts and interleave; + // Here we can just use the data itself for the other lanes, since + // the inactive lanes won't be stored. + auto DataL = DAG.getNode(AArch64ISD::ZIP1, DL, VT, Data, Data); + auto DataH = DAG.getNode(AArch64ISD::ZIP2, DL, VT, Data, Data); + + // Perform the actual stores, making sure we use the chain value + // from the first in the second. + auto StoreL = DAG.getMaskedStore(Chain, DL, DataL, BasePtr, + DAG.getUNDEF(BaseVT), PredL, MemVT, MMO, + ISD::UNINDEXED, MSC->isTruncatingStore()); + + MMO = DAG.getMachineFunction().getMachineMemOperand( + MPI, MSC->getMemOperand()->getFlags(), MemoryLocation::UnknownSize, + MSC->getOriginalAlign(), MSC->getAAInfo(), MSC->getRanges()); + + BasePtr = DAG.getNode(ISD::ADD, DL, BaseVT, BasePtr, EltCount); + auto StoreH = DAG.getMaskedStore(StoreL, DL, DataH, BasePtr, + DAG.getUNDEF(BaseVT), PredH, MemVT, MMO, + ISD::UNINDEXED, MSC->isTruncatingStore()); + + return StoreH; +} + static SDValue performMaskedGatherScatterCombine( - SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { MaskedGatherScatterSDNode *MGS = cast(N); assert(MGS && "Can only combine gather load or scatter store nodes"); - if (!DCI.isBeforeLegalize()) - return SDValue(); - SDLoc DL(MGS); SDValue Chain = MGS->getChain(); SDValue Scale = MGS->getScale(); @@ -16785,25 +16888,34 @@ SDValue Mask = MGS->getMask(); SDValue BasePtr = MGS->getBasePtr(); ISD::MemIndexType IndexType = MGS->getIndexType(); + EVT MemVT = MGS->getMemoryVT(); + + if (DCI.isBeforeLegalize() && + findMoreOptimalIndexType(MGS, BasePtr, Index, DAG)) { + // Here we catch such cases early and change MGATHER's IndexType to allow + // the use of an Index that's more legalisation friendly. + if (auto *MGT = dyn_cast(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + MemVT, DL, Ops, MGT->getMemOperand(), + IndexType, MGT->getExtensionType()); + } + auto *MSC = cast(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } - if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG)) - return SDValue(); + // Attempt to transform a gather or scatter with a stride of 2 into + // a pair of contiguous loads or stores + if (Index->getOpcode() == ISD::STEP_VECTOR && + IndexType == ISD::SIGNED_SCALED) + return tryToCombineToMaskedLoadStore(MGS, DAG, Subtarget); - // Here we catch such cases early and change MGATHER's IndexType to allow - // the use of an Index that's more legalisation friendly. - if (auto *MGT = dyn_cast(MGS)) { - SDValue PassThru = MGT->getPassThru(); - SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedGather( - DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, - Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); - } - auto *MSC = cast(MGS); - SDValue Data = MSC->getValue(); - SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, - Ops, MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); + return SDValue(); } /// Target-specific DAG combine function for NEON load/store intrinsics @@ -18354,7 +18466,7 @@ return performSTORECombine(N, DCI, DAG, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: - return performMaskedGatherScatterCombine(N, DCI, DAG); + return performMaskedGatherScatterCombine(N, DCI, DAG, Subtarget); case ISD::VECTOR_SPLICE: return performSVESpliceCombine(N, DAG); case ISD::FP_EXTEND: Index: llvm/lib/Target/AArch64/AArch64Subtarget.h =================================================================== --- llvm/lib/Target/AArch64/AArch64Subtarget.h +++ llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -666,6 +666,11 @@ void mirFileLoaded(MachineFunction &MF) const override; + // Returns true if it is preferable to use the given gather or scatter for + // the current subtarget. Otherwise we may attempt to create a pair of + // contiguous loads/stores instead. + bool preferGatherScatter(unsigned Opcode, EVT MemTy, unsigned Stride) const; + // Return the known range for the bit length of SVE data registers. A value // of 0 means nothing is known about that particular limit beyong what's // implied by the architecture. Index: llvm/lib/Target/AArch64/AArch64Subtarget.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64Subtarget.cpp +++ llvm/lib/Target/AArch64/AArch64Subtarget.cpp @@ -361,6 +361,27 @@ return false; } +bool AArch64Subtarget::preferGatherScatter(unsigned Opcode, EVT MemTy, + unsigned Stride) const { + if (Stride != 2 || !MemTy.isSimple()) + return true; + + if (ARMProcFamily == NeoverseV1) { + switch (MemTy.getSimpleVT().SimpleTy) { + case MVT::nxv4i32: + case MVT::nxv4f32: + return false; + case MVT::nxv2i64: + case MVT::nxv2f64: + return Opcode == ISD::MSCATTER; + default: + return true; + } + } + + return true; +} + std::unique_ptr AArch64Subtarget::getCustomPBQPConstraints() const { return balanceFPOps() ? std::make_unique() : nullptr; Index: llvm/test/CodeGen/AArch64/sve-gather-scatter-to-contiguous.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-gather-scatter-to-contiguous.ll @@ -0,0 +1,168 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mcpu=neoverse-v1 < %s | FileCheck %s + +; Gather -> Contiguous Load + +define @gather_stride2_4i32(i32* %addr, %mask) { +; CHECK-LABEL: gather_stride2_4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 p2.s, p0.s, p1.s +; CHECK-NEXT: zip2 p0.s, p0.s, p1.s +; CHECK-NEXT: ld1w { z0.s }, p2/z, [x0] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s +; CHECK-NEXT: ret + %1 = insertelement undef, i32 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv4i32() + %indices = mul %2, %stepvector + %ptrs = getelementptr i32, i32* %addr, %indices + %vals = call @llvm.masked.gather.nxv4i32( %ptrs, i32 4, %mask, undef) + ret %vals +} + +define @gather_stride2_2i64(i64* %addr, %mask) { +; CHECK-LABEL: gather_stride2_2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 p2.d, p0.d, p1.d +; CHECK-NEXT: zip2 p0.d, p0.d, p1.d +; CHECK-NEXT: ld1d { z0.d }, p2/z, [x0] +; CHECK-NEXT: ld1d { z1.d }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: uzp1 z0.d, z0.d, z1.d +; CHECK-NEXT: ret + %1 = insertelement undef, i64 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv2i64() + %indices = mul %2, %stepvector + %ptrs = getelementptr i64, i64* %addr, %indices + %vals = call @llvm.masked.gather.nxv2i64( %ptrs, i32 8, %mask, undef) + ret %vals +} + +define @gather_stride2_2f64(double* %addr, %mask) { +; CHECK-LABEL: gather_stride2_2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 p2.d, p0.d, p1.d +; CHECK-NEXT: zip2 p0.d, p0.d, p1.d +; CHECK-NEXT: ld1d { z0.d }, p2/z, [x0] +; CHECK-NEXT: ld1d { z1.d }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: uzp1 z0.d, z0.d, z1.d +; CHECK-NEXT: ret + %1 = insertelement undef, i64 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv2i64() + %indices = mul %2, %stepvector + %ptrs = getelementptr double, double* %addr, %indices + %vals = call @llvm.masked.gather.nxv2f64( %ptrs, i32 4, %mask, undef) + ret %vals +} + +; Do not combine an extending gather to a pair of contiguous loads +define @extending_gather_stride2_2i32(i32* %addr, %mask) { +; CHECK-LABEL: extending_gather_stride2_2i32: +; CHECK: // %bb.0: +; CHECK-NEXT: index z0.d, #0, #2 +; CHECK-NEXT: ld1w { z0.d }, p0/z, [x0, z0.d, sxtw #2] +; CHECK-NEXT: ret + %1 = insertelement undef, i32 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv2i32() + %indices = mul %2, %stepvector + %ptrs = getelementptr i32, i32* %addr, %indices + %vals = call @llvm.masked.gather.nxv2i32( %ptrs, i32 4, %mask, undef) + %vals.zext = zext %vals to + ret %vals.zext +} + +; Scatter -> Contiguous Store + +define void @scatter_stride2_4i32(i32* %addr, %data, %mask) { +; CHECK-LABEL: scatter_stride2_4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 z1.s, z0.s, z0.s +; CHECK-NEXT: zip2 z0.s, z0.s, z0.s +; CHECK-NEXT: zip1 p2.s, p0.s, p1.s +; CHECK-NEXT: zip2 p0.s, p0.s, p1.s +; CHECK-NEXT: st1w { z1.s }, p2, [x0] +; CHECK-NEXT: st1w { z0.s }, p0, [x0, #1, mul vl] +; CHECK-NEXT: ret + %1 = insertelement undef, i32 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv4i32() + %indices = mul %2, %stepvector + %ptrs = getelementptr i32, i32* %addr, %indices + call void @llvm.masked.scatter.nxv4i32( %data, %ptrs, i32 4, %mask) + ret void +} + +define void @scatter_stride2_4f32(float* %addr, %data, %mask) { +; CHECK-LABEL: scatter_stride2_4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 z1.s, z0.s, z0.s +; CHECK-NEXT: zip2 z0.s, z0.s, z0.s +; CHECK-NEXT: zip1 p2.s, p0.s, p1.s +; CHECK-NEXT: zip2 p0.s, p0.s, p1.s +; CHECK-NEXT: st1w { z1.s }, p2, [x0] +; CHECK-NEXT: st1w { z0.s }, p0, [x0, #1, mul vl] +; CHECK-NEXT: ret + %1 = insertelement undef, i32 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv4i32() + %indices = mul %2, %stepvector + %ptrs = getelementptr float, float* %addr, %indices + call void @llvm.masked.scatter.nxv4f32( %data, %ptrs, i32 4, %mask) + ret void +} + +; Contiguous stores are not beneficial over 64 bit scatters here +define void @scatter_stride2_2i64(i64* %addr, %data, %mask) { +; CHECK-LABEL: scatter_stride2_2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: index z1.d, #0, #2 +; CHECK-NEXT: st1d { z0.d }, p0, [x0, z1.d, lsl #3] +; CHECK-NEXT: ret + %1 = insertelement undef, i64 2, i64 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv2i64() + %indices = mul %2, %stepvector + %ptrs = getelementptr i64, i64* %addr, %indices + call void @llvm.masked.scatter.nxv2i64( %data, %ptrs, i32 4, %mask) + ret void +} + +; Do not combine a truncating scatter to a pair of contiguous stores +define void @truncating_scatter_stride2_2i32(i32* %addr, %data, %mask) { +; CHECK-LABEL: truncating_scatter_stride2_2i32: +; CHECK: // %bb.0: +; CHECK-NEXT: index z1.d, #0, #2 +; CHECK-NEXT: st1w { z0.d }, p0, [x0, z1.d, sxtw #2] +; CHECK-NEXT: ret + %1 = insertelement undef, i32 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv2i32() + %indices = mul %2, %stepvector + %ptrs = getelementptr i32, i32* %addr, %indices + call void @llvm.masked.scatter.nxv2i32( %data, %ptrs, i32 4, %mask) + ret void +} + +declare @llvm.experimental.stepvector.nxv2i32() +declare @llvm.experimental.stepvector.nxv4i32() +declare @llvm.experimental.stepvector.nxv4f32() +declare @llvm.experimental.stepvector.nxv2i64() +declare @llvm.experimental.stepvector.nxv2f64() + +declare @llvm.masked.gather.nxv2i32(, i32, , ) +declare @llvm.masked.gather.nxv4i32(, i32, , ) +declare @llvm.masked.gather.nxv2i64(, i32, , ) +declare @llvm.masked.gather.nxv2f64(, i32, , ) + +declare void @llvm.masked.scatter.nxv2i32(, , i32, ) +declare void @llvm.masked.scatter.nxv4i32(, , i32, ) +declare void @llvm.masked.scatter.nxv4f32(, , i32, ) +declare void @llvm.masked.scatter.nxv2i64(, , i32, )