diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4651,63 +4651,29 @@ return false; } -// If the base pointer of a masked gather or scatter is null, we -// may be able to swap BasePtr & Index and use the vector + register -// or vector + immediate addressing mode, e.g. -// VECTOR + REGISTER: -// getelementptr nullptr, (splat(%offset)) + %indices) -// -> getelementptr %offset, %indices +// If the base pointer of a masked gather or scatter is constant, we +// may be able to swap BasePtr & Index and use the vector + immediate addressing +// mode, e.g. // VECTOR + IMMEDIATE: // getelementptr nullptr, (splat(#x)) + %indices) // -> getelementptr #x, %indices void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, bool IsScaled, EVT MemVT, unsigned &Opcode, bool IsGather, SelectionDAG &DAG) { - if (!isNullConstant(BasePtr) || IsScaled) + ConstantSDNode *Offset = dyn_cast(BasePtr); + if (!Offset || IsScaled) return; - // FIXME: This will not match for fixed vector type codegen as the nodes in - // question will have fixed<->scalable conversions around them. This should be - // moved to a DAG combine or complex pattern so that is executes after all of - // the fixed vector insert and extracts have been removed. This deficiency - // will result in a sub-optimal addressing mode being used, i.e. an ADD not - // being folded into the scatter/gather. - ConstantSDNode *Offset = nullptr; - if (Index.getOpcode() == ISD::ADD) - if (auto SplatVal = DAG.getSplatValue(Index.getOperand(1))) { - if (isa(SplatVal)) - Offset = cast(SplatVal); - else { - BasePtr = SplatVal; - Index = Index->getOperand(0); - return; - } - } - - unsigned NewOp = - IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED; - - if (!Offset) { - std::swap(BasePtr, Index); - Opcode = NewOp; - return; - } - uint64_t OffsetVal = Offset->getZExtValue(); unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8; - auto ConstOffset = DAG.getConstant(OffsetVal, SDLoc(Index), MVT::i64); - if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31) { - // Index is out of range for the immediate addressing mode - BasePtr = ConstOffset; - Index = Index->getOperand(0); + if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31) return; - } // Immediate is in range - Opcode = NewOp; - BasePtr = Index->getOperand(0); - Index = ConstOffset; + Opcode = + IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED; + std::swap(BasePtr, Index); } SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, @@ -17136,43 +17102,43 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N, SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) { + // Try to iteratively fold parts of the index into the base pointer to + // simplify the index as much as possible. + bool Changed = false; + while (foldIndexIntoBase(BasePtr, Index, N->getScale(), SDLoc(N), DAG)) + Changed = true; + // Only consider element types that are pointer sized as smaller types can // be easily promoted. EVT IndexVT = Index.getValueType(); if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64) - return false; - - // Try to iteratively fold parts of the index into the base pointer to - // simplify the index as much as possible. - SDValue NewBasePtr = BasePtr, NewIndex = Index; - while (foldIndexIntoBase(NewBasePtr, NewIndex, N->getScale(), SDLoc(N), DAG)) - ; + return Changed; // Match: // Index = step(const) int64_t Stride = 0; - if (NewIndex.getOpcode() == ISD::STEP_VECTOR) - Stride = cast(NewIndex.getOperand(0))->getSExtValue(); + if (Index.getOpcode() == ISD::STEP_VECTOR) + Stride = cast(Index.getOperand(0))->getSExtValue(); // Match: // Index = step(const) << shift(const) - else if (NewIndex.getOpcode() == ISD::SHL && - NewIndex.getOperand(0).getOpcode() == ISD::STEP_VECTOR) { - SDValue RHS = NewIndex.getOperand(1); + else if (Index.getOpcode() == ISD::SHL && + Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) { + SDValue RHS = Index.getOperand(1); if (auto *Shift = dyn_cast_or_null(DAG.getSplatValue(RHS))) { - int64_t Step = (int64_t)NewIndex.getOperand(0).getConstantOperandVal(1); + int64_t Step = (int64_t)Index.getOperand(0).getConstantOperandVal(1); Stride = Step << Shift->getZExtValue(); } } // Return early because no supported pattern is found. if (Stride == 0) - return false; + return Changed; if (Stride < std::numeric_limits::min() || Stride > std::numeric_limits::max()) - return false; + return Changed; const auto &Subtarget = static_cast(DAG.getSubtarget()); @@ -17183,14 +17149,13 @@ if (LastElementOffset < std::numeric_limits::min() || LastElementOffset > std::numeric_limits::max()) - return false; + return Changed; EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32); // Stride does not scale explicitly by 'Scale', because it happens in // the gather/scatter addressing mode. Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT, DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32)); - BasePtr = NewBasePtr; return true; } diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll @@ -1155,7 +1155,6 @@ ret void } -; FIXME: This case does not yet codegen well due to deficiencies in opcode selection define void @masked_gather_vec_plus_reg(<32 x float>* %a, <32 x i8*>* %b, i64 %off) #0 { ; VBITS_GE_2048-LABEL: masked_gather_vec_plus_reg: ; VBITS_GE_2048: // %bb.0: @@ -1163,11 +1162,9 @@ ; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 ; VBITS_GE_2048-NEXT: ld1w { z0.s }, p0/z, [x0] ; VBITS_GE_2048-NEXT: ld1d { z1.d }, p1/z, [x1] -; VBITS_GE_2048-NEXT: mov z2.d, x2 ; VBITS_GE_2048-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0 -; VBITS_GE_2048-NEXT: add z0.d, z1.d, z2.d ; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: ld1w { z0.d }, p1/z, [z0.d] +; VBITS_GE_2048-NEXT: ld1w { z0.d }, p1/z, [x2, z1.d] ; VBITS_GE_2048-NEXT: uzp1 z0.s, z0.s, z0.s ; VBITS_GE_2048-NEXT: st1w { z0.s }, p0, [x0] ; VBITS_GE_2048-NEXT: ret @@ -1181,7 +1178,6 @@ ret void } -; FIXME: This case does not yet codegen well due to deficiencies in opcode selection define void @masked_gather_vec_plus_imm(<32 x float>* %a, <32 x i8*>* %b) #0 { ; VBITS_GE_2048-LABEL: masked_gather_vec_plus_imm: ; VBITS_GE_2048: // %bb.0: @@ -1190,9 +1186,8 @@ ; VBITS_GE_2048-NEXT: ld1w { z0.s }, p0/z, [x0] ; VBITS_GE_2048-NEXT: ld1d { z1.d }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0 -; VBITS_GE_2048-NEXT: add z1.d, z1.d, #4 ; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: ld1w { z0.d }, p1/z, [z1.d] +; VBITS_GE_2048-NEXT: ld1w { z0.d }, p1/z, [z1.d, #4] ; VBITS_GE_2048-NEXT: uzp1 z0.s, z0.s, z0.s ; VBITS_GE_2048-NEXT: st1w { z0.s }, p0, [x0] ; VBITS_GE_2048-NEXT: ret diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll @@ -1051,7 +1051,6 @@ ret void } -; FIXME: This case does not yet codegen well due to deficiencies in opcode selection define void @masked_scatter_vec_plus_reg(<32 x float>* %a, <32 x i8*>* %b, i64 %off) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_vec_plus_reg: ; VBITS_GE_2048: // %bb.0: @@ -1059,12 +1058,10 @@ ; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 ; VBITS_GE_2048-NEXT: ld1w { z0.s }, p0/z, [x0] ; VBITS_GE_2048-NEXT: ld1d { z1.d }, p1/z, [x1] -; VBITS_GE_2048-NEXT: mov z2.d, x2 ; VBITS_GE_2048-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 -; VBITS_GE_2048-NEXT: add z1.d, z1.d, z2.d ; VBITS_GE_2048-NEXT: uunpklo z0.d, z0.s ; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: st1w { z0.d }, p0, [z1.d] +; VBITS_GE_2048-NEXT: st1w { z0.d }, p0, [x2, z1.d] ; VBITS_GE_2048-NEXT: ret %vals = load <32 x float>, <32 x float>* %a %bases = load <32 x i8*>, <32 x i8*>* %b @@ -1075,7 +1072,6 @@ ret void } -; FIXME: This case does not yet codegen well due to deficiencies in opcode selection define void @masked_scatter_vec_plus_imm(<32 x float>* %a, <32 x i8*>* %b) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_vec_plus_imm: ; VBITS_GE_2048: // %bb.0: @@ -1084,10 +1080,9 @@ ; VBITS_GE_2048-NEXT: ld1w { z0.s }, p0/z, [x0] ; VBITS_GE_2048-NEXT: ld1d { z1.d }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 -; VBITS_GE_2048-NEXT: add z1.d, z1.d, #4 ; VBITS_GE_2048-NEXT: uunpklo z0.d, z0.s ; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: st1w { z0.d }, p0, [z1.d] +; VBITS_GE_2048-NEXT: st1w { z0.d }, p0, [z1.d, #4] ; VBITS_GE_2048-NEXT: ret %vals = load <32 x float>, <32 x float>* %a %bases = load <32 x i8*>, <32 x i8*>* %b diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -105,20 +105,18 @@ ; CHECK-NEXT: rdvl x8, #1 ; CHECK-NEXT: mov w9, #67108864 ; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: add x10, x0, x1 ; CHECK-NEXT: punpklo p1.h, p0.b -; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: uunpklo z3.d, z0.s ; CHECK-NEXT: mul x8, x8, x9 ; CHECK-NEXT: mov w9, #33554432 -; CHECK-NEXT: index z2.d, #0, x9 -; CHECK-NEXT: mov z3.d, x8 -; CHECK-NEXT: add z3.d, z2.d, z3.d -; CHECK-NEXT: add z2.d, z2.d, z1.d -; CHECK-NEXT: add z1.d, z3.d, z1.d -; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: punpkhi p0.h, p0.b ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] -; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: index z1.d, #0, x9 +; CHECK-NEXT: mov z2.d, x8 +; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d] +; CHECK-NEXT: add z2.d, z1.d, z2.d +; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer @@ -140,20 +138,18 @@ ; CHECK-NEXT: mov x9, #-2 ; CHECK-NEXT: lsr x8, x8, #4 ; CHECK-NEXT: movk x9, #64511, lsl #16 -; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: add x10, x0, x1 ; CHECK-NEXT: punpklo p1.h, p0.b ; CHECK-NEXT: mul x8, x8, x9 ; CHECK-NEXT: mov x9, #-33554433 -; CHECK-NEXT: punpkhi p0.h, p0.b -; CHECK-NEXT: index z2.d, #0, x9 -; CHECK-NEXT: mov z3.d, x8 -; CHECK-NEXT: add z3.d, z2.d, z3.d -; CHECK-NEXT: add z2.d, z2.d, z1.d -; CHECK-NEXT: add z1.d, z3.d, z1.d ; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: punpkhi p0.h, p0.b ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] -; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: index z1.d, #0, x9 +; CHECK-NEXT: mov z2.d, x8 +; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d] +; CHECK-NEXT: add z2.d, z1.d, z2.d +; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer @@ -174,20 +170,18 @@ ; CHECK-NEXT: rdvl x8, #1 ; CHECK-NEXT: mov x9, #-9223372036854775808 ; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: mov z1.d, x1 +; CHECK-NEXT: add x10, x0, x1 ; CHECK-NEXT: punpklo p1.h, p0.b -; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: uunpklo z3.d, z0.s ; CHECK-NEXT: mul x8, x8, x9 ; CHECK-NEXT: mov x9, #4611686018427387904 -; CHECK-NEXT: index z2.d, #0, x9 -; CHECK-NEXT: mov z3.d, x8 -; CHECK-NEXT: add z3.d, z2.d, z3.d -; CHECK-NEXT: add z2.d, z2.d, z1.d -; CHECK-NEXT: add z1.d, z3.d, z1.d -; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: punpkhi p0.h, p0.b ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: st1b { z3.d }, p1, [x0, z2.d] -; CHECK-NEXT: st1b { z0.d }, p0, [x0, z1.d] +; CHECK-NEXT: index z1.d, #0, x9 +; CHECK-NEXT: mov z2.d, x8 +; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d] +; CHECK-NEXT: add z2.d, z1.d, z2.d +; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer @@ -346,9 +340,7 @@ define @masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets( %vector_offsets, i64 %scalar_offset, %pg) #0 { ; CHECK-LABEL: masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: mov x8, xzr -; CHECK-NEXT: mov z1.d, x0 -; CHECK-NEXT: add z0.d, z0.d, z1.d +; CHECK-NEXT: lsl x8, x0, #3 ; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 %scalar_offset, i64 0 @@ -362,8 +354,7 @@ define @masked_gather_nxv2i64_null_with__vec_plus_imm_offsets( %vector_offsets, %pg) #0 { ; CHECK-LABEL: masked_gather_nxv2i64_null_with__vec_plus_imm_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: mov x8, xzr -; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1 +; CHECK-NEXT: mov w8, #8 ; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 1, i64 0 @@ -427,9 +418,7 @@ define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets( %vector_offsets, i64 %scalar_offset, %pg, %data) #0 { ; CHECK-LABEL: masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: mov x8, xzr -; CHECK-NEXT: mov z2.d, x0 -; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: lsl x8, x0, #3 ; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 %scalar_offset, i64 0 @@ -443,8 +432,7 @@ define void @masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets( %vector_offsets, %pg, %data) #0 { ; CHECK-LABEL: masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: mov x8, xzr -; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1 +; CHECK-NEXT: mov w8, #8 ; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 1, i64 0