diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1390,7 +1390,9 @@ // Returns true if VT is a legal index type for masked gathers/scatters // on this target - virtual bool shouldRemoveExtendFromGSIndex(EVT VT) const { return false; } + virtual bool shouldRemoveExtendFromGSIndex(EVT IndexVT, EVT DataVT) const { + return false; + } /// Return how the condition code should be treated: either it is legal, needs /// to be expanded to some other code sequence, or the target has a custom diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10441,14 +10441,14 @@ } // Fold sext/zext of index into index type. -bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, +bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType, EVT DataVT, SelectionDAG &DAG) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // It's always safe to look through zero extends. if (Index.getOpcode() == ISD::ZERO_EXTEND) { SDValue Op = Index.getOperand(0); - if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { + if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) { IndexType = ISD::UNSIGNED_SCALED; Index = Op; return true; @@ -10463,7 +10463,7 @@ if (Index.getOpcode() == ISD::SIGN_EXTEND && ISD::isIndexTypeSigned(IndexType)) { SDValue Op = Index.getOperand(0); - if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { + if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType(), DataVT)) { Index = Op; return true; } @@ -10494,7 +10494,7 @@ MSC->isTruncatingStore()); } - if (refineIndexType(Index, IndexType, DAG)) { + if (refineIndexType(Index, IndexType, StoreVal.getValueType(), DAG)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, MSC->getMemOperand(), IndexType, @@ -10590,7 +10590,7 @@ Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType()); } - if (refineIndexType(Index, IndexType, DAG)) { + if (refineIndexType(Index, IndexType, N->getValueType(0), DAG)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; return DAG.getMaskedGather( DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1090,7 +1090,7 @@ } bool shouldExtendGSIndex(EVT VT, EVT &EltTy) const override; - bool shouldRemoveExtendFromGSIndex(EVT VT) const override; + bool shouldRemoveExtendFromGSIndex(EVT IndexVT, EVT DataVT) const override; bool isVectorLoadExtDesirable(SDValue ExtVal) const override; bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; bool mayBeEmittedAsTailCall(const CallInst *CI) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4520,13 +4520,19 @@ return false; } -bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const { - if (VT.getVectorElementType() == MVT::i32 && - VT.getVectorElementCount().getKnownMinValue() >= 4 && - !VT.isFixedLengthVector()) - return true; +bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(EVT IndexVT, + EVT DataVT) const { + // SVE only supports implicit extension of 32-bit indices. + if (!Subtarget->hasSVE() || IndexVT.getVectorElementType() != MVT::i32) + return false; - return false; + // Indices cannot be smaller than the main data type. + if (IndexVT.getScalarSizeInBits() < DataVT.getScalarSizeInBits()) + return false; + + // Scalable vectors with "vscale * 2" or fewer elements sit within a 64-bit + // element container type, which would violate the previous clause. + return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2; } bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -560,7 +560,7 @@ const RISCVRegisterInfo *TRI); MVT getContainerForFixedLengthVector(MVT VT) const; - bool shouldRemoveExtendFromGSIndex(EVT VT) const override; + bool shouldRemoveExtendFromGSIndex(EVT IndexVT, EVT DataVT) const override; bool isLegalElementTypeForRVV(Type *ScalarTy) const; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -11656,7 +11656,8 @@ return Result; } -bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(EVT VT) const { +bool RISCVTargetLowering::shouldRemoveExtendFromGSIndex(EVT IndexVT, + EVT DataVT) const { return false; } diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll @@ -959,19 +959,16 @@ ; The above tests test the types, the below tests check that the addressing ; modes still function -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_gather_32b_scaled_sext_f16(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { ; VBITS_GE_2048-LABEL: masked_gather_32b_scaled_sext_f16: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1sw { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p1.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: ld1h { z0.d }, p1/z, [x2, z1.d, lsl #1] -; VBITS_GE_2048-NEXT: uzp1 z0.s, z0.s, z0.s +; VBITS_GE_2048-NEXT: ld1h { z0.s }, p1/z, [x2, z1.s, sxtw #1] ; VBITS_GE_2048-NEXT: uzp1 z0.h, z0.h, z0.h ; VBITS_GE_2048-NEXT: st1h { z0.h }, p0, [x0] ; VBITS_GE_2048-NEXT: ret @@ -985,18 +982,14 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_gather_32b_scaled_sext_f32(<32 x float>* %a, <32 x i32>* %b, float* %base) #0 { ; VBITS_GE_2048-LABEL: masked_gather_32b_scaled_sext_f32: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.s, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 ; VBITS_GE_2048-NEXT: ld1w { z0.s }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1sw { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p0/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0 -; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: ld1w { z0.d }, p1/z, [x2, z1.d, lsl #2] -; VBITS_GE_2048-NEXT: uzp1 z0.s, z0.s, z0.s +; VBITS_GE_2048-NEXT: ld1w { z0.s }, p1/z, [x2, z1.s, sxtw #2] ; VBITS_GE_2048-NEXT: st1w { z0.s }, p0, [x0] ; VBITS_GE_2048-NEXT: ret %cvals = load <32 x float>, <32 x float>* %a @@ -1009,7 +1002,6 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_gather_32b_scaled_sext_f64(<32 x double>* %a, <32 x i32>* %b, double* %base) #0 { ; VBITS_GE_2048-LABEL: masked_gather_32b_scaled_sext_f64: ; VBITS_GE_2048: // %bb.0: @@ -1030,19 +1022,16 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_gather_32b_scaled_zext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { ; VBITS_GE_2048-LABEL: masked_gather_32b_scaled_zext: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1w { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p1.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: ld1h { z0.d }, p1/z, [x2, z1.d, lsl #1] -; VBITS_GE_2048-NEXT: uzp1 z0.s, z0.s, z0.s +; VBITS_GE_2048-NEXT: ld1h { z0.s }, p1/z, [x2, z1.s, uxtw #1] ; VBITS_GE_2048-NEXT: uzp1 z0.h, z0.h, z0.h ; VBITS_GE_2048-NEXT: st1h { z0.h }, p0, [x0] ; VBITS_GE_2048-NEXT: ret @@ -1056,19 +1045,16 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_gather_32b_unscaled_sext(<32 x half>* %a, <32 x i32>* %b, i8* %base) #0 { ; VBITS_GE_2048-LABEL: masked_gather_32b_unscaled_sext: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1sw { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p1.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: ld1h { z0.d }, p1/z, [x2, z1.d] -; VBITS_GE_2048-NEXT: uzp1 z0.s, z0.s, z0.s +; VBITS_GE_2048-NEXT: ld1h { z0.s }, p1/z, [x2, z1.s, sxtw] ; VBITS_GE_2048-NEXT: uzp1 z0.h, z0.h, z0.h ; VBITS_GE_2048-NEXT: st1h { z0.h }, p0, [x0] ; VBITS_GE_2048-NEXT: ret @@ -1083,19 +1069,16 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_gather_32b_unscaled_zext(<32 x half>* %a, <32 x i32>* %b, i8* %base) #0 { ; VBITS_GE_2048-LABEL: masked_gather_32b_unscaled_zext: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1w { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p1.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: punpklo p1.h, p1.b -; VBITS_GE_2048-NEXT: ld1h { z0.d }, p1/z, [x2, z1.d] -; VBITS_GE_2048-NEXT: uzp1 z0.s, z0.s, z0.s +; VBITS_GE_2048-NEXT: ld1h { z0.s }, p1/z, [x2, z1.s, uxtw] ; VBITS_GE_2048-NEXT: uzp1 z0.h, z0.h, z0.h ; VBITS_GE_2048-NEXT: st1h { z0.h }, p0, [x0] ; VBITS_GE_2048-NEXT: ret diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll @@ -871,20 +871,17 @@ ; The above tests test the types, the below tests check that the addressing ; modes still function -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_scatter_32b_scaled_sext_f16(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_32b_scaled_sext_f16: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1sw { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: uunpklo z0.s, z0.h ; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: uunpklo z0.d, z0.s -; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: st1h { z0.d }, p0, [x2, z1.d, lsl #1] +; VBITS_GE_2048-NEXT: st1h { z0.s }, p0, [x2, z1.s, sxtw #1] ; VBITS_GE_2048-NEXT: ret %vals = load <32 x half>, <32 x half>* %a %idxs = load <32 x i32>, <32 x i32>* %b @@ -895,18 +892,14 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_scatter_32b_scaled_sext_f32(<32 x float>* %a, <32 x i32>* %b, float* %base) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_32b_scaled_sext_f32: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.s, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 ; VBITS_GE_2048-NEXT: ld1w { z0.s }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1sw { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p0/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p0.s, p0/z, z0.s, #0.0 -; VBITS_GE_2048-NEXT: uunpklo z0.d, z0.s -; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: st1w { z0.d }, p0, [x2, z1.d, lsl #2] +; VBITS_GE_2048-NEXT: st1w { z0.s }, p0, [x2, z1.s, sxtw #2] ; VBITS_GE_2048-NEXT: ret %vals = load <32 x float>, <32 x float>* %a %idxs = load <32 x i32>, <32 x i32>* %b @@ -917,7 +910,6 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_scatter_32b_scaled_sext_f64(<32 x double>* %a, <32 x i32>* %b, double* %base) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_32b_scaled_sext_f64: ; VBITS_GE_2048: // %bb.0: @@ -936,20 +928,17 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_scatter_32b_scaled_zext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_32b_scaled_zext: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1w { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: uunpklo z0.s, z0.h ; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: uunpklo z0.d, z0.s -; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: st1h { z0.d }, p0, [x2, z1.d, lsl #1] +; VBITS_GE_2048-NEXT: st1h { z0.s }, p0, [x2, z1.s, uxtw #1] ; VBITS_GE_2048-NEXT: ret %vals = load <32 x half>, <32 x half>* %a %idxs = load <32 x i32>, <32 x i32>* %b @@ -960,20 +949,17 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_scatter_32b_unscaled_sext(<32 x half>* %a, <32 x i32>* %b, i8* %base) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_32b_unscaled_sext: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1sw { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: uunpklo z0.s, z0.h ; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: uunpklo z0.d, z0.s -; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: st1h { z0.d }, p0, [x2, z1.d] +; VBITS_GE_2048-NEXT: st1h { z0.s }, p0, [x2, z1.s, sxtw] ; VBITS_GE_2048-NEXT: ret %vals = load <32 x half>, <32 x half>* %a %idxs = load <32 x i32>, <32 x i32>* %b @@ -985,20 +971,17 @@ ret void } -; NOTE: This produces an non-optimal addressing mode due to a temporary workaround define void @masked_scatter_32b_unscaled_zext(<32 x half>* %a, <32 x i32>* %b, i8* %base) #0 { ; VBITS_GE_2048-LABEL: masked_scatter_32b_unscaled_zext: ; VBITS_GE_2048: // %bb.0: ; VBITS_GE_2048-NEXT: ptrue p0.h, vl32 -; VBITS_GE_2048-NEXT: ptrue p1.d, vl32 +; VBITS_GE_2048-NEXT: ptrue p1.s, vl32 ; VBITS_GE_2048-NEXT: ld1h { z0.h }, p0/z, [x0] -; VBITS_GE_2048-NEXT: ld1w { z1.d }, p1/z, [x1] +; VBITS_GE_2048-NEXT: ld1w { z1.s }, p1/z, [x1] ; VBITS_GE_2048-NEXT: fcmeq p0.h, p0/z, z0.h, #0.0 ; VBITS_GE_2048-NEXT: uunpklo z0.s, z0.h ; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: uunpklo z0.d, z0.s -; VBITS_GE_2048-NEXT: punpklo p0.h, p0.b -; VBITS_GE_2048-NEXT: st1h { z0.d }, p0, [x2, z1.d] +; VBITS_GE_2048-NEXT: st1h { z0.s }, p0, [x2, z1.s, uxtw] ; VBITS_GE_2048-NEXT: ret %vals = load <32 x half>, <32 x half>* %a %idxs = load <32 x i32>, <32 x i32>* %b