diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10408,9 +10408,10 @@ // Fold sext/zext of index into index type. bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index, - bool Scaled, SelectionDAG &DAG) { + bool Scaled, bool Signed, SelectionDAG &DAG) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + // It's always safe to look through zero extends. if (Index.getOpcode() == ISD::ZERO_EXTEND) { SDValue Op = Index.getOperand(0); MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED); @@ -10420,7 +10421,8 @@ } } - if (Index.getOpcode() == ISD::SIGN_EXTEND) { + // It's only safe to look through sign extends when Index is signed. + if (Index.getOpcode() == ISD::SIGN_EXTEND && Signed) { SDValue Op = Index.getOperand(0); MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED); if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { @@ -10453,7 +10455,8 @@ MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore()); } - if (refineIndexType(MSC, Index, MSC->isIndexScaled(), DAG)) { + if (refineIndexType(MSC, Index, MSC->isIndexScaled(), MSC->isIndexSigned(), + DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter( DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, @@ -10549,7 +10552,8 @@ MGT->getExtensionType()); } - if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) { + if (refineIndexType(MGT, Index, MGT->isIndexScaled(), MGT->isIndexSigned(), + DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, Ops, diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -399,14 +399,12 @@ ret %data } -; TODO: The generated code is wrong because we're replicating offset[31] across -; offset[32:63] even though the IR has explicitly zero'd those bits. define @masked_gather_nxv4i32_u32s8_offsets(i32* %base, %offsets, %mask) #0 { ; CHECK-LABEL: masked_gather_nxv4i32_u32s8_offsets: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p1.s ; CHECK-NEXT: sxtb z0.s, p1/m, z0.s -; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0, z0.s, sxtw #2] +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0, z0.s, uxtw #2] ; CHECK-NEXT: ret %offsets.sext = sext %offsets to %offsets.sext.zext = zext %offsets.sext to @@ -482,14 +480,12 @@ ret void } -; TODO: The generated code is wrong because we're replicating offset[31] across -; offset[32:63] even though the IR has explicitly zero'd those bits. define void @masked_scatter_nxv4i32_u32s8_offsets(i32* %base, %offsets, %mask, %data) #0 { ; CHECK-LABEL: masked_scatter_nxv4i32_u32s8_offsets: ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p1.s ; CHECK-NEXT: sxtb z0.s, p1/m, z0.s -; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, sxtw #2] +; CHECK-NEXT: st1w { z1.s }, p0, [x0, z0.s, uxtw #2] ; CHECK-NEXT: ret %offsets.sext = sext %offsets to %offsets.sext.zext = zext %offsets.sext to