diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -2790,9 +2790,9 @@ // def int_aarch64_sve_psel - : DefaultAttrsIntrinsic<[llvm_anyvector_ty], - [LLVMMatchType<0>, - LLVMMatchType<0>, llvm_i32_ty], + : DefaultAttrsIntrinsic<[llvm_nxv16i1_ty], + [llvm_nxv16i1_ty, + llvm_anyvector_ty, llvm_i32_ty], [IntrNoMem]>; // diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td --- a/llvm/lib/Target/AArch64/SMEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td @@ -1309,30 +1309,30 @@ (!cast(NAME # _D) PNRAny:$Pd, PNRAny:$Pn, PPR64:$Pm, MatrixIndexGPR32Op12_15:$Rv, sme_elm_idx0_1:$imm), 0>; - def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv16i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv16i1 PPR8:$Pm), MatrixIndexGPR32Op12_15:$idx)), (!cast(NAME # _B) $Pn, $Pm, $idx, 0)>; - def : Pat<(nxv8i1 (op (nxv8i1 PPRAny:$Pn), (nxv8i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv8i1 PPR16:$Pm), MatrixIndexGPR32Op12_15:$idx)), (!cast(NAME # _H) $Pn, $Pm, $idx, 0)>; - def : Pat<(nxv4i1 (op (nxv4i1 PPRAny:$Pn), (nxv4i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv4i1 PPR32:$Pm), MatrixIndexGPR32Op12_15:$idx)), (!cast(NAME # _S) $Pn, $Pm, $idx, 0)>; - def : Pat<(nxv2i1 (op (nxv2i1 PPRAny:$Pn), (nxv2i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv2i1 PPR64:$Pm), MatrixIndexGPR32Op12_15:$idx)), (!cast(NAME # _D) $Pn, $Pm, $idx, 0)>; let AddedComplexity = 1 in { - def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv16i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv16i1 PPR8:$Pm), (i32 (tileslice8 MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_15:$imm)))), (!cast(NAME # _B) $Pn, $Pm, $idx, $imm)>; - def : Pat<(nxv8i1 (op (nxv8i1 PPRAny:$Pn), (nxv8i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv8i1 PPR16:$Pm), (i32 (tileslice16 MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_7:$imm)))), (!cast(NAME # _H) $Pn, $Pm, $idx, $imm)>; - def : Pat<(nxv4i1 (op (nxv4i1 PPRAny:$Pn), (nxv4i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv4i1 PPR32:$Pm), (i32 (tileslice32 MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_3:$imm)))), (!cast(NAME # _S) $Pn, $Pm, $idx, $imm)>; - def : Pat<(nxv2i1 (op (nxv2i1 PPRAny:$Pn), (nxv2i1 PPRAny:$Pm), + def : Pat<(nxv16i1 (op (nxv16i1 PPRAny:$Pn), (nxv2i1 PPR64:$Pm), (i32 (tileslice64 MatrixIndexGPR32Op12_15:$idx, sme_elm_idx0_1:$imm)))), (!cast(NAME # _D) $Pn, $Pm, $idx, $imm)>; } diff --git a/llvm/test/CodeGen/AArch64/sve2-intrinsics-psel.ll b/llvm/test/CodeGen/AArch64/sve2-intrinsics-psel.ll --- a/llvm/test/CodeGen/AArch64/sve2-intrinsics-psel.ll +++ b/llvm/test/CodeGen/AArch64/sve2-intrinsics-psel.ll @@ -22,70 +22,70 @@ ret %res } -define @psel_h( %p1, %p2, i32 %idx) { +define @psel_h( %p1, %p2, i32 %idx) { ; CHECK-LABEL: psel_h: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w12, w0 ; CHECK-NEXT: psel p0, p0, p1.h[w12, 0] ; CHECK-NEXT: ret - %res = call @llvm.aarch64.sve.psel.nxv8i1( %p1, %p2, i32 %idx) - ret %res + %res = call @llvm.aarch64.sve.psel.nxv8i1( %p1, %p2, i32 %idx) + ret %res } -define @psel_h_imm( %p1, %p2, i32 %idx) { +define @psel_h_imm( %p1, %p2, i32 %idx) { ; CHECK-LABEL: psel_h_imm: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w12, w0 ; CHECK-NEXT: psel p0, p0, p1.h[w12, 7] ; CHECK-NEXT: ret %add = add i32 %idx, 7 - %res = call @llvm.aarch64.sve.psel.nxv8i1( %p1, %p2, i32 %add) - ret %res + %res = call @llvm.aarch64.sve.psel.nxv8i1( %p1, %p2, i32 %add) + ret %res } -define @psel_s( %p1, %p2, i32 %idx) { +define @psel_s( %p1, %p2, i32 %idx) { ; CHECK-LABEL: psel_s: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w12, w0 ; CHECK-NEXT: psel p0, p0, p1.s[w12, 0] ; CHECK-NEXT: ret - %res = call @llvm.aarch64.sve.psel.nxv4i1( %p1, %p2, i32 %idx) - ret %res + %res = call @llvm.aarch64.sve.psel.nxv4i1( %p1, %p2, i32 %idx) + ret %res } -define @psel_s_imm( %p1, %p2, i32 %idx) { +define @psel_s_imm( %p1, %p2, i32 %idx) { ; CHECK-LABEL: psel_s_imm: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w12, w0 ; CHECK-NEXT: psel p0, p0, p1.s[w12, 3] ; CHECK-NEXT: ret %add = add i32 %idx, 3 - %res = call @llvm.aarch64.sve.psel.nxv4i1( %p1, %p2, i32 %add) - ret %res + %res = call @llvm.aarch64.sve.psel.nxv4i1( %p1, %p2, i32 %add) + ret %res } -define @psel_d( %p1, %p2, i32 %idx) { +define @psel_d( %p1, %p2, i32 %idx) { ; CHECK-LABEL: psel_d: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w12, w0 ; CHECK-NEXT: psel p0, p0, p1.d[w12, 0] ; CHECK-NEXT: ret - %res = call @llvm.aarch64.sve.psel.nxv2i1( %p1, %p2, i32 %idx) - ret %res + %res = call @llvm.aarch64.sve.psel.nxv2i1( %p1, %p2, i32 %idx) + ret %res } -define @psel_d_imm( %p1, %p2, i32 %idx) { +define @psel_d_imm( %p1, %p2, i32 %idx) { ; CHECK-LABEL: psel_d_imm: ; CHECK: // %bb.0: ; CHECK-NEXT: mov w12, w0 ; CHECK-NEXT: psel p0, p0, p1.d[w12, 1] ; CHECK-NEXT: ret %add = add i32 %idx, 1 - %res = call @llvm.aarch64.sve.psel.nxv2i1( %p1, %p2, i32 %add) - ret %res + %res = call @llvm.aarch64.sve.psel.nxv2i1( %p1, %p2, i32 %add) + ret %res } declare @llvm.aarch64.sve.psel.nxv16i1(, , i32) -declare @llvm.aarch64.sve.psel.nxv8i1(, , i32) -declare @llvm.aarch64.sve.psel.nxv4i1(, , i32) -declare @llvm.aarch64.sve.psel.nxv2i1(, , i32) +declare @llvm.aarch64.sve.psel.nxv8i1(, , i32) +declare @llvm.aarch64.sve.psel.nxv4i1(, , i32) +declare @llvm.aarch64.sve.psel.nxv2i1(, , i32)