diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td --- a/llvm/include/llvm/IR/IntrinsicsAArch64.td +++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -1294,6 +1294,11 @@ ], [IntrInaccessibleMemOrArgMemOnly, ImmArg<3>]>; +class SVE_MatMul_Intrinsic + : Intrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMSubdivide4VectorType<0>, LLVMSubdivide4VectorType<0>], + [IntrNoMem]>; + // // Loads // @@ -2254,6 +2259,19 @@ def int_aarch64_sve_bext_x : AdvSIMD_2VectorArg_Intrinsic; def int_aarch64_sve_bgrp_x : AdvSIMD_2VectorArg_Intrinsic; + +// +// SVE ACLE: 7.3. INT8 matrix multiply extensions +// +def int_aarch64_sve_ummla : SVE_MatMul_Intrinsic; +def int_aarch64_sve_smmla : SVE_MatMul_Intrinsic; +def int_aarch64_sve_usmmla : SVE_MatMul_Intrinsic; + +// +// SVE ACLE: 7.4/5. FP64/FP32 matrix multiply extensions +// +def int_aarch64_sve_mmla : AdvSIMD_3VectorArg_Intrinsic; + } // diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1852,20 +1852,20 @@ } let Predicates = [HasSVE, HasMatMulInt8] in { - def SMMLA_ZZZ : sve_int_matmul<0b00, "smmla">; - def UMMLA_ZZZ : sve_int_matmul<0b11, "ummla">; - def USMMLA_ZZZ : sve_int_matmul<0b10, "usmmla">; + defm SMMLA_ZZZ : sve_int_matmul<0b00, "smmla", int_aarch64_sve_smmla>; + defm UMMLA_ZZZ : sve_int_matmul<0b11, "ummla", int_aarch64_sve_ummla>; + defm USMMLA_ZZZ : sve_int_matmul<0b10, "usmmla", int_aarch64_sve_usmmla>; def USDOT_ZZZ : sve_int_dot_mixed<"usdot">; def USDOT_ZZZI : sve_int_dot_mixed_indexed<0, "usdot">; def SUDOT_ZZZI : sve_int_dot_mixed_indexed<1, "sudot">; } let Predicates = [HasSVE, HasMatMulFP32] in { - def FMMLA_ZZZ_S : sve_fp_matrix_mla<0, "fmmla", ZPR32>; + defm FMMLA_ZZZ_S : sve_fp_matrix_mla<0, "fmmla", ZPR32, int_aarch64_sve_mmla, nxv4f32>; } let Predicates = [HasSVE, HasMatMulFP64] in { - def FMMLA_ZZZ_D : sve_fp_matrix_mla<1, "fmmla", ZPR64>; + defm FMMLA_ZZZ_D : sve_fp_matrix_mla<1, "fmmla", ZPR64, int_aarch64_sve_mmla, nxv2f64>; defm LD1RO_B_IMM : sve_mem_ldor_si<0b00, "ld1rob", Z_b, ZPR8>; defm LD1RO_H_IMM : sve_mem_ldor_si<0b01, "ld1roh", Z_h, ZPR16>; defm LD1RO_W_IMM : sve_mem_ldor_si<0b10, "ld1row", Z_s, ZPR32>; diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td --- a/llvm/lib/Target/AArch64/SVEInstrFormats.td +++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -7597,6 +7597,13 @@ let ElementSize = ZPR32.ElementSize; } +multiclass sve_int_matmul uns, string asm, SDPatternOperator op> { + def NAME : sve_int_matmul; + + def : Pat<(nxv4i32 (op (nxv4i32 ZPR32:$Zda), (nxv16i8 ZPR8:$Zn), (nxv16i8 ZPR8:$Zm))), + (!cast(NAME) ZPR32:$Zda, ZPR8:$Zn, ZPR8:$Zm)>; +} + //===----------------------------------------------------------------------===// // SVE Integer Dot Product Mixed Sign Group //===----------------------------------------------------------------------===// @@ -7665,6 +7672,13 @@ let ElementSize = zprty.ElementSize; } +multiclass sve_fp_matrix_mla { + def NAME : sve_fp_matrix_mla; + + def : Pat<(vt (op (vt zprty:$Zda), (vt zprty:$Zn), (vt zprty:$Zm))), + (!cast(NAME) zprty:$Zda, zprty:$Zn, zprty:$Zm)>; +} + //===----------------------------------------------------------------------===// // SVE Memory - Contiguous Load And Replicate 256-bit Group //===----------------------------------------------------------------------===// diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp32.ll @@ -0,0 +1,14 @@ +; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+f32mm -asm-verbose=0 < %s -o - | FileCheck %s + + +define @fmmla_s( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: fmmla_s: +; CHECK-NEXT: fmmla z0.s, z1.s, z2.s +; CHECK-NEXT : ret + %val = tail call @llvm.aarch64.sve.mmla.nxv4f32( %r, %a, %b) + ret %val +} + +declare @llvm.aarch64.sve.mmla.nxv4f32(,,) + diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-fp64.ll @@ -0,0 +1,14 @@ +; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+f64mm -asm-verbose=0 < %s -o - | FileCheck %s + + +define @fmmla_d( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: fmmla_d: +; CHECK-NEXT: fmmla z0.d, z1.d, z2.d +; CHECK-NEXT : ret + %val = tail call @llvm.aarch64.sve.mmla.nxv2f64( %r, %a, %b) + ret %val +} + +declare @llvm.aarch64.sve.mmla.nxv2f64(,,) + diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-matmul-int8.ll @@ -0,0 +1,33 @@ +; RUN: llc -mtriple=aarch64-none-linux-gnu -mattr=+sve,+i8mm -asm-verbose=0 < %s -o - | FileCheck %s + +define @smmla( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: smmla: +; CHECK-NEXT: smmla z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %val = tail call @llvm.aarch64.sve.smmla.nxv4i32( %r, %a, %b) + ret %val +} + +define @ummla( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: ummla: +; CHECK-NEXT: ummla z0.s, z1.b, z2.b +; CHECK-NEXT: ret + %val = tail call @llvm.aarch64.sve.ummla.nxv4i32( %r, %a, %b) + ret %val +} + +define @usmmla( %r, %a, %b) nounwind { +entry: +; CHECK-LABEL: usmmla: +; CHECK-NEXT: usmmla z0.s, z1.b, z2.b +; CHECK-NEXT : ret + %val = tail call @llvm.aarch64.sve.usmmla.nxv4i32( %r, %a, %b) + ret %val +} + +declare @llvm.aarch64.sve.smmla.nxv4i32(, , ) +declare @llvm.aarch64.sve.ummla.nxv4i32(, , ) +declare @llvm.aarch64.sve.usmmla.nxv4i32(, , ) +