diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -116,6 +116,7 @@ TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16") +TARGET_BUILTIN(__builtin_ia32_tdpfp16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-fp16") // AMX TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile") diff --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h --- a/clang/lib/Headers/amxintrin.h +++ b/clang/lib/Headers/amxintrin.h @@ -22,6 +22,8 @@ __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) #define __DEFAULT_FN_ATTRS_BF16 \ __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16"))) +#define __DEFAULT_FN_ATTRS_FP16 \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-fp16"))) /// Load tile configuration from a 64-byte memory location specified by /// "mem_addr". The tile configuration includes the tile type palette, the @@ -290,6 +292,13 @@ return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); } +/// This is internal intrinsic. C/C++ user should avoid calling it directly. +static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP16 +_tile_dpfp16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tdpfp16ps_internal(m, n, k, dst, src1, src2); +} + /// This struct pack the shape and tile data together for user. We suggest /// initializing the struct as early as possible, because compiler depends /// on the shape information to do configure. The constant value is preferred @@ -484,9 +493,32 @@ src0.tile, src1.tile); } +/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles src0 and +/// src1, accumulating the intermediate single-precision (32-bit) floating-point +/// elements with elements in "dst", and store the 32-bit result back to tile +/// "dst". +/// +/// \headerfile +/// +/// This intrinsic corresponds to the TDPFP16PS instruction. +/// +/// \param dst +/// The destination tile. Max size is 1024 Bytes. +/// \param src0 +/// The 1st source tile. Max size is 1024 Bytes. +/// \param src1 +/// The 2nd source tile. Max size is 1024 Bytes. +__DEFAULT_FN_ATTRS_FP16 +static __inline__ void __tile_dpfp16ps(__tile1024i *dst, __tile1024i src0, + __tile1024i src1) { + dst->tile = _tile_dpfp16ps_internal(src0.row, src1.col, src0.col, dst->tile, + src0.tile, src1.tile); +} + #undef __DEFAULT_FN_ATTRS_TILE #undef __DEFAULT_FN_ATTRS_INT8 #undef __DEFAULT_FN_ATTRS_BF16 +#undef __DEFAULT_FN_ATTRS_FP16 #endif /* __x86_64__ */ #endif /* __AMXINTRIN_H */ diff --git a/clang/test/CodeGen/X86/amx_api.c b/clang/test/CodeGen/X86/amx_api.c --- a/clang/test/CodeGen/X86/amx_api.c +++ b/clang/test/CodeGen/X86/amx_api.c @@ -1,5 +1,5 @@ // RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown -target-feature +avx512f -target-feature +amx-int8 \ -// RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK +// RUN: -target-feature +amx-bf16 -target-feature +amx-fp16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK #include @@ -102,3 +102,11 @@ //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) __tile_dpbf16ps(&a, b, c); } + +void test_tile_dpfp16ps(__tile1024i a, __tile1024i b, __tile1024i c) { + //CHECK-LABEL: @test_tile_dpfp16ps + //CHECK-DAG: call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> {{%.*}}) + //CHECK-DAG: call x86_amx @llvm.x86.tdpfp16ps.internal + //CHECK-DAG: call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx {{%.*}}) + __tile_dpfp16ps(&a, b, c); +} diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5396,6 +5396,12 @@ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty], []>; + def int_x86_tdpfp16ps_internal : + ClangBuiltin<"__builtin_ia32_tdpfp16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; def int_x86_cast_vector_to_tile: DefaultAttrsIntrinsic<[llvm_x86amx_ty], [llvm_anyvector_ty], [IntrNoMem]>; def int_x86_cast_tile_to_vector: diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -566,7 +566,8 @@ case X86::PTDPBSUDV: case X86::PTDPBUSDV: case X86::PTDPBUUDV: - case X86::PTDPBF16PSV: { + case X86::PTDPBF16PSV: + case X86::PTDPFP16PSV: { MI.untieRegOperand(4); for (unsigned i = 3; i > 0; --i) MI.removeOperand(i); @@ -577,6 +578,7 @@ case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break; case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break; + case X86::PTDPFP16PSV: Opc = X86::TDPFP16PS; break; default: llvm_unreachable("Impossible Opcode!"); } MI.setDesc(TII->get(Opc)); diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -195,6 +195,18 @@ "tdpfp16ps\t{$src3, $src2, $src1|$src1, $src2, $src3}", []>, VEX_4V, T8XD; } + + // Pseduo instruction for RA. + let isPseudo = true, Constraints = "$src4 = $dst" in { + def PTDPFP16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_tdpfp16ps_internal GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6))]>; + } + let usesCustomInserter = 1 in { def PTDPFP16PS : PseudoI<(outs), (ins u8imm:$src1, u8imm:$src2, u8imm:$src3), diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -133,7 +133,8 @@ case Intrinsic::x86_tdpbsud_internal: case Intrinsic::x86_tdpbusd_internal: case Intrinsic::x86_tdpbuud_internal: - case Intrinsic::x86_tdpbf16ps_internal: { + case Intrinsic::x86_tdpbf16ps_internal: + case Intrinsic::x86_tdpfp16ps_internal: { switch (OpNo) { case 3: Row = II->getArgOperand(0); diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -962,6 +962,7 @@ case X86::PTDPBUUDV: case X86::PTILEZEROV: case X86::PTDPBF16PSV: + case X86::PTDPFP16PSV: MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2); ShapeT Shape(&MO1, &MO2, MRI); diff --git a/llvm/test/CodeGen/X86/AMX/amx-fp16.ll b/llvm/test/CodeGen/X86/AMX/amx-fp16.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/amx-fp16.ll @@ -0,0 +1,41 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-int8,+amx-fp16,+avx512f -verify-machineinstrs | FileCheck %s + +define void @test_amx(ptr %pointer, ptr %base, i64 %stride) { +; CHECK-LABEL: test_amx: +; CHECK: # %bb.0: +; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0 +; CHECK-NEXT: vmovups %zmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $1, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, %ax +; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm0 +; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm1 +; CHECK-NEXT: tilezero %tmm2 +; CHECK-NEXT: tdpfp16ps %tmm1, %tmm0, %tmm2 +; CHECK-NEXT: tileloaddt1 (%rsi,%rdx), %tmm0 +; CHECK-NEXT: tilestored %tmm2, (%rdi,%rdx) +; CHECK-NEXT: tilerelease +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + %a = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, ptr %base, i64 %stride) + %b = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, ptr %base, i64 %stride) + %c = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8) + %d = call x86_amx @llvm.x86.tdpfp16ps.internal(i16 8, i16 8, i16 8, x86_amx %c, x86_amx %a, x86_amx %b) + %e = call x86_amx @llvm.x86.tileloaddt164.internal(i16 8, i16 8, ptr %base, i64 %stride) + call void @llvm.x86.tilestored64.internal(i16 8, i16 8, ptr %pointer, i64 %stride, x86_amx %d) + + ret void +} + +declare x86_amx @llvm.x86.tilezero.internal(i16, i16) +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, ptr, i64) +declare x86_amx @llvm.x86.tileloaddt164.internal(i16, i16, ptr, i64) +declare x86_amx @llvm.x86.tdpfp16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare void @llvm.x86.tilestored64.internal(i16, i16, ptr, i64, x86_amx)