diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -118,6 +118,10 @@ /// ConstantFPSDNode or undef. bool isBuildVectorOfConstantFPSDNodes(const SDNode *N); +/// Returns true if the specified node is a vector where all elements can +/// be truncated to the specified element size without a loss in meaning. +bool isVectorShrinkable(const SDNode *N, unsigned NewEltSize, bool Signed); + /// Return true if the node has at least one operand and all operands of the /// specified node are ISD::UNDEF. bool allOperandsUndef(const SDNode *N); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -291,6 +291,31 @@ return true; } +bool ISD::isVectorShrinkable(const SDNode *N, unsigned NewEltSize, + bool Signed) { + if (N->getOpcode() != ISD::BUILD_VECTOR) + return false; + + unsigned EltSize = N->getValueType(0).getScalarSizeInBits(); + if (EltSize <= NewEltSize) + return false; + + for (const SDValue &Op : N->op_values()) { + if (Op.isUndef()) + continue; + if (!isa(Op)) + return false; + + APInt C = cast(Op)->getAPIntValue().trunc(EltSize); + if (Signed && C.trunc(NewEltSize).sext(EltSize) != C) + return false; + if (!Signed && C.trunc(NewEltSize).zext(EltSize) != C) + return false; + } + + return true; +} + bool ISD::allOperandsUndef(const SDNode *N) { // Return false if the node has no operands. // This is "logically inconsistent" with the definition of "all" but diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17827,12 +17827,19 @@ if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64) return Changed; + // Can indices be trivially shrunk? + if (ISD::isVectorShrinkable(Index.getNode(), 32, N->isIndexSigned())) { + EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32); + Index = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NewIndexVT, Index); + return true; + } + // Match: // Index = step(const) int64_t Stride = 0; - if (Index.getOpcode() == ISD::STEP_VECTOR) + if (Index.getOpcode() == ISD::STEP_VECTOR) { Stride = cast(Index.getOperand(0))->getSExtValue(); - + } // Match: // Index = step(const) << shift(const) else if (Index.getOpcode() == ISD::SHL && @@ -17866,8 +17873,7 @@ EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32); // Stride does not scale explicitly by 'Scale', because it happens in // the gather/scatter addressing mode. - Index = DAG.getNode(ISD::STEP_VECTOR, SDLoc(N), NewIndexVT, - DAG.getTargetConstant(Stride, SDLoc(N), MVT::i32)); + Index = DAG.getStepVector(SDLoc(N), NewIndexVT, APInt(32, Stride)); return true; } diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-addressing-modes.ll @@ -1,35 +1,17 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_256 -; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s -check-prefixes=CHECK,VBITS_GE_512 +; RUN: llc -aarch64-sve-vector-bits-min=256 < %s | FileCheck %s +; RUN: llc -aarch64-sve-vector-bits-min=512 < %s | FileCheck %s target triple = "aarch64-unknown-linux-gnu" define void @masked_gather_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 { -; VBITS_GE_256-LABEL: masked_gather_base_plus_stride_v8f32: -; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: index z0.d, #0, #7 -; VBITS_GE_256-NEXT: ptrue p0.d, vl4 -; VBITS_GE_256-NEXT: mov z1.d, z0.d -; VBITS_GE_256-NEXT: ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2] -; VBITS_GE_256-NEXT: add z1.d, z1.d, #28 // =0x1c -; VBITS_GE_256-NEXT: ld1w { z1.d }, p0/z, [x1, z1.d, lsl #2] -; VBITS_GE_256-NEXT: ptrue p0.s, vl4 -; VBITS_GE_256-NEXT: uzp1 z0.s, z0.s, z0.s -; VBITS_GE_256-NEXT: uzp1 z1.s, z1.s, z1.s -; VBITS_GE_256-NEXT: splice z0.s, p0, z0.s, z1.s -; VBITS_GE_256-NEXT: ptrue p0.s, vl8 -; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0] -; VBITS_GE_256-NEXT: ret -; -; VBITS_GE_512-LABEL: masked_gather_base_plus_stride_v8f32: -; VBITS_GE_512: // %bb.0: -; VBITS_GE_512-NEXT: index z0.d, #0, #7 -; VBITS_GE_512-NEXT: ptrue p0.d, vl8 -; VBITS_GE_512-NEXT: ld1w { z0.d }, p0/z, [x1, z0.d, lsl #2] -; VBITS_GE_512-NEXT: ptrue p0.s, vl8 -; VBITS_GE_512-NEXT: uzp1 z0.s, z0.s, z0.s -; VBITS_GE_512-NEXT: st1w { z0.s }, p0, [x0] -; VBITS_GE_512-NEXT: ret +; CHECK-LABEL: masked_gather_base_plus_stride_v8f32: +; CHECK: // %bb.0: +; CHECK-NEXT: index z0.s, #0, #7 +; CHECK-NEXT: ptrue p0.s, vl8 +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x1, z0.s, sxtw #2] +; CHECK-NEXT: st1w { z0.s }, p0, [x0] +; CHECK-NEXT: ret %ptrs = getelementptr float, ptr %src, <8 x i64> %data = tail call <8 x float> @llvm.masked.gather.v8f32.v8p0(<8 x ptr> %ptrs, i32 4, <8 x i1> , <8 x float> undef) store <8 x float> %data, ptr %dst, align 4 @@ -52,30 +34,13 @@ } define void @masked_scatter_base_plus_stride_v8f32(ptr %dst, ptr %src) #0 { -; VBITS_GE_256-LABEL: masked_scatter_base_plus_stride_v8f32: -; VBITS_GE_256: // %bb.0: -; VBITS_GE_256-NEXT: ptrue p0.s, vl8 -; VBITS_GE_256-NEXT: mov z1.d, #-28 // =0xffffffffffffffe4 -; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x1] -; VBITS_GE_256-NEXT: index z2.d, #0, #-7 -; VBITS_GE_256-NEXT: add z1.d, z2.d, z1.d -; VBITS_GE_256-NEXT: ptrue p0.d, vl4 -; VBITS_GE_256-NEXT: uunpklo z3.d, z0.s -; VBITS_GE_256-NEXT: ext z0.b, z0.b, z0.b, #16 -; VBITS_GE_256-NEXT: uunpklo z0.d, z0.s -; VBITS_GE_256-NEXT: st1w { z3.d }, p0, [x0, z2.d, lsl #2] -; VBITS_GE_256-NEXT: st1w { z0.d }, p0, [x0, z1.d, lsl #2] -; VBITS_GE_256-NEXT: ret -; -; VBITS_GE_512-LABEL: masked_scatter_base_plus_stride_v8f32: -; VBITS_GE_512: // %bb.0: -; VBITS_GE_512-NEXT: ptrue p0.s, vl8 -; VBITS_GE_512-NEXT: index z1.d, #0, #-7 -; VBITS_GE_512-NEXT: ld1w { z0.s }, p0/z, [x1] -; VBITS_GE_512-NEXT: ptrue p0.d, vl8 -; VBITS_GE_512-NEXT: uunpklo z0.d, z0.s -; VBITS_GE_512-NEXT: st1w { z0.d }, p0, [x0, z1.d, lsl #2] -; VBITS_GE_512-NEXT: ret +; CHECK-LABEL: masked_scatter_base_plus_stride_v8f32: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s, vl8 +; CHECK-NEXT: index z1.s, #0, #-7 +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x1] +; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2] +; CHECK-NEXT: ret %data = load <8 x float>, ptr %src, align 4 %ptrs = getelementptr float, ptr %dst, <8 x i64> tail call void @llvm.masked.scatter.v8f32.v8p0(<8 x float> %data, <8 x ptr> %ptrs, i32 4, <8 x i1> )