Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16588,14 +16588,110 @@ return true; } +static SDValue tryToCombineToMaskedLoadStore(MaskedGatherScatterSDNode *MGS, + SelectionDAG &DAG) { + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + + EVT MemVT = MGS->getMemoryVT(); + EVT MaskVT = Mask.getValueType(); + EVT BaseVT = BasePtr.getValueType(); + + auto Step = cast(Index->getOperand(0))->getSExtValue(); + + if (Step != 2) + return SDValue(); + + unsigned ShiftAmt; + if (MemVT == MVT::nxv2i64 || MemVT == MVT::nxv2f64) + ShiftAmt = 3; + else if (MemVT == MVT::nxv4i32 || MemVT == MVT::nxv4f32) + ShiftAmt = 2; + else + return SDValue(); + + // Split the mask into two parts and interleave with false + auto PFalse = SDValue(DAG.getMachineNode(AArch64::PFALSE, DL, MaskVT), 0); + auto PredL = DAG.getNode(AArch64ISD::ZIP1, DL, MaskVT, Mask, PFalse); + auto PredH = DAG.getNode(AArch64ISD::ZIP2, DL, MaskVT, Mask, PFalse); + + if (auto *MGT = dyn_cast(MGS)) { + SDValue PassThrough = MGT->getPassThru(); + if (!PassThrough.isUndef()) + return SDValue(); + + EVT VT = MGT->getValueType(0); + auto EltCount = + DAG.getVScale(DL, MVT::i64, + APInt(VT.getScalarSizeInBits(), + VT.getVectorMinNumElements() << ShiftAmt)); + + // Perform the actual loads, making sure we use the chain value + // from the first in the second. + auto LoadL = DAG.getMaskedLoad(VT, DL, Chain, BasePtr, + DAG.getUNDEF(BaseVT), PredL, PassThrough, + MemVT, MGT->getMemOperand(), + ISD::UNINDEXED, MGT->getExtensionType()); + + BasePtr = DAG.getNode(ISD::ADD, DL, BaseVT, BasePtr, EltCount); + auto LoadH = DAG.getMaskedLoad(VT, DL, LoadL.getValue(1), BasePtr, + DAG.getUNDEF(BaseVT), PredH, PassThrough, + MemVT, MGT->getMemOperand(), + ISD::UNINDEXED, MGT->getExtensionType()); + + // Combine the loaded lanes into the single vector we'd get from + // the original gather. + auto Res = DAG.getNode(AArch64ISD::UZP1, DL, VT, LoadL.getValue(0), + LoadH.getValue(0)); + + // Make sure we return both the loaded values and the chain from the + // last load. + return DAG.getMergeValues({Res, LoadH.getValue(1)}, DL); + } + + auto *MSC = cast(MGS); + SDValue Data = MSC->getValue(); + EVT VT = Data.getValueType(); + + // Using contiguous stores is not beneficial over 64 bit scatters + if (VT.getVectorElementType().getSizeInBits() == 64) + return SDValue(); + + auto EltCount = + DAG.getVScale(DL, MVT::i64, + APInt(VT.getScalarSizeInBits(), + VT.getVectorMinNumElements() << ShiftAmt)); + + // As with the mask, split the data into two parts and interleave; + // Here we can just use the data itself for the other lanes, since + // the inactive lanes won't be stored. + auto DataL = DAG.getNode(AArch64ISD::ZIP1, DL, VT, Data, Data); + auto DataH = DAG.getNode(AArch64ISD::ZIP2, DL, VT, Data, Data); + + // Perform the actual stores, making sure we use the chain value + // from the first in the second. + auto StoreL = DAG.getMaskedStore(Chain, DL, DataL, BasePtr, + DAG.getUNDEF(BaseVT), PredL, MemVT, + MSC->getMemOperand(), ISD::UNINDEXED, + MSC->isTruncatingStore()); + + BasePtr = DAG.getNode(ISD::ADD, DL, BaseVT, BasePtr, EltCount); + auto StoreH = DAG.getMaskedStore(StoreL, DL, DataH, BasePtr, + DAG.getUNDEF(BaseVT), PredH, MemVT, + MSC->getMemOperand(), ISD::UNINDEXED, + MSC->isTruncatingStore()); + + return StoreH; +} + static SDValue performMaskedGatherScatterCombine( SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { MaskedGatherScatterSDNode *MGS = cast(N); assert(MGS && "Can only combine gather load or scatter store nodes"); - if (!DCI.isBeforeLegalize()) - return SDValue(); - SDLoc DL(MGS); SDValue Chain = MGS->getChain(); SDValue Scale = MGS->getScale(); @@ -16603,25 +16699,34 @@ SDValue Mask = MGS->getMask(); SDValue BasePtr = MGS->getBasePtr(); ISD::MemIndexType IndexType = MGS->getIndexType(); + EVT MemVT = MGS->getMemoryVT(); + + if (DCI.isBeforeLegalize() && + findMoreOptimalIndexType(MGS, BasePtr, Index, DAG)) { + // Here we catch such cases early and change MGATHER's IndexType to allow + // the use of an Index that's more legalisation friendly. + if (auto *MGT = dyn_cast(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + MemVT, DL, Ops, MGT->getMemOperand(), + IndexType, MGT->getExtensionType()); + } + auto *MSC = cast(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL, Ops, + MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); + } - if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG)) - return SDValue(); + // Attempt to transform a gather or scatter with a stride of 2 into + // a pair of contiguous loads or stores + if (Index->getOpcode() == ISD::STEP_VECTOR && + IndexType == ISD::SIGNED_SCALED) + return tryToCombineToMaskedLoadStore(MGS, DAG); - // Here we catch such cases early and change MGATHER's IndexType to allow - // the use of an Index that's more legalisation friendly. - if (auto *MGT = dyn_cast(MGS)) { - SDValue PassThru = MGT->getPassThru(); - SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedGather( - DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, - Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); - } - auto *MSC = cast(MGS); - SDValue Data = MSC->getValue(); - SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; - return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, - Ops, MSC->getMemOperand(), IndexType, - MSC->isTruncatingStore()); + return SDValue(); } /// Target-specific DAG combine function for NEON load/store intrinsics Index: llvm/test/CodeGen/AArch64/sve-gather-scatter-to-contiguous.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-gather-scatter-to-contiguous.ll @@ -0,0 +1,89 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; Gather -> Contiguous Load + +define @gather_stride2_4i32(i32* %addr, %mask) { +; CHECK-LABEL: gather_stride2_4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 p2.s, p0.s, p1.s +; CHECK-NEXT: zip2 p0.s, p0.s, p1.s +; CHECK-NEXT: ld1w { z0.s }, p2/z, [x0] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s +; CHECK-NEXT: ret + %1 = insertelement undef, i32 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv4i32() + %indices = mul %2, %stepvector + %ptrs = getelementptr i32, i32* %addr, %indices + %vals = call @llvm.masked.gather.nxv4i32( %ptrs, i32 4, %mask, undef) + ret %vals +} + +define @gather_stride2_2i64(i64* %addr, %mask) { +; CHECK-LABEL: gather_stride2_2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 p2.d, p0.d, p1.d +; CHECK-NEXT: zip2 p0.d, p0.d, p1.d +; CHECK-NEXT: ld1d { z0.d }, p2/z, [x0] +; CHECK-NEXT: ld1d { z1.d }, p0/z, [x0, #1, mul vl] +; CHECK-NEXT: uzp1 z0.d, z0.d, z1.d +; CHECK-NEXT: ret + %1 = insertelement undef, i64 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv2i64() + %indices = mul %2, %stepvector + %ptrs = getelementptr i64, i64* %addr, %indices + %vals = call @llvm.masked.gather.nxv2i64( %ptrs, i32 8, %mask, undef) + ret %vals +} + +; Scatter -> Contiguous Store + +define void @scatter_stride2_4i32(i32* %addr, %data, %mask) { +; CHECK-LABEL: scatter_stride2_4i32: +; CHECK: // %bb.0: +; CHECK-NEXT: pfalse p1.b +; CHECK-NEXT: zip1 z1.s, z0.s, z0.s +; CHECK-NEXT: zip1 p2.s, p0.s, p1.s +; CHECK-NEXT: zip2 p0.s, p0.s, p1.s +; CHECK-NEXT: zip2 z0.s, z0.s, z0.s +; CHECK-NEXT: st1w { z1.s }, p2, [x0] +; CHECK-NEXT: st1w { z0.s }, p0, [x0, #1, mul vl] +; CHECK-NEXT: ret + %1 = insertelement undef, i32 2, i32 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv4i32() + %indices = mul %2, %stepvector + %ptrs = getelementptr i32, i32* %addr, %indices + call void @llvm.masked.scatter.nxv4i32( %data, %ptrs, i32 4, %mask) + ret void +} + +; Contiguous stores are not beneficial over 64 bit scatters here +define void @scatter_stride2_2i64(i64* %addr, %data, %mask) { +; CHECK-LABEL: scatter_stride2_2i64: +; CHECK: // %bb.0: +; CHECK-NEXT: index z1.d, #0, #2 +; CHECK-NEXT: st1d { z0.d }, p0, [x0, z1.d, lsl #3] +; CHECK-NEXT: ret + %1 = insertelement undef, i64 2, i64 0 + %2 = shufflevector %1, undef, zeroinitializer + %stepvector = call @llvm.experimental.stepvector.nxv2i64() + %indices = mul %2, %stepvector + %ptrs = getelementptr i64, i64* %addr, %indices + call void @llvm.masked.scatter.nxv2i64( %data, %ptrs, i32 4, %mask) + ret void +} + +declare @llvm.experimental.stepvector.nxv4i32() +declare @llvm.experimental.stepvector.nxv2i64() + +declare @llvm.masked.gather.nxv4i32(, i32, , ) +declare @llvm.masked.gather.nxv2i64(, i32, , ) + +declare void @llvm.masked.scatter.nxv4i32(, , i32, ) +declare void @llvm.masked.scatter.nxv2i64(, , i32, )