diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -889,6 +889,9 @@ setTargetDAGCombine(ISD::VECREDUCE_ADD); setTargetDAGCombine(ISD::STEP_VECTOR); + setTargetDAGCombine(ISD::MGATHER); + setTargetDAGCombine(ISD::MSCATTER); + setTargetDAGCombine(ISD::FP_EXTEND); setTargetDAGCombine(ISD::GlobalAddress); @@ -16358,6 +16361,93 @@ return SDValue(); } +// Analyse the specified address returning true if a more optimal addressing +// mode is available. When returning true all parameters are updated to reflect +// their recommended values. +static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N, + SDValue &BasePtr, SDValue &Index, + ISD::MemIndexType &IndexType, + SelectionDAG &DAG) { + // Only consider element types that are pointer sized as smaller types can + // be easily promoted. + EVT IndexVT = Index.getValueType(); + if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64) + return false; + + int64_t Stride = 0; + SDLoc DL(N); + // Index = step(const) + splat(offset) + if (Index.getOpcode() == ISD::ADD && + Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) { + SDValue StepVector = Index.getOperand(0); + if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) { + Stride = cast(StepVector.getOperand(0))->getSExtValue(); + Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale()); + BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset); + } + } + + // Return early because no supported pattern is found. + if (Stride == 0) + return false; + + if (Stride < std::numeric_limits::min() || + Stride > std::numeric_limits::max()) + return false; + + const auto &Subtarget = + static_cast(DAG.getSubtarget()); + unsigned MaxVScale = + Subtarget.getMaxSVEVectorSizeInBits() / AArch64::SVEBitsPerBlock; + int64_t LastElementOffset = + IndexVT.getVectorMinNumElements() * Stride * MaxVScale; + + if (LastElementOffset < std::numeric_limits::min() || + LastElementOffset > std::numeric_limits::max()) + return false; + + EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32); + Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT, + DAG.getTargetConstant(Stride, DL, MVT::i32)); + return true; +} + +static SDValue performMaskedGatherScatterCombine( + SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + MaskedGatherScatterSDNode *MGS = cast(N); + assert(MGS && "Can only combine gather load or scatter store nodes"); + + if (!DCI.isBeforeLegalize()) + return SDValue(); + + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Scale = MGS->getScale(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + ISD::MemIndexType IndexType = MGS->getIndexType(); + + if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG)) + return SDValue(); + + // Here we catch such cases early and change MGATHER's IndexType to allow + // the use of an Index that's more legalisation friendly. + if (auto *MGT = dyn_cast(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedGather( + DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, + Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); + } + auto *MSC = cast(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, + Ops, MSC->getMemOperand(), IndexType, + MSC->isTruncatingStore()); +} + /// Target-specific DAG combine function for NEON load/store intrinsics /// to merge base address updates. static SDValue performNEONPostLDSTCombine(SDNode *N, @@ -17820,6 +17910,9 @@ break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MGATHER: + case ISD::MSCATTER: + return performMaskedGatherScatterCombine(N, DCI, DAG); case ISD::VECTOR_SPLICE: return performSVESpliceCombine(N, DAG); case ISD::FP_EXTEND: diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -0,0 +1,212 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-linux-unknown | FileCheck %s + + +; Ensure we use a "vscale x 4" wide scatter for the maximum supported offset. +define void @scatter_i8_index_offset_maximum(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_maximum: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #33554431 +; CHECK-NEXT: add x9, x0, x1 +; CHECK-NEXT: index z1.s, #0, w8 +; CHECK-NEXT: st1b { z0.s }, p0, [x9, z1.s, sxtw] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 33554431, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we use a "vscale x 4" wide scatter for the minimum supported offset. +define void @scatter_i16_index_offset_minimum(i16* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i16_index_offset_minimum: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #-33554432 +; CHECK-NEXT: add x9, x0, x1, lsl #1 +; CHECK-NEXT: index z1.s, #0, w8 +; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw #1] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 -33554432, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i16, i16* %base, %t5 + call void @llvm.masked.scatter.nxv4i16( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we use a "vscale x 4" gather for an offset in the limits of 32 bits. +define @gather_i8_index_offset_8(i8* %base, i64 %offset, %pg) #0 { +; CHECK-LABEL: gather_i8_index_offset_8: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x1 +; CHECK-NEXT: index z0.s, #0, #1 +; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw] +; CHECK-NEXT: ret + %splat.insert0 = insertelement undef, i64 %offset, i32 0 + %splat0 = shufflevector %splat.insert0, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %splat.insert1 = insertelement undef, i64 1, i32 0 + %splat1 = shufflevector %splat.insert1, undef, zeroinitializer + %t1 = mul %splat1, %step + %t2 = add %splat0, %t1 + %t3 = getelementptr i8, i8* %base, %t2 + %load = call @llvm.masked.gather.nxv4i8( %t3, i32 4, %pg, undef) + ret %load +} + +;; Negative tests + +; Ensure we don't use a "vscale x 4" scatter. Cannot prove that variable stride +; will not wrap when shrunk to be i32 based. +define void @scatter_f16_index_offset_var(half* %base, i64 %offset, i64 %scale, %pg, %data) #0 { +; CHECK-LABEL: scatter_f16_index_offset_var: +; CHECK: // %bb.0: +; CHECK-NEXT: index z1.d, #0, #1 +; CHECK-NEXT: mov z3.d, x1 +; CHECK-NEXT: mov z2.d, z1.d +; CHECK-NEXT: mov z4.d, z3.d +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: incd z2.d +; CHECK-NEXT: mla z3.d, p1/m, z1.d, z3.d +; CHECK-NEXT: mla z4.d, p1/m, z2.d, z4.d +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: uunpklo z1.d, z0.s +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1h { z1.d }, p1, [x0, z3.d, lsl #1] +; CHECK-NEXT: st1h { z0.d }, p0, [x0, z4.d, lsl #1] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 %scale, i32 0 + %t3 = shufflevector %t0, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr half, half* %base, %t5 + call void @llvm.masked.scatter.nxv4f16( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we don't use a "vscale x 4" wide scatter when the offset is too big. +define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_maximum_plus_one: +; CHECK: // %bb.0: +; CHECK-NEXT: rdvl x8, #1 +; CHECK-NEXT: mov w9, #67108864 +; CHECK-NEXT: lsr x8, x8, #4 +; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: mul x8, x8, x9 +; CHECK-NEXT: mov w9, #33554432 +; CHECK-NEXT: index z2.d, #0, x9 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: add z3.d, z2.d, z3.d +; CHECK-NEXT: add z2.d, z2.d, z1.d +; CHECK-NEXT: add z1.d, z3.d, z1.d +; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 33554432, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we don't use a "vscale x 4" wide scatter when the offset is too small. +define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_minimum_minus_one: +; CHECK: // %bb.0: +; CHECK-NEXT: rdvl x8, #1 +; CHECK-NEXT: mov x9, #-2 +; CHECK-NEXT: lsr x8, x8, #4 +; CHECK-NEXT: movk x9, #64511, lsl #16 +; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: mul x8, x8, x9 +; CHECK-NEXT: mov x9, #-33554433 +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: index z2.d, #0, x9 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: add z3.d, z2.d, z3.d +; CHECK-NEXT: add z2.d, z2.d, z1.d +; CHECK-NEXT: add z1.d, z3.d, z1.d +; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 -33554433, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we don't use a "vscale x 4" wide scatter when the stride is too big . +define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_stride_too_big: +; CHECK: // %bb.0: +; CHECK-NEXT: rdvl x8, #1 +; CHECK-NEXT: mov x9, #-9223372036854775808 +; CHECK-NEXT: lsr x8, x8, #4 +; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: mul x8, x8, x9 +; CHECK-NEXT: mov x9, #4611686018427387904 +; CHECK-NEXT: index z2.d, #0, x9 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: add z3.d, z2.d, z3.d +; CHECK-NEXT: add z2.d, z2.d, z1.d +; CHECK-NEXT: add z1.d, z3.d, z1.d +; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 4611686018427387904, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + + +attributes #0 = { "target-features"="+sve" vscale_range(1, 16) } + + +declare @llvm.masked.gather.nxv4i8(, i32, , ) +declare void @llvm.masked.scatter.nxv4i8(, , i32, ) +declare void @llvm.masked.scatter.nxv4i16(, , i32, ) +declare void @llvm.masked.scatter.nxv4f16(, , i32, ) +declare @llvm.experimental.stepvector.nxv4i64()