diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17655,6 +17655,16 @@ static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N, SDValue &BasePtr, SDValue &Index, SelectionDAG &DAG) { + if (Index.getOpcode() == ISD::ZERO_EXTEND) { + SDValue IndexOp0 = Index.getOperand(0); + if (IndexOp0.getValueType().getScalarSizeInBits() < 32 && + !N->isIndexSigned()) { + Index = DAG.getNode( + ISD::ZERO_EXTEND, SDLoc(N), + IndexOp0.getValueType().changeVectorElementType(MVT::i32), IndexOp0); + return true; + } + } // Try to iteratively fold parts of the index into the base pointer to // simplify the index as much as possible. bool Changed = false; diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll @@ -73,30 +73,18 @@ ; CHECK-LABEL: narrow_i64_gather_index_i8: ; CHECK: // %bb.0: ; CHECK-NEXT: add x8, x1, x2 -; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: ld1b { z0.d }, p0/z, [x1, x2] -; CHECK-NEXT: ld1b { z1.d }, p0/z, [x8, #1, mul vl] -; CHECK-NEXT: ld1b { z2.d }, p0/z, [x8, #2, mul vl] -; CHECK-NEXT: ld1b { z3.d }, p0/z, [x8, #3, mul vl] -; CHECK-NEXT: ld1b { z4.d }, p0/z, [x8, #4, mul vl] -; CHECK-NEXT: ld1b { z5.d }, p0/z, [x8, #5, mul vl] -; CHECK-NEXT: ld1b { z6.d }, p0/z, [x8, #6, mul vl] -; CHECK-NEXT: ld1b { z7.d }, p0/z, [x8, #7, mul vl] -; CHECK-NEXT: ld1b { z7.d }, p0/z, [x1, z7.d] -; CHECK-NEXT: ld1b { z6.d }, p0/z, [x1, z6.d] -; CHECK-NEXT: ld1b { z5.d }, p0/z, [x1, z5.d] -; CHECK-NEXT: ld1b { z4.d }, p0/z, [x1, z4.d] -; CHECK-NEXT: ld1b { z3.d }, p0/z, [x1, z3.d] -; CHECK-NEXT: ld1b { z2.d }, p0/z, [x1, z2.d] -; CHECK-NEXT: ld1b { z0.d }, p0/z, [x1, z0.d] -; CHECK-NEXT: ld1b { z1.d }, p0/z, [x1, z1.d] -; CHECK-NEXT: uzp1 z6.s, z6.s, z7.s -; CHECK-NEXT: uzp1 z4.s, z4.s, z5.s -; CHECK-NEXT: uzp1 z2.s, z2.s, z3.s -; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s -; CHECK-NEXT: uzp1 z1.h, z4.h, z6.h -; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h -; CHECK-NEXT: uzp1 z0.b, z0.b, z1.b +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1b { z0.s }, p0/z, [x1, x2] +; CHECK-NEXT: ld1b { z1.s }, p0/z, [x8, #1, mul vl] +; CHECK-NEXT: ld1b { z2.s }, p0/z, [x8, #2, mul vl] +; CHECK-NEXT: ld1b { z3.s }, p0/z, [x8, #3, mul vl] +; CHECK-NEXT: ld1b { z3.s }, p0/z, [x1, z3.s, uxtw] +; CHECK-NEXT: ld1b { z2.s }, p0/z, [x1, z2.s, uxtw] +; CHECK-NEXT: ld1b { z0.s }, p0/z, [x1, z0.s, uxtw] +; CHECK-NEXT: ld1b { z1.s }, p0/z, [x1, z1.s, uxtw] +; CHECK-NEXT: uzp1 z2.h, z2.h, z3.h +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h +; CHECK-NEXT: uzp1 z0.b, z0.b, z2.b ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, i8* %in, i64 %ptr %2 = bitcast i8* %1 to * @@ -111,18 +99,12 @@ ; CHECK-LABEL: narrow_i64_gather_index_i16: ; CHECK: // %bb.0: ; CHECK-NEXT: add x8, x1, x2, lsl #1 -; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: ld1h { z0.d }, p0/z, [x1, x2, lsl #1] -; CHECK-NEXT: ld1h { z1.d }, p0/z, [x8, #1, mul vl] -; CHECK-NEXT: ld1h { z2.d }, p0/z, [x8, #2, mul vl] -; CHECK-NEXT: ld1h { z3.d }, p0/z, [x8, #3, mul vl] -; CHECK-NEXT: ld1h { z3.d }, p0/z, [x1, z3.d, lsl #1] -; CHECK-NEXT: ld1h { z2.d }, p0/z, [x1, z2.d, lsl #1] -; CHECK-NEXT: ld1h { z0.d }, p0/z, [x1, z0.d, lsl #1] -; CHECK-NEXT: ld1h { z1.d }, p0/z, [x1, z1.d, lsl #1] -; CHECK-NEXT: uzp1 z2.s, z2.s, z3.s -; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s -; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1, x2, lsl #1] +; CHECK-NEXT: ld1h { z1.s }, p0/z, [x8, #1, mul vl] +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1, z0.s, uxtw #1] +; CHECK-NEXT: ld1h { z1.s }, p0/z, [x1, z1.s, uxtw #1] +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h ; CHECK-NEXT: ret %1 = getelementptr inbounds i16, i16* %in, i64 %ptr %2 = bitcast i16* %1 to *