diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -1294,6 +1294,11 @@ ], [IntrInaccessibleMemOrArgMemOnly, ImmArg<3>]>; +class SVE_MatMul_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMSubdivide4VectorType<0>, LLVMSubdivide4VectorType<0>], + [IntrNoMem]>; + // // Loads // @@ -2254,6 +2259,19 @@ def int_aarch64_sve_bext_x : AdvSIMD_2VectorArg_Intrinsic; def int_aarch64_sve_bgrp_x : AdvSIMD_2VectorArg_Intrinsic; + +// +// SVE ACLE: 7.3. INT8 matrix multiply extensions +// +def int_aarch64_sve_ummla : SVE_MatMul_Intrinsic; +def int_aarch64_sve_smmla : SVE_MatMul_Intrinsic; +def int_aarch64_sve_usmmla : SVE_MatMul_Intrinsic; + +// +// SVE ACLE: 7.4/5. FP64/FP32 matrix multiply extensions +// +def int_aarch64_sve_mmla : AdvSIMD_3VectorArg_Intrinsic; + } // diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1858,10 +1858,21 @@ def USDOT_ZZZ : sve_int_dot_mixed<"usdot">; def USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot">; def SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot">; + + def : Pat<(nxv4i32 (int_aarch64_sve_smmla (nxv4i32 ZPR32:$_Zda_), (nxv16i8 ZPR8:$Zn), (nxv16i8 ZPR8:$Zm))), + (SMMLA_ZZZ ZPR32:$_Zda_, ZPR8:$Zn, ZPR8:$Zm)>; + def : Pat<(nxv4i32 (int_aarch64_sve_ummla (nxv4i32 ZPR32:$_Zda_), (nxv16i8 ZPR8:$Zn), (nxv16i8 ZPR8:$Zm))), + (UMMLA_ZZZ ZPR32:$_Zda_, ZPR8:$Zn, ZPR8:$Zm)>; + def : Pat<(nxv4i32 (int_aarch64_sve_usmmla (nxv4i32 ZPR32:$_Zda_), (nxv16i8 ZPR8:$Zn), (nxv16i8 ZPR8:$Zm))), + (USMMLA_ZZZ ZPR32:$_Zda_, ZPR8:$Zn, ZPR8:$Zm)>; + } let Predicates = [HasSVE, HasMatMulFP32] in { def FMMLA_ZZZ_S : sve_fp_matrix_mla<0, "fmmla", ZPR32>; + + def : Pat<(nxv4f32 (int_aarch64_sve_mmla (nxv4f32 ZPR32:$_Zda_), (nxv4f32 ZPR32:$Zn), (nxv4f32 ZPR32:$Zm))), + (FMMLA_ZZZ_S ZPR32:$_Zda_, ZPR32:$Zn, ZPR32:$Zm)>; } let Predicates = [HasSVE, HasMatMulFP64] in { @@ -1880,6 +1891,10 @@ def UZP2_ZZZ_128 : sve_int_perm_bin_perm_128_zz<0b01, 1, "uzp2">; def TRN1_ZZZ_128 : sve_int_perm_bin_perm_128_zz<0b11, 0, "trn1">; def TRN2_ZZZ_128 : sve_int_perm_bin_perm_128_zz<0b11, 1, "trn2">; + + def : Pat<(nxv2f64 (int_aarch64_sve_mmla (nxv2f64 ZPR64:$_Zda_), (nxv2f64 ZPR64:$Zn), (nxv2f64 ZPR64:$Zm))), + (FMMLA_ZZZ_D ZPR64:$_Zda_, ZPR64:$Zn, ZPR64:$Zm)>; + } let Predicates = [HasSVE2] in { diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll @@ -0,0 +1,14 @@ +; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+f32mm -asm-verbose=0 < %s -o - | FileCheck %s + + +define @fmmla_s( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: fmmla_s: +; CHECK-NEXT: fmmla z0.s, z1.s, z2.s +; CHECK-NEXT : ret + %val = tail call @llvm.aarch64.sve.mmla.nxv4f32( %r, %a, %b) + ret %val +} + +declare @llvm.aarch64.sve.mmla.nxv4f32(,,) + diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll @@ -0,0 +1,14 @@ +; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+f64mm -asm-verbose=0 < %s -o - | FileCheck %s + + +define @fmmla_d( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: fmmla_d: +; CHECK-NEXT: fmmla z0.d, z1.d, z2.d +; CHECK-NEXT : ret + %val = tail call @llvm.aarch64.sve.mmla.nxv2f64( %r, %a, %b) + ret %val +} + +declare @llvm.aarch64.sve.mmla.nxv2f64(,,) + diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll @@ -0,0 +1,33 @@ +; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+i8mm -asm-verbose=0 < %s -o - | FileCheck %s + +define @smmla( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: smmla: +; CHECK-NEXT: smmla z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %val = tail call @llvm.aarch64.sve.smmla.nxv4i32( %r, %a, %b) + ret %val +} + +define @ummla( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: ummla: +; CHECK-NEXT: ummla z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %val = tail call @llvm.aarch64.sve.ummla.nxv4i32( %r, %a, %b) + ret %val +} + +define @usmmla( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: usmmla: +; CHECK-NEXT: usmmla z0.s, z1.b, z2.b +; CHECK-NEXT : ret + %val = tail call @llvm.aarch64.sve.usmmla.nxv4i32( %r, %a, %b) + ret %val +} + +declare @llvm.aarch64.sve.smmla.nxv4i32(, , ) +declare @llvm.aarch64.sve.ummla.nxv4i32(, , ) +declare @llvm.aarch64.sve.usmmla.nxv4i32(, , ) +