diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4360,8 +4360,13 @@ if (IsFixedLength) { assert(Subtarget->useSVEForFixedLengthVectors() && "Cannot lower when not using SVE for fixed vectors"); - IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); - MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); + if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) { + IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); + MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); + } else { + MemVT = getContainerForFixedLengthVector(DAG, MemVT); + IndexVT = MemVT.changeTypeToInteger(); + } InputVT = DAG.getValueType(MemVT.changeTypeToInteger()); Mask = DAG.getNode( ISD::ZERO_EXTEND, DL, @@ -4460,8 +4465,13 @@ if (IsFixedLength) { assert(Subtarget->useSVEForFixedLengthVectors() && "Cannot lower when not using SVE for fixed vectors"); - IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); - MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); + if (MemVT.getScalarSizeInBits() <= IndexVT.getScalarSizeInBits()) { + IndexVT = getContainerForFixedLengthVector(DAG, IndexVT); + MemVT = IndexVT.changeVectorElementType(MemVT.getVectorElementType()); + } else { + MemVT = getContainerForFixedLengthVector(DAG, MemVT); + IndexVT = MemVT.changeTypeToInteger(); + } InputVT = DAG.getValueType(MemVT.changeTypeToInteger()); StoreVal = diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll @@ -917,8 +917,8 @@ ; The above tests test the types, the below tests check that the addressing ; modes still function -define void @masked_gather_32b_scaled_sext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { -; CHECK-LABEL: masked_gather_32b_scaled_sext: +define void @masked_gather_32b_scaled_sext_f16(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { +; CHECK-LABEL: masked_gather_32b_scaled_sext_f16: ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32 ; VBITS_GE_2048-NEXT: ld1h { [[VALS:z[0-9]+]].h }, [[PG0]]/z, [x0] ; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32 @@ -941,6 +941,45 @@ ret void } +define void @masked_gather_32b_scaled_sext_f32(<32 x float>* %a, <32 x i32>* %b, float* %base) #0 { +; CHECK-LABEL: masked_gather_32b_scaled_sext_f32: +; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].s, vl32 +; VBITS_GE_2048-NEXT: ld1w { [[VALS:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG]]/z, [x1] +; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].s, [[PG]]/z, [[VALS]].s, #0.0 +; VBITS_GE_2048-NEXT: ld1w { [[RES:z[0-9]+]].s }, [[MASK]]/z, [x2, [[PTRS]].s, sxtw #2] +; VBITS_GE_2048-NEXT: st1w { [[RES]].s }, [[PG]], [x0] +; VBITS_GE_2048-NEXT: ret + %cvals = load <32 x float>, <32 x float>* %a + %idxs = load <32 x i32>, <32 x i32>* %b + %ext = sext <32 x i32> %idxs to <32 x i64> + %ptrs = getelementptr float, float* %base, <32 x i64> %ext + %mask = fcmp oeq <32 x float> %cvals, zeroinitializer + %vals = call <32 x float> @llvm.masked.gather.v32f32(<32 x float*> %ptrs, i32 8, <32 x i1> %mask, <32 x float> undef) + store <32 x float> %vals, <32 x float>* %a + ret void +} + +define void @masked_gather_32b_scaled_sext_f64(<32 x double>* %a, <32 x i32>* %b, double* %base) #0 { +; CHECK-LABEL: masked_gather_32b_scaled_sext_f64: +; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].d, vl32 +; VBITS_GE_2048-NEXT: ld1d { [[VALS:z[0-9]+]].d }, [[PG0]]/z, [x0] +; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32 +; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG1]]/z, [x1] +; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].d, [[PG0]]/z, [[VALS]].d, #0.0 +; VBITS_GE_2048-NEXT: ld1d { [[RES:z[0-9]+]].d }, [[MASK]]/z, [x2, [[PTRS]].d, sxtw #3] +; VBITS_GE_2048-NEXT: st1d { [[RES]].d }, [[PG0]], [x0] +; VBITS_GE_2048-NEXT: ret + %cvals = load <32 x double>, <32 x double>* %a + %idxs = load <32 x i32>, <32 x i32>* %b + %ext = sext <32 x i32> %idxs to <32 x i64> + %ptrs = getelementptr double, double* %base, <32 x i64> %ext + %mask = fcmp oeq <32 x double> %cvals, zeroinitializer + %vals = call <32 x double> @llvm.masked.gather.v32f64(<32 x double*> %ptrs, i32 8, <32 x i1> %mask, <32 x double> undef) + store <32 x double> %vals, <32 x double>* %a + ret void +} + define void @masked_gather_32b_scaled_zext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { ; CHECK-LABEL: masked_gather_32b_scaled_zext: ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32 diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll @@ -839,9 +839,8 @@ ; The above tests test the types, the below tests check that the addressing ; modes still function - -define void @masked_scatter_32b_scaled_sext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { -; CHECK-LABEL: masked_scatter_32b_scaled_sext: +define void @masked_scatter_32b_scaled_sext_f16(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { +; CHECK-LABEL: masked_scatter_32b_scaled_sext_f16: ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32 ; VBITS_GE_2048-NEXT: ld1h { [[VALS:z[0-9]+]].h }, [[PG0]]/z, [x0] ; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32 @@ -862,6 +861,41 @@ ret void } +define void @masked_scatter_32b_scaled_sext_f32(<32 x float>* %a, <32 x i32>* %b, float* %base) #0 { +; CHECK-LABEL: masked_scatter_32b_scaled_sext_f32: +; VBITS_GE_2048: ptrue [[PG:p[0-9]+]].s, vl32 +; VBITS_GE_2048-NEXT: ld1w { [[VALS:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG]]/z, [x1] +; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].s, [[PG]]/z, [[VALS]].s, #0.0 +; VBITS_GE_2048-NEXT: st1w { [[VALS]].s }, [[MASK]], [x2, [[PTRS]].s, sxtw #2] +; VBITS_GE_2048-NEXT: ret + %vals = load <32 x float>, <32 x float>* %a + %idxs = load <32 x i32>, <32 x i32>* %b + %ext = sext <32 x i32> %idxs to <32 x i64> + %ptrs = getelementptr float, float* %base, <32 x i64> %ext + %mask = fcmp oeq <32 x float> %vals, zeroinitializer + call void @llvm.masked.scatter.v32f32(<32 x float> %vals, <32 x float*> %ptrs, i32 8, <32 x i1> %mask) + ret void +} + +define void @masked_scatter_32b_scaled_sext_f64(<32 x double>* %a, <32 x i32>* %b, double* %base) #0 { +; CHECK-LABEL: masked_scatter_32b_scaled_sext_f64: +; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].d, vl32 +; VBITS_GE_2048-NEXT: ld1d { [[VALS:z[0-9]+]].d }, [[PG0]]/z, [x0] +; VBITS_GE_2048-NEXT: ptrue [[PG1:p[0-9]+]].s, vl32 +; VBITS_GE_2048-NEXT: ld1w { [[PTRS:z[0-9]+]].s }, [[PG1]]/z, [x1] +; VBITS_GE_2048-NEXT: fcmeq [[MASK:p[0-9]+]].d, [[PG0]]/z, [[VALS]].d, #0.0 +; VBITS_GE_2048-NEXT: st1d { [[VALS]].d }, [[MASK]], [x2, [[PTRS]].d, sxtw #3] +; VBITS_GE_2048-NEXT: ret + %vals = load <32 x double>, <32 x double>* %a + %idxs = load <32 x i32>, <32 x i32>* %b + %ext = sext <32 x i32> %idxs to <32 x i64> + %ptrs = getelementptr double, double* %base, <32 x i64> %ext + %mask = fcmp oeq <32 x double> %vals, zeroinitializer + call void @llvm.masked.scatter.v32f64(<32 x double> %vals, <32 x double*> %ptrs, i32 8, <32 x i1> %mask) + ret void +} + define void @masked_scatter_32b_scaled_zext(<32 x half>* %a, <32 x i32>* %b, half* %base) #0 { ; CHECK-LABEL: masked_scatter_32b_scaled_zext: ; VBITS_GE_2048: ptrue [[PG0:p[0-9]+]].h, vl32