diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10426,14 +10426,19 @@ TopHalf->isZero() ? RHS->getOperand(1) : LHS->getOperand(1)); } -bool refineUniformBase(SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) { +bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled, + SelectionDAG &DAG) { if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD) return false; + // Only perform the transformation when existing operands can be reused. + if (IndexIsScaled) + return false; + // For now we check only the LHS of the add. SDValue LHS = Index.getOperand(0); SDValue SplatVal = DAG.getSplatValue(LHS); - if (!SplatVal) + if (!SplatVal || SplatVal.getValueType() != BasePtr.getValueType()) return false; BasePtr = SplatVal; @@ -10481,7 +10486,7 @@ if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return Chain; - if (refineUniformBase(BasePtr, Index, DAG)) { + if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter( DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, @@ -10576,7 +10581,7 @@ if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return CombineTo(N, PassThru, MGT->getChain()); - if (refineUniformBase(BasePtr, Index, DAG)) { + if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, Ops, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4656,10 +4656,10 @@ // VECTOR + IMMEDIATE: // getelementptr nullptr, (splat(#x)) + %indices) // -> getelementptr #x, %indices -void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, EVT MemVT, - unsigned &Opcode, bool IsGather, - SelectionDAG &DAG) { - if (!isNullConstant(BasePtr)) +void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, + bool IsScaled, EVT MemVT, unsigned &Opcode, + bool IsGather, SelectionDAG &DAG) { + if (!isNullConstant(BasePtr) || IsScaled) return; // FIXME: This will not match for fixed vector type codegen as the nodes in @@ -4789,7 +4789,7 @@ Index = Index.getOperand(0); unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend); - selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode, + selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode, /*isGather=*/true, DAG); if (ExtType == ISD::SEXTLOAD) @@ -4898,7 +4898,7 @@ Index = Index.getOperand(0); unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend); - selectGatherScatterAddrMode(BasePtr, Index, MemVT, Opcode, + selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode, /*isGather=*/false, DAG); if (IsFixedLength) { diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -343,12 +343,13 @@ ret %data } -; TODO: The generated code is wrong because we've lost the scaling applied to -; %scalar_offset when it's used to calculate %ptrs. define @masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets( %vector_offsets, i64 %scalar_offset, %pg) #0 { ; CHECK-LABEL: masked_gather_nxv2i64_null_with_vec_plus_scalar_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1d { z0.d }, p0/z, [x0, z0.d, lsl #3] +; CHECK-NEXT: mov x8, xzr +; CHECK-NEXT: mov z1.d, x0 +; CHECK-NEXT: add z0.d, z0.d, z1.d +; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 %scalar_offset, i64 0 %scalar_offset.splat = shufflevector %scalar_offset.ins, undef, zeroinitializer @@ -358,12 +359,11 @@ ret %data } -; TODO: The generated code is wrong because we've lost the scaling applied to -; constant scalar offset (i.e. i64 1) when it's used to calculate %ptrs. define @masked_gather_nxv2i64_null_with__vec_plus_imm_offsets( %vector_offsets, %pg) #0 { ; CHECK-LABEL: masked_gather_nxv2i64_null_with__vec_plus_imm_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: mov w8, #1 +; CHECK-NEXT: mov x8, xzr +; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1 ; CHECK-NEXT: ld1d { z0.d }, p0/z, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 1, i64 0 @@ -425,12 +425,13 @@ ret void } -; TODO: The generated code is wrong because we've lost the scaling applied to -; %scalar_offset when it's used to calculate %ptrs. define void @masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets( %vector_offsets, i64 %scalar_offset, %pg, %data) #0 { ; CHECK-LABEL: masked_scatter_nxv2i64_null_with_vec_plus_scalar_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: st1d { z1.d }, p0, [x0, z0.d, lsl #3] +; CHECK-NEXT: mov x8, xzr +; CHECK-NEXT: mov z2.d, x0 +; CHECK-NEXT: add z0.d, z0.d, z2.d +; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 %scalar_offset, i64 0 %scalar_offset.splat = shufflevector %scalar_offset.ins, undef, zeroinitializer @@ -440,12 +441,11 @@ ret void } -; TODO: The generated code is wrong because we've lost the scaling applied to -; constant scalar offset (i.e. i64 1) when it's used to calculate %ptrs. define void @masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets( %vector_offsets, %pg, %data) #0 { ; CHECK-LABEL: masked_scatter_nxv2i64_null_with__vec_plus_imm_offsets: ; CHECK: // %bb.0: -; CHECK-NEXT: mov w8, #1 +; CHECK-NEXT: mov x8, xzr +; CHECK-NEXT: add z0.d, z0.d, #1 // =0x1 ; CHECK-NEXT: st1d { z1.d }, p0, [x8, z0.d, lsl #3] ; CHECK-NEXT: ret %scalar_offset.ins = insertelement undef, i64 1, i64 0