diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -932,6 +932,9 @@ setTargetDAGCombine(ISD::VECREDUCE_ADD); setTargetDAGCombine(ISD::STEP_VECTOR); + setTargetDAGCombine(ISD::MGATHER); + setTargetDAGCombine(ISD::MSCATTER); + setTargetDAGCombine(ISD::FP_EXTEND); setTargetDAGCombine(ISD::GlobalAddress); @@ -16091,6 +16094,131 @@ return SDValue(); } +// Analyse the specified address returning true if a more optimal addressing +// mode is available. When returning true all parameters are updated to reflect +// their recommended values. +static bool FindMoreOptimalIndexType(const MaskedGatherScatterSDNode *N, + SDValue &BasePtr, SDValue &Index, + ISD::MemIndexType &IndexType, + SelectionDAG &DAG) { + bool ScaledOffset = + (IndexType == ISD::SIGNED_SCALED) || (IndexType == ISD::UNSIGNED_SCALED); + + // Only consider element types that are pointer sized as smaller types can + // be easily promoted. + if (Index.getValueType().getVectorElementType() != MVT::i64) + return false; + + EVT IndexVT = Index.getValueType(); + assert(IndexVT.getVectorElementType() == MVT::i64 && + "Unexpected Index type!"); + + // Ignore indices that are already type legal. + if (IndexVT == MVT::nxv2i64) + return false; + + IndexVT = IndexVT.changeVectorElementType(MVT::i32); + int64_t StepConst = 0; + SDLoc DL(N); + bool ChangedIndex = false; + + // Index = splat(offset) + step(const) + if (Index.getOpcode() == ISD::ADD && + Index.getOperand(1).getOpcode() == ISD::STEP_VECTOR) { + auto *C = dyn_cast(Index.getOperand(1).getOperand(0)); + if (!C) + return false; + StepConst = C->getSExtValue(); + if (auto Offset = DAG.getSplatValue(Index.getOperand(0))) { + if (ScaledOffset) + Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale()); + BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset); + ChangedIndex = true; + } + } + + // Index = shl((splat(offset) + step(const)), splat(shift)) + if (Index.getOpcode() == ISD::SHL && + Index.getOperand(0).getOpcode() == ISD::ADD && + Index.getOperand(0).getOperand(1).getOpcode() == ISD::STEP_VECTOR) + if (auto Shift = DAG.getSplatValue(Index.getOperand(1))) + if (auto Offset = DAG.getSplatValue(Index.getOperand(0).getOperand(0))) { + auto *C = dyn_cast(Shift); + if (!C) + return false; + auto *Step = cast( + Index.getOperand(0).getOperand(1).getOperand(0)); + if (!Step) + return false; + // StepConst = const << shift + StepConst = Step->getSExtValue() << C->getSExtValue(); + if (ScaledOffset) + Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale()); + // BasePtr = BasePtr + (Offset << Shift) + Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, Shift); + BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset); + ChangedIndex = true; + } + + // Ensure all element can be represented as a signed i32 value. + if (StepConst < std::numeric_limits::min() || + StepConst > std::numeric_limits::max()) + return false; + MachineFunction &MF = DAG.getMachineFunction(); + llvm::Optional MaxVScale = 1; + if (MF.getFunction().hasFnAttribute(Attribute::VScaleRange)) + MaxVScale = MF.getFunction() + .getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeMax(); + int64_t LastElementOffset = + IndexVT.getVectorMinNumElements() * StepConst * MaxVScale.getValue(); + if (LastElementOffset < std::numeric_limits::min() || + LastElementOffset > std::numeric_limits::max()) + return false; + + // Only use StepConst when it knows it does not overflow in MGATHER/MSCATTER + Index = DAG.getNode(ISD::STEP_VECTOR, DL, IndexVT, + DAG.getTargetConstant(StepConst, DL, MVT::i32)); + return ChangedIndex; +} + +static SDValue performMaskedGatherScatterCombine( + SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { + MaskedGatherScatterSDNode *MGS = cast(N); + assert(MGS && "Can only combine gather load or scatter store nodes"); + + SDLoc DL(MGS); + SDValue Chain = MGS->getChain(); + SDValue Scale = MGS->getScale(); + SDValue Index = MGS->getIndex(); + SDValue Mask = MGS->getMask(); + SDValue BasePtr = MGS->getBasePtr(); + ISD::MemIndexType IndexType = MGS->getIndexType(); + + if (DCI.isBeforeLegalize()) { + // Here we catch such cases early and change MGATHER's IndexType to allow + // the use of an Index that's more legalisation friendly. + if (FindMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG)) { + if (auto *MGT = dyn_cast(MGS)) { + SDValue PassThru = MGT->getPassThru(); + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedGather( + DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), + DL, Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); + } else { + auto *MSC = cast(MGS); + SDValue Data = MSC->getValue(); + SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedScatter( + DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, + MSC->getMemOperand(), IndexType, MSC->isTruncatingStore()); + } + } + } + + return SDValue(); +} + /// Target-specific DAG combine function for NEON load/store intrinsics /// to merge base address updates. static SDValue performNEONPostLDSTCombine(SDNode *N, @@ -17547,6 +17675,10 @@ break; case ISD::STORE: return performSTORECombine(N, DCI, DAG, Subtarget); + case ISD::MGATHER: + return performMaskedGatherScatterCombine(N, DCI, DAG); + case ISD::MSCATTER: + return performMaskedGatherScatterCombine(N, DCI, DAG); case ISD::VECTOR_SPLICE: return performSVESpliceCombine(N, DAG); case ISD::FP_EXTEND: diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -0,0 +1,275 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-linux-unknown | FileCheck %s + +; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the +; impression the gather must be split due to it's offset. +; gather_f32(base, index(offset, 8 * sizeof(float)) +define @gather_i8_index_offset_8([8 x i8]* %base, i64 %offset, %pg) #0 { +; CHECK-LABEL: gather_i8_index_offset_8: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x1, lsl #3 +; CHECK-NEXT: index z0.s, #0, #8 +; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x8, z0.s, sxtw] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t2 = add %t1, %step + %t3 = getelementptr [8 x i8], [8 x i8]* %base, %t2 + %t4 = bitcast %t3 to + %load = call @llvm.masked.gather.nxv4i8( %t4, i32 4, %pg, undef) + ret %load +} + +; Ensure the resulting load is "vscale x 4" wide, despite the offset giving the +; impression the gather must be split due to it's offset. +; gather_f32(base, index(offset, 8 * sizeof(float)) +define @gather_f32_index_offset_8([8 x float]* %base, i64 %offset, %pg) #0 { +; CHECK-LABEL: gather_f32_index_offset_8: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #32 +; CHECK-NEXT: add x9, x0, x1, lsl #5 +; CHECK-NEXT: index z0.s, #0, w8 +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x9, z0.s, sxtw] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t2 = add %t1, %step + %t3 = getelementptr [8 x float], [8 x float]* %base, %t2 + %t4 = bitcast %t3 to + %load = call @llvm.masked.gather.nxv4f32( %t4, i32 4, %pg, undef) + ret %load +} + + +; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the +; impression the scatter must be split due to it's offset. +; scatter_f16(base, index(offset, 8 * sizeof(i8)) +define void @scatter_i8_index_offset_8([8 x i8]* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_8: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x0, x1, lsl #3 +; CHECK-NEXT: index z1.s, #0, #8 +; CHECK-NEXT: st1b { z0.s }, p0, [x8, z1.s, sxtw] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t2 = add %t1, %step + %t3 = getelementptr [8 x i8], [8 x i8]* %base, %t2 + %t4 = bitcast %t3 to + call void @llvm.masked.scatter.nxv4i8( %data, %t4, i32 2, %pg) + ret void +} + +; Ensure the resulting store is "vscale x 4" wide, despite the offset giving the +; impression the scatter must be split due to it's offset. +; scatter_f16(base, index(offset, 8 * sizeof(half)) +define void @scatter_f16_index_offset_8([8 x half]* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_f16_index_offset_8: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #16 +; CHECK-NEXT: add x9, x0, x1, lsl #4 +; CHECK-NEXT: index z1.s, #0, w8 +; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t2 = add %t1, %step + %t3 = getelementptr [8 x half], [8 x half]* %base, %t2 + %t4 = bitcast %t3 to + call void @llvm.masked.scatter.nxv4f16( %data, %t4, i32 2, %pg) + ret void +} + +; As scatter_f16_index_offset_8 but with a variable stride that we cannot prove +; will not wrap when shrunk to be i32 based. +define void @scatter_f16_index_offset_var(half* %base, i64 %offset, i64 %scale, %pg, %data) #0 { +; CHECK-LABEL: scatter_f16_index_offset_var: +; CHECK: // %bb.0: +; CHECK-NEXT: index z1.d, #0, #1 +; CHECK-NEXT: mov z3.d, x1 +; CHECK-NEXT: mov z2.d, z1.d +; CHECK-NEXT: mov z4.d, z3.d +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: incd z2.d +; CHECK-NEXT: mla z3.d, p1/m, z3.d, z1.d +; CHECK-NEXT: mla z4.d, p1/m, z4.d, z2.d +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: uunpklo z1.d, z0.s +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1h { z1.d }, p1, [x0, z3.d, lsl #1] +; CHECK-NEXT: st1h { z0.d }, p0, [x0, z4.d, lsl #1] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 %scale, i32 0 + %t3 = shufflevector %t0, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr half, half* %base, %t5 + call void @llvm.masked.scatter.nxv4f16( %data, %t6, i32 2, %pg) + ret void +} + + +; Ensure we use a "vscale x 4" wide scatter for the maximum supported offset. +define void @scatter_i8_index_offset_maximum(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_maximum: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #33554431 +; CHECK-NEXT: add x9, x0, x1 +; CHECK-NEXT: index z1.s, #0, w8 +; CHECK-NEXT: st1b { z0.s }, p0, [x9, z1.s, sxtw] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 33554431, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we don't use a "vscale x 4" wide scatter when the offset is too big. +define void @scatter_i8_index_offset_maximum_plus_one(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_maximum_plus_one: +; CHECK: // %bb.0: +; CHECK-NEXT: rdvl x8, #1 +; CHECK-NEXT: mov w9, #67108864 +; CHECK-NEXT: lsr x8, x8, #4 +; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: mul x8, x8, x9 +; CHECK-NEXT: mov w9, #33554432 +; CHECK-NEXT: index z2.d, #0, x9 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: add z3.d, z2.d, z3.d +; CHECK-NEXT: add z2.d, z1.d, z2.d +; CHECK-NEXT: add z1.d, z1.d, z3.d +; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 33554432, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we use a "vscale x 4" wide scatter for the minimum supported offset. +; DAGCombine changes: +; before legalization. +define void @scatter_i8_index_offset_minimum(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_minimum: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #-33554432 +; CHECK-NEXT: add x9, x0, x1 +; CHECK-NEXT: index z1.s, #0, w8 +; CHECK-NEXT: st1b { z0.s }, p0, [x9, z1.s, sxtw] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 -33554432, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we don't use a "vscale x 4" wide scatter when the offset is too small. +define void @scatter_i8_index_offset_minimum_minus_one(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_offset_minimum_minus_one: +; CHECK: // %bb.0: +; CHECK-NEXT: rdvl x8, #1 +; CHECK-NEXT: mov x9, #-2 +; CHECK-NEXT: lsr x8, x8, #4 +; CHECK-NEXT: movk x9, #64511, lsl #16 +; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: mul x8, x8, x9 +; CHECK-NEXT: mov x9, #-33554433 +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: index z2.d, #0, x9 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: add z3.d, z2.d, z3.d +; CHECK-NEXT: add z2.d, z1.d, z2.d +; CHECK-NEXT: add z1.d, z1.d, z3.d +; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 -33554433, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + +; Ensure we don't use a "vscale x 4" wide scatter when the stride is too big . +define void @scatter_i8_index_stride_too_big(i8* %base, i64 %offset, %pg, %data) #0 { +; CHECK-LABEL: scatter_i8_index_stride_too_big: +; CHECK: // %bb.0: +; CHECK-NEXT: rdvl x8, #1 +; CHECK-NEXT: mov x9, #-9223372036854775808 +; CHECK-NEXT: lsr x8, x8, #4 +; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: punpklo p1.h, p0.b +; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: mul x8, x8, x9 +; CHECK-NEXT: mov x9, #4611686018427387904 +; CHECK-NEXT: index z2.d, #0, x9 +; CHECK-NEXT: mov z3.d, x8 +; CHECK-NEXT: add z3.d, z2.d, z3.d +; CHECK-NEXT: add z2.d, z1.d, z2.d +; CHECK-NEXT: add z1.d, z1.d, z3.d +; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: ret + %t0 = insertelement undef, i64 %offset, i32 0 + %t1 = shufflevector %t0, undef, zeroinitializer + %t2 = insertelement undef, i64 4611686018427387904, i32 0 + %t3 = shufflevector %t2, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %t4 = mul %t3, %step + %t5 = add %t1, %t4 + %t6 = getelementptr i8, i8* %base, %t5 + call void @llvm.masked.scatter.nxv4i8( %data, %t6, i32 2, %pg) + ret void +} + + +attributes #0 = { "target-features"="+sve" vscale_range(1, 16) } + +declare @llvm.masked.gather.nxv4i8(, i32, , ) +declare @llvm.masked.gather.nxv4f32(, i32, , ) + +declare void @llvm.masked.scatter.nxv4i8(, , i32, ) +declare void @llvm.masked.scatter.nxv4f16(, , i32, ) +declare @llvm.experimental.stepvector.nxv4i64()