diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -293,13 +293,25 @@ bool ISD::isVectorShrinkable(const SDNode *N, unsigned NewEltSize, bool Signed) { - if (N->getOpcode() != ISD::BUILD_VECTOR) - return false; + assert(N->getValueType(0).isVector() && "Expected a vector!"); unsigned EltSize = N->getValueType(0).getScalarSizeInBits(); if (EltSize <= NewEltSize) return false; + if (N->getOpcode() == ISD::ZERO_EXTEND) { + return (N->getOperand(0).getValueType().getScalarSizeInBits() <= + NewEltSize) && + !Signed; + } + if (N->getOpcode() == ISD::SIGN_EXTEND) { + return (N->getOperand(0).getValueType().getScalarSizeInBits() <= + NewEltSize) && + Signed; + } + if (N->getOpcode() != ISD::BUILD_VECTOR) + return false; + for (const SDValue &Op : N->op_values()) { if (Op.isUndef()) continue; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17868,7 +17868,9 @@ return Changed; // Can indices be trivially shrunk? - if (ISD::isVectorShrinkable(Index.getNode(), 32, N->isIndexSigned())) { + EVT DataVT = N->getOperand(1).getValueType(); + if (ISD::isVectorShrinkable(Index.getNode(), 32, N->isIndexSigned()) && + !(DataVT.getScalarSizeInBits() == 64 && DataVT.isFixedLengthVector())) { EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32); Index = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NewIndexVT, Index); return true; diff --git a/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll b/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll --- a/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll +++ b/llvm/test/CodeGen/AArch64/sve-gather-scatter-dag-combine.ll @@ -69,34 +69,22 @@ ret %res } -define @narrow_i64_gather_index_i8(i8* %out, i8* %in, %d, i64 %ptr){ -; CHECK-LABEL: narrow_i64_gather_index_i8: +define @narrow_i64_gather_index_i8_zext(i8* %out, i8* %in, %d, i64 %ptr){ +; CHECK-LABEL: narrow_i64_gather_index_i8_zext: ; CHECK: // %bb.0: ; CHECK-NEXT: add x8, x1, x2 -; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: ld1b { z0.d }, p0/z, [x1, x2] -; CHECK-NEXT: ld1b { z1.d }, p0/z, [x8, #1, mul vl] -; CHECK-NEXT: ld1b { z2.d }, p0/z, [x8, #2, mul vl] -; CHECK-NEXT: ld1b { z3.d }, p0/z, [x8, #3, mul vl] -; CHECK-NEXT: ld1b { z4.d }, p0/z, [x8, #4, mul vl] -; CHECK-NEXT: ld1b { z5.d }, p0/z, [x8, #5, mul vl] -; CHECK-NEXT: ld1b { z6.d }, p0/z, [x8, #6, mul vl] -; CHECK-NEXT: ld1b { z7.d }, p0/z, [x8, #7, mul vl] -; CHECK-NEXT: ld1b { z7.d }, p0/z, [x1, z7.d] -; CHECK-NEXT: ld1b { z6.d }, p0/z, [x1, z6.d] -; CHECK-NEXT: ld1b { z5.d }, p0/z, [x1, z5.d] -; CHECK-NEXT: ld1b { z4.d }, p0/z, [x1, z4.d] -; CHECK-NEXT: ld1b { z3.d }, p0/z, [x1, z3.d] -; CHECK-NEXT: ld1b { z2.d }, p0/z, [x1, z2.d] -; CHECK-NEXT: ld1b { z0.d }, p0/z, [x1, z0.d] -; CHECK-NEXT: ld1b { z1.d }, p0/z, [x1, z1.d] -; CHECK-NEXT: uzp1 z6.s, z6.s, z7.s -; CHECK-NEXT: uzp1 z4.s, z4.s, z5.s -; CHECK-NEXT: uzp1 z2.s, z2.s, z3.s -; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s -; CHECK-NEXT: uzp1 z1.h, z4.h, z6.h -; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h -; CHECK-NEXT: uzp1 z0.b, z0.b, z1.b +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1b { z0.s }, p0/z, [x1, x2] +; CHECK-NEXT: ld1b { z1.s }, p0/z, [x8, #1, mul vl] +; CHECK-NEXT: ld1b { z2.s }, p0/z, [x8, #2, mul vl] +; CHECK-NEXT: ld1b { z3.s }, p0/z, [x8, #3, mul vl] +; CHECK-NEXT: ld1b { z3.s }, p0/z, [x1, z3.s, uxtw] +; CHECK-NEXT: ld1b { z2.s }, p0/z, [x1, z2.s, uxtw] +; CHECK-NEXT: ld1b { z0.s }, p0/z, [x1, z0.s, uxtw] +; CHECK-NEXT: ld1b { z1.s }, p0/z, [x1, z1.s, uxtw] +; CHECK-NEXT: uzp1 z2.h, z2.h, z3.h +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h +; CHECK-NEXT: uzp1 z0.b, z0.b, z2.b ; CHECK-NEXT: ret %1 = getelementptr inbounds i8, i8* %in, i64 %ptr %2 = bitcast i8* %1 to * @@ -107,22 +95,42 @@ ret %wide.masked.gather } -define @narrow_i64_gather_index_i16(i16* %out, i16* %in, %d, i64 %ptr){ -; CHECK-LABEL: narrow_i64_gather_index_i16: +define @narrow_i64_gather_index_i8_sext(i8* %out, i8* %in, %d, i64 %ptr){ +; CHECK-LABEL: narrow_i64_gather_index_i8_sext: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x1, x2 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1sb { z0.s }, p0/z, [x1, x2] +; CHECK-NEXT: ld1sb { z1.s }, p0/z, [x8, #1, mul vl] +; CHECK-NEXT: ld1sb { z2.s }, p0/z, [x8, #2, mul vl] +; CHECK-NEXT: ld1sb { z3.s }, p0/z, [x8, #3, mul vl] +; CHECK-NEXT: ld1b { z3.s }, p0/z, [x1, z3.s, sxtw] +; CHECK-NEXT: ld1b { z2.s }, p0/z, [x1, z2.s, sxtw] +; CHECK-NEXT: ld1b { z0.s }, p0/z, [x1, z0.s, sxtw] +; CHECK-NEXT: ld1b { z1.s }, p0/z, [x1, z1.s, sxtw] +; CHECK-NEXT: uzp1 z2.h, z2.h, z3.h +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h +; CHECK-NEXT: uzp1 z0.b, z0.b, z2.b +; CHECK-NEXT: ret + %1 = getelementptr inbounds i8, i8* %in, i64 %ptr + %2 = bitcast i8* %1 to * + %wide.load = load , * %2, align 1 + %3 = sext %wide.load to + %4 = getelementptr inbounds i8, i8* %in, %3 + %wide.masked.gather = call @llvm.masked.gather.nxv16i8.nxv16p0( %4, i32 1, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer), undef) + ret %wide.masked.gather +} + +define @narrow_i64_gather_index_i16_zext(i16* %out, i16* %in, %d, i64 %ptr){ +; CHECK-LABEL: narrow_i64_gather_index_i16_zext: ; CHECK: // %bb.0: ; CHECK-NEXT: add x8, x1, x2, lsl #1 -; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: ld1h { z0.d }, p0/z, [x1, x2, lsl #1] -; CHECK-NEXT: ld1h { z1.d }, p0/z, [x8, #1, mul vl] -; CHECK-NEXT: ld1h { z2.d }, p0/z, [x8, #2, mul vl] -; CHECK-NEXT: ld1h { z3.d }, p0/z, [x8, #3, mul vl] -; CHECK-NEXT: ld1h { z3.d }, p0/z, [x1, z3.d, lsl #1] -; CHECK-NEXT: ld1h { z2.d }, p0/z, [x1, z2.d, lsl #1] -; CHECK-NEXT: ld1h { z0.d }, p0/z, [x1, z0.d, lsl #1] -; CHECK-NEXT: ld1h { z1.d }, p0/z, [x1, z1.d, lsl #1] -; CHECK-NEXT: uzp1 z2.s, z2.s, z3.s -; CHECK-NEXT: uzp1 z0.s, z0.s, z1.s -; CHECK-NEXT: uzp1 z0.h, z0.h, z2.h +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1, x2, lsl #1] +; CHECK-NEXT: ld1h { z1.s }, p0/z, [x8, #1, mul vl] +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1, z0.s, uxtw #1] +; CHECK-NEXT: ld1h { z1.s }, p0/z, [x1, z1.s, uxtw #1] +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h ; CHECK-NEXT: ret %1 = getelementptr inbounds i16, i16* %in, i64 %ptr %2 = bitcast i16* %1 to * @@ -133,6 +141,26 @@ ret %wide.masked.gather } +define @narrow_i64_gather_index_i16_sext(i16* %out, i16* %in, %d, i64 %ptr){ +; CHECK-LABEL: narrow_i64_gather_index_i16_sext: +; CHECK: // %bb.0: +; CHECK-NEXT: add x8, x1, x2, lsl #1 +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1sh { z0.s }, p0/z, [x1, x2, lsl #1] +; CHECK-NEXT: ld1sh { z1.s }, p0/z, [x8, #1, mul vl] +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1, z0.s, sxtw #1] +; CHECK-NEXT: ld1h { z1.s }, p0/z, [x1, z1.s, sxtw #1] +; CHECK-NEXT: uzp1 z0.h, z0.h, z1.h +; CHECK-NEXT: ret + %1 = getelementptr inbounds i16, i16* %in, i64 %ptr + %2 = bitcast i16* %1 to * + %wide.load = load , * %2, align 1 + %3 = sext %wide.load to + %4 = getelementptr inbounds i16, i16* %in, %3 + %wide.masked.gather = call @llvm.masked.gather.nxv8i16.nxv8p0( %4, i32 1, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer), undef) + ret %wide.masked.gather +} + define @no_narrow_i64_gather_index_i32(i32* %out, i32* %in, %d, i64 %ptr){ ; CHECK-LABEL: no_narrow_i64_gather_index_i32: ; CHECK: // %bb.0: