Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -24960,7 +24960,14 @@ EVT VT = Op.getValueType(); EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT); - SDValue Mask = convertFixedMaskToScalableVector(Load->getMask(), DAG); + SDValue Mask = Load->getMask(); + // If this is an extending load and the mask type is not the same as + // load's type then we have to extend the mask type. + if (!ISD::isNormalLoad(Load) && + VT.getVectorElementType().getSizeInBits() > + Mask.getValueType().getVectorElementType().getSizeInBits()) + Mask = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Load->getMask()); + Mask = convertFixedMaskToScalableVector(Mask, DAG); SDValue PassThru; bool IsPassThruZeroOrUndef = false; Index: llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll +++ llvm/test/CodeGen/AArch64/sve-streaming-mode-fixed-length-masked-load.ll @@ -335,6 +335,58 @@ ret <4 x double> %load } +define <3 x i32> @masked_load_zext_v3i32(ptr %load_ptr, <3 x i1> %pm) #0 { +; CHECK-LABEL: masked_load_zext_v3i32: +; CHECK: // %bb.0: +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: adrp x8, .LCPI13_0 +; CHECK-NEXT: strh w3, [sp, #12] +; CHECK-NEXT: strh w2, [sp, #10] +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: strh w1, [sp, #8] +; CHECK-NEXT: ldr d0, [x8, :lo12:.LCPI13_0] +; CHECK-NEXT: ldr d1, [sp, #8] +; CHECK-NEXT: and z0.d, z1.d, z0.d +; CHECK-NEXT: lsl z0.h, z0.h, #15 +; CHECK-NEXT: asr z0.h, z0.h, #15 +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: cmpne p0.s, p0/z, z0.s, #0 +; CHECK-NEXT: ld1h { z0.s }, p0/z, [x0] +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: add sp, sp, #16 +; CHECK-NEXT: ret + %load_value = tail call <3 x i16> @llvm.masked.load.v3i16.p0(ptr %load_ptr, i32 4, <3 x i1> %pm, <3 x i16> zeroinitializer) + %extend = zext <3 x i16> %load_value to <3 x i32> + ret <3 x i32> %extend; +} + +define <3 x i32> @masked_load_sext_v3i32(ptr %load_ptr, <3 x i1> %pm) #0 { +; CHECK-LABEL: masked_load_sext_v3i32: +; CHECK: // %bb.0: +; CHECK-NEXT: sub sp, sp, #16 +; CHECK-NEXT: .cfi_def_cfa_offset 16 +; CHECK-NEXT: adrp x8, .LCPI14_0 +; CHECK-NEXT: strh w3, [sp, #12] +; CHECK-NEXT: strh w2, [sp, #10] +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: strh w1, [sp, #8] +; CHECK-NEXT: ldr d0, [x8, :lo12:.LCPI14_0] +; CHECK-NEXT: ldr d1, [sp, #8] +; CHECK-NEXT: and z0.d, z1.d, z0.d +; CHECK-NEXT: lsl z0.h, z0.h, #15 +; CHECK-NEXT: asr z0.h, z0.h, #15 +; CHECK-NEXT: uunpklo z0.s, z0.h +; CHECK-NEXT: cmpne p0.s, p0/z, z0.s, #0 +; CHECK-NEXT: ld1sh { z0.s }, p0/z, [x0] +; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0 +; CHECK-NEXT: add sp, sp, #16 +; CHECK-NEXT: ret + %load_value = tail call <3 x i16> @llvm.masked.load.v3i16.p0(ptr %load_ptr, i32 4, <3 x i1> %pm, <3 x i16> zeroinitializer) + %extend = sext <3 x i16> %load_value to <3 x i32> + ret <3 x i32> %extend; +} + declare <4 x i8> @llvm.masked.load.v4i8(ptr, i32, <4 x i1>, <4 x i8>) declare <8 x i8> @llvm.masked.load.v8i8(ptr, i32, <8 x i1>, <8 x i8>) declare <16 x i8> @llvm.masked.load.v16i8(ptr, i32, <16 x i1>, <16 x i8>) @@ -351,3 +403,5 @@ declare <2 x double> @llvm.masked.load.v2f64(ptr, i32, <2 x i1>, <2 x double>) declare <4 x double> @llvm.masked.load.v4f64(ptr, i32, <4 x i1>, <4 x double>) + +declare <3 x i16> @llvm.masked.load.v3i16.p0(ptr, i32, <3 x i1>, <3 x i16>)