diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -108,6 +108,7 @@ TARGET_BUILTIN(__builtin_ia32_tdpbuud_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8") TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile") +TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256fUsUsUsV256fV512sV512s", "n", "amx-bf16") // AMX TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile") diff --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h --- a/clang/lib/Headers/amxintrin.h +++ b/clang/lib/Headers/amxintrin.h @@ -15,8 +15,13 @@ #define __AMXINTRIN_H #ifdef __x86_64__ +/* Define the default attributes for the functions in this file. */ #define __DEFAULT_FN_ATTRS_TILE \ __attribute__((__always_inline__, __nodebug__, __target__("amx-tile"))) +#define __DEFAULT_FN_ATTRS_INT8 \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) +#define __DEFAULT_FN_ATTRS_BF16 \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16"))) /// Load tile configuration from a 64-byte memory location specified by /// "mem_addr". The tile configuration includes the tile type palette, the @@ -221,10 +226,10 @@ #define _tile_dpbf16ps(dst, src0, src1) \ __builtin_ia32_tdpbf16ps((dst), (src0), (src1)) -#define __DEFAULT_FN_ATTRS_INT8 \ - __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) - typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64))); +typedef short _tile1024bh __attribute__((__vector_size__(1024), __aligned__(64))); +typedef float _tile1024 __attribute__((__vector_size__(1024), __aligned__(64))); + static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8 _tile_loadd_internal(unsigned short m, unsigned short n, const void *base, __SIZE_TYPE__ stride) { @@ -263,12 +268,30 @@ (__SIZE_TYPE__)(stride), tile); } +static __inline__ _tile1024 __DEFAULT_FN_ATTRS_BF16 +_tile_tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024 dst, _tile1024bh src1, _tile1024bh src2) { + return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); +} + typedef struct __tile1024i_str { const unsigned short row; const unsigned short col; _tile1024i tile; } __tile1024i; +typedef struct __tile1024bf16_str { + const unsigned short row; + const unsigned short col; + _tile1024bh tile; +} __tile1024bh; + +typedef struct __tile1024_str { + const unsigned short row; + const unsigned short col; + _tile1024 tile; +} __tile1024; + __DEFAULT_FN_ATTRS_TILE static void __tile_loadd(__tile1024i *dst, const void *base, __SIZE_TYPE__ stride) { @@ -313,5 +336,16 @@ dst->tile = __builtin_ia32_tilezero_internal(dst->row, dst->col); } +__DEFAULT_FN_ATTRS_BF16 +static void __tile_tdpbf16ps(__tile1024 *dst, __tile1024bh src1, + __tile1024bh src2) { + dst->tile = _tile_tdpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile, + src1.tile, src2.tile); +} + +#undef __DEFAULT_FN_ATTRS_TILE +#undef __DEFAULT_FN_ATTRS_INT8 +#undef __DEFAULT_FN_ATTRS_BF16 + #endif /* __x86_64__ */ #endif /* __AMXINTRIN_H */ diff --git a/clang/test/CodeGen/X86/amx_api.c b/clang/test/CodeGen/X86/amx_api.c --- a/clang/test/CodeGen/X86/amx_api.c +++ b/clang/test/CodeGen/X86/amx_api.c @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +avx512f -target-feature +amx-int8 \ +// RUN: %clang_cc1 %s -flax-vector-conversions=none -ffreestanding -triple=x86_64-unknown-unknown -target-feature +avx512f -target-feature +amx-int8 \ // RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK #include @@ -80,3 +80,10 @@ //CHECK-NEXT bitcast x86_amx {{%.*}} to <256 x i32> __tile_zero(&c); } + +void test_tile_tdpbf16ps(__tile1024 a, __tile1024bh b, __tile1024bh c) { + //CHECK-LABEL: @test_tile_tdpbf16ps + //CHECK: call x86_amx @llvm.x86.tdpbf16ps.internal + //CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x float> + __tile_tdpbf16ps(&a, b, c); +} diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5079,6 +5079,12 @@ GCCBuiltin<"__builtin_ia32_tilezero_internal">, Intrinsic<[llvm_x86amx_ty], [llvm_i16_ty, llvm_i16_ty], []>; + def int_x86_tdpbf16ps_internal : + GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; } //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -470,16 +470,18 @@ case X86::PTDPBSSDV: case X86::PTDPBSUDV: case X86::PTDPBUSDV: - case X86::PTDPBUUDV: { + case X86::PTDPBUUDV: + case X86::PTDPBF16PSV: { MI.untieRegOperand(4); for (unsigned i = 3; i > 0; --i) MI.RemoveOperand(i); unsigned Opc; switch (Opcode) { - case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break; - case X86::PTDPBSUDV: Opc = X86::TDPBSUD; break; - case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break; - case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; + case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break; + case X86::PTDPBSUDV: Opc = X86::TDPBSUD; break; + case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break; + case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break; + case X86::PTDPBF16PSV: Opc = X86::TDPBF16PS; break; default: llvm_unreachable("Impossible Opcode!"); } MI.setDesc(TII->get(Opc)); diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4626,7 +4626,7 @@ case Intrinsic::x86_tdpbsud_internal: case Intrinsic::x86_tdpbusd_internal: case Intrinsic::x86_tdpbuud_internal: { - if (!Subtarget->hasAMXTILE()) + if (!Subtarget->hasAMXINT8()) break; SDValue Chain = Node->getOperand(0); unsigned Opc; diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -138,6 +138,16 @@ "tdpbf16ps\t{$src3, $src2, $dst|$dst, $src2, $src3}", []>, VEX_4V, T8XS; + // Pseduo instruction for RA. + let Constraints = "$src4 = $dst" in + def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), + [(set TILE: $dst, + (int_x86_tdpbf16ps_internal GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6))]>; + let usesCustomInserter = 1 in { // Pseudo instructions, using immediates instead of tile registers. // To be translated to the actual instructions in X86ISelLowering.cpp diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -70,7 +70,8 @@ case Intrinsic::x86_tdpbssd_internal: case Intrinsic::x86_tdpbsud_internal: case Intrinsic::x86_tdpbusd_internal: - case Intrinsic::x86_tdpbuud_internal: { + case Intrinsic::x86_tdpbuud_internal: + case Intrinsic::x86_tdpbf16ps_internal: { switch (OpNo) { case 3: Row = II->getArgOperand(0); diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -159,6 +159,7 @@ case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTILEZEROV: + case X86::PTDPBF16PSV: MachineOperand &MO1 = const_cast(MI.getOperand(1)); MachineOperand &MO2 = const_cast(MI.getOperand(2)); ShapeT Shape(&MO1, &MO2, MRI); @@ -256,6 +257,7 @@ case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTILEZEROV: + case X86::PTDPBF16PSV: return true; } } diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -888,6 +888,7 @@ case X86::PTDPBUSDV: case X86::PTDPBUUDV: case X86::PTILEZEROV: + case X86::PTDPBF16PSV: MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2); ShapeT Shape(&MO1, &MO2, MRI); diff --git a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll --- a/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll @@ -1,11 +1,14 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile -mattr=+avx512f -verify-machineinstrs | FileCheck %s +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-tile,+amx-int8,+amx-bf16 -verify-machineinstrs | FileCheck %s define void @test_amx(i8* %pointer, i8* %base, i64 %stride) { ; CHECK-LABEL: test_amx: ; CHECK: # %bb.0: -; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 -; CHECK-NEXT: vmovdqu64 %zmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: xorps %xmm0, %xmm0 +; CHECK-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) ; CHECK-NEXT: movb $1, -{{[0-9]+}}(%rsp) ; CHECK-NEXT: movb $8, -{{[0-9]+}}(%rsp) ; CHECK-NEXT: movw $8, -{{[0-9]+}}(%rsp) @@ -22,9 +25,9 @@ ; CHECK-NEXT: tdpbsud %tmm2, %tmm1, %tmm0 ; CHECK-NEXT: tdpbusd %tmm2, %tmm1, %tmm0 ; CHECK-NEXT: tdpbuud %tmm2, %tmm1, %tmm0 +; CHECK-NEXT: tdpbf16ps %tmm2, %tmm1, %tmm0 ; CHECK-NEXT: tilestored %tmm0, (%rdi,%rdx) ; CHECK-NEXT: tilerelease -; CHECK-NEXT: vzeroupper ; CHECK-NEXT: retq %c = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8) %a = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* %base, i64 %stride) @@ -33,7 +36,8 @@ %d1 = call x86_amx @llvm.x86.tdpbsud.internal(i16 8, i16 8, i16 8, x86_amx %d0, x86_amx %a, x86_amx %b) %d2 = call x86_amx @llvm.x86.tdpbusd.internal(i16 8, i16 8, i16 8, x86_amx %d1, x86_amx %a, x86_amx %b) %d3 = call x86_amx @llvm.x86.tdpbuud.internal(i16 8, i16 8, i16 8, x86_amx %d2, x86_amx %a, x86_amx %b) - call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d3) + %d4 = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 8, i16 8, i16 8, x86_amx %d3, x86_amx %a, x86_amx %b) + call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d4) ret void } @@ -44,4 +48,5 @@ declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)