diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -16476,55 +16476,90 @@ return SDValue(); } -// Analyse the specified address returning true if a more optimal addressing -// mode is available. When returning true all parameters are updated to reflect -// their recommended values. -static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N, - SDValue &BasePtr, SDValue &Index, - ISD::MemIndexType &IndexType, - SelectionDAG &DAG) { - // Only consider element types that are pointer sized as smaller types can - // be easily promoted. +/// \return true if part of the index was folded into the Base. +static bool foldIndexIntoBase(SDValue &BasePtr, SDValue &Index, SDValue Scale, + SDLoc DL, SelectionDAG &DAG) { + // This function assumes a vector of i64 indices. EVT IndexVT = Index.getValueType(); - if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64) + if (!IndexVT.isVector() || IndexVT.getVectorElementType() != MVT::i64) return false; - int64_t Stride = 0; - SDLoc DL(N); - // Index = step(const) + splat(offset) - if (Index.getOpcode() == ISD::ADD && - Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) { - SDValue StepVector = Index.getOperand(0); + // Simplify: + // BasePtr = Ptr + // Index = X + splat(Offset) + // -> + // BasePtr = Ptr + Offset * scale. + // Index = X + if (Index.getOpcode() == ISD::ADD) { if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) { - Stride = cast(StepVector.getOperand(0))->getSExtValue(); - Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale()); + Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale); BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset); + Index = Index.getOperand(0); + return true; } } - // Index = shl((step(const) + splat(offset))), splat(shift)) + // Simplify: + // BasePtr = Ptr + // Index = (X + splat(Offset)) << splat(Shift) + // -> + // BasePtr = Ptr + (Offset << Shift) * scale) + // Index = X << splat(shift) if (Index.getOpcode() == ISD::SHL && - Index.getOperand(0).getOpcode() == ISD::ADD && - Index.getOperand(0).getOperand(0).getOpcode() == ISD::STEP_VECTOR) { + Index.getOperand(0).getOpcode() == ISD::ADD) { SDValue Add = Index.getOperand(0); SDValue ShiftOp = Index.getOperand(1); - SDValue StepOp = Add.getOperand(0); SDValue OffsetOp = Add.getOperand(1); - if (auto *Shift = - dyn_cast_or_null(DAG.getSplatValue(ShiftOp))) + if (auto Shift = DAG.getSplatValue(ShiftOp)) if (auto Offset = DAG.getSplatValue(OffsetOp)) { - int64_t Step = - cast(StepOp.getOperand(0))->getSExtValue(); - // Stride does not scale explicitly by 'Scale', because it happens in - // the gather/scatter addressing mode. - Stride = Step << Shift->getSExtValue(); - // BasePtr = BasePtr + ((Offset * Scale) << Shift) - Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, N->getScale()); - Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, SDValue(Shift, 0)); + Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, Shift); + Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale); BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset); + Index = DAG.getNode(ISD::SHL, DL, Index.getValueType(), + Add.getOperand(0), ShiftOp); + return true; } } + return false; +} + +// Analyse the specified address returning true if a more optimal addressing +// mode is available. When returning true all parameters are updated to reflect +// their recommended values. +static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N, + SDValue &BasePtr, SDValue &Index, + SelectionDAG &DAG) { + // Only consider element types that are pointer sized as smaller types can + // be easily promoted. + EVT IndexVT = Index.getValueType(); + if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64) + return false; + + // Try to iteratively fold parts of the index into the base pointer to + // simplify the index as much as possible. + SDValue NewBasePtr = BasePtr, NewIndex = Index; + while (foldIndexIntoBase(NewBasePtr, NewIndex, N->getScale(), SDLoc(N), DAG)) + ; + + // Match: + // Index = step(const) + int64_t Stride = 0; + if (NewIndex.getOpcode() == ISD::STEP_VECTOR) + Stride = cast(NewIndex.getOperand(0))->getSExtValue(); + + // Match: + // Index = step(const) << shift(const) + else if (NewIndex.getOpcode() == ISD::SHL && + NewIndex.getOperand(0).getOpcode() == ISD::STEP_VECTOR) { + SDValue RHS = NewIndex.getOperand(1); + if (auto *Shift = + dyn_cast_or_null(DAG.getSplatValue(RHS))) { + int64_t Step = (int64_t)NewIndex.getOperand(0).getConstantOperandVal(1); + Stride = Step << Shift->getZExtValue(); + } + } + // Return early because no supported pattern is found. if (Stride == 0) return false; @@ -16545,8 +16580,11 @@ return false; EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32); - Index = DAG.getNode(ISD::STEP_VECTOR, DL, NewIndexVT, - DAG.getTargetConstant(Stride, DL, MVT::i32)); + // Stride does not scale explicitly by 'Scale', because it happens in + // the gather/scatter addressing mode. + Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT, + DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32)); + BasePtr = NewBasePtr; return true; } @@ -16566,7 +16604,7 @@ SDValue BasePtr = MGS->getBasePtr(); ISD::MemIndexType IndexType = MGS->getIndexType(); - if (!findMoreOptimalIndexType(MGS, BasePtr, Index, IndexType, DAG)) + if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG)) return SDValue(); // Here we catch such cases early and change MGATHER's IndexType to allow diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -283,7 +283,54 @@ ret void } +; stepvector is hidden further behind GEP and two adds. +define void @scatter_f16_index_add_add([8 x half]* %base, i64 %offset, i64 %offset2, %pg, %data) #0 { +; CHECK-LABEL: scatter_f16_index_add_add: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #16 +; CHECK-NEXT: add x9, x0, x2, lsl #4 +; CHECK-NEXT: add x9, x9, x1, lsl #4 +; CHECK-NEXT: index z1.s, #0, w8 +; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw] +; CHECK-NEXT: ret + %splat.offset.ins = insertelement undef, i64 %offset, i32 0 + %splat.offset = shufflevector %splat.offset.ins, undef, zeroinitializer + %splat.offset2.ins = insertelement undef, i64 %offset2, i32 0 + %splat.offset2 = shufflevector %splat.offset2.ins, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %add1 = add %splat.offset, %step + %add2 = add %add1, %splat.offset2 + %gep = getelementptr [8 x half], [8 x half]* %base, %add2 + %gep.bc = bitcast %gep to + call void @llvm.masked.scatter.nxv4f16( %data, %gep.bc, i32 2, %pg) + ret void +} +; stepvector is hidden further behind GEP two adds and a shift. +define void @scatter_f16_index_add_add_mul([8 x half]* %base, i64 %offset, i64 %offset2, %pg, %data) #0 { +; CHECK-LABEL: scatter_f16_index_add_add_mul: +; CHECK: // %bb.0: +; CHECK-NEXT: mov w8, #128 +; CHECK-NEXT: add x9, x0, x2, lsl #7 +; CHECK-NEXT: add x9, x9, x1, lsl #7 +; CHECK-NEXT: index z1.s, #0, w8 +; CHECK-NEXT: st1h { z0.s }, p0, [x9, z1.s, sxtw] +; CHECK-NEXT: ret + %splat.offset.ins = insertelement undef, i64 %offset, i32 0 + %splat.offset = shufflevector %splat.offset.ins, undef, zeroinitializer + %splat.offset2.ins = insertelement undef, i64 %offset2, i32 0 + %splat.offset2 = shufflevector %splat.offset2.ins, undef, zeroinitializer + %step = call @llvm.experimental.stepvector.nxv4i64() + %add1 = add %splat.offset, %step + %add2 = add %add1, %splat.offset2 + %splat.const8.ins = insertelement undef, i64 8, i32 0 + %splat.const8 = shufflevector %splat.const8.ins, undef, zeroinitializer + %mul = mul %add2, %splat.const8 + %gep = getelementptr [8 x half], [8 x half]* %base, %mul + %gep.bc = bitcast %gep to + call void @llvm.masked.scatter.nxv4f16( %data, %gep.bc, i32 2, %pg) + ret void +} attributes #0 = { "target-features"="+sve" vscale_range(1, 16) } declare @llvm.masked.gather.nxv4f32(, i32, , )