Index: include/clang/Basic/BuiltinsX86.def =================================================================== --- include/clang/Basic/BuiltinsX86.def +++ include/clang/Basic/BuiltinsX86.def @@ -1831,6 +1831,8 @@ TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb512, "V64cV64cV64c", "ncV:512:", "avx512vbmi") TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb128, "V16cV16cV16c", "ncV:128:", "avx512vbmi,avx512vl") TARGET_BUILTIN(__builtin_ia32_vpmultishiftqb256, "V32cV32cV32c", "ncV:256:", "avx512vbmi,avx512vl") + +// bf16 intrinsics TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_128, "V8sV4fV4f", "ncV:128:", "avx512bf16,avx512vl") TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_256, "V16sV8fV8f", "ncV:256:", "avx512bf16,avx512vl") TARGET_BUILTIN(__builtin_ia32_cvtne2ps2bf16_512, "V32sV16fV16f", "ncV:512:", "avx512bf16") @@ -1840,6 +1842,7 @@ TARGET_BUILTIN(__builtin_ia32_dpbf16ps_128, "V4fV4fV4iV4i", "ncV:128:", "avx512bf16,avx512vl") TARGET_BUILTIN(__builtin_ia32_dpbf16ps_256, "V8fV8fV8iV8i", "ncV:256:", "avx512bf16,avx512vl") TARGET_BUILTIN(__builtin_ia32_dpbf16ps_512, "V16fV16fV16iV16i", "ncV:512:", "avx512bf16") +TARGET_BUILTIN(__builtin_ia32_cvtsbf162ss_32, "fUs", "nc", "avx512bf16") // generic select intrinsics TARGET_BUILTIN(__builtin_ia32_selectb_128, "V16cUsV16cV16c", "ncV:128:", "avx512bw,avx512vl") Index: lib/CodeGen/CGBuiltin.cpp =================================================================== --- lib/CodeGen/CGBuiltin.cpp +++ lib/CodeGen/CGBuiltin.cpp @@ -9787,6 +9787,18 @@ return EmitX86CpuIs(CPUStr); } +// Convert a BF16 to a float. +static Value *EmitX86CvtBF16ToFloatExpr(CodeGenFunction &CGF, + const CallExpr *E, + ArrayRef Ops) { + llvm::Type *Int32Ty = CGF.Builder.getInt32Ty(); + Value *ZeroExt = CGF.Builder.CreateZExt(Ops[0], Int32Ty); + Value *Shl = CGF.Builder.CreateShl(ZeroExt, 16); + llvm::Type *ResultType = CGF.ConvertType(E->getType()); + Value *BitCast = CGF.Builder.CreateBitCast(Shl, ResultType); + return BitCast; +} + Value *CodeGenFunction::EmitX86CpuIs(StringRef CPUStr) { llvm::Type *Int32Ty = Builder.getInt32Ty(); @@ -11891,6 +11903,8 @@ Intrinsic::ID IID = Intrinsic::x86_avx512bf16_mask_cvtneps2bf16_128; return Builder.CreateCall(CGM.getIntrinsic(IID), Ops); } + case X86::BI__builtin_ia32_cvtsbf162ss_32: + return EmitX86CvtBF16ToFloatExpr(*this, E, Ops); case X86::BI__builtin_ia32_cvtneps2bf16_256_mask: case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: { Index: lib/Headers/avx512bf16intrin.h =================================================================== --- lib/Headers/avx512bf16intrin.h +++ lib/Headers/avx512bf16intrin.h @@ -15,10 +15,50 @@ typedef short __m512bh __attribute__((__vector_size__(64), __aligned__(64))); typedef short __m256bh __attribute__((__vector_size__(32), __aligned__(32))); +typedef short __m128bh __attribute__((__vector_size__(16), __aligned__(16))); +typedef unsigned short __bfloat16; #define __DEFAULT_FN_ATTRS512 \ __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \ __min_vector_width__(512))) +#define __DEFAULT_FN_ATTRS256 \ + __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \ + __min_vector_width__(256))) +#define __DEFAULT_FN_ATTRS128 \ + __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"), \ + __min_vector_width__(128))) +#define __DEFAULT_FN_ATTRS \ + __attribute__((__always_inline__, __nodebug__, __target__("avx512bf16"))) + +/// Convert One BF16 Data to One Single Data. +/// +/// \headerfile +/// +/// This intrinsic does not correspond to a specific instruction. +/// +/// \param __A +/// A bfloat data. +/// \returns A float data whose sign field and exponent field keep unchanged, +/// and fraction field is extended to 23 bits. +static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bfloat16 __A) { + return __builtin_ia32_cvtsbf162ss_32(__A); +} + +/// Convert One Single Float Data to One BF16 Data. +/// +/// \headerfile +/// +/// This intrinsic corresponds to the VCVTNEPS2BF16 instructions. +/// +/// \param __A +/// A float data. +/// \returns A bf16 data whose sign field and exponent field keep unchanged, +/// and fraction field is truncated to 7 bits. +static __inline__ __bfloat16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) { + __v4sf __V = {__A, 0, 0, 0}; + __m128bh __R = _mm_cvtneps_pbh(__V); + return __R[0]; +} /// Convert Two Packed Single Data to One Packed BF16 Data. /// @@ -209,6 +249,100 @@ (__v16sf)_mm512_setzero_si512()); } +/// Convert Packed BF16 Data to Packed float Data. +/// +/// \headerfile +/// +/// \param __A +/// A 256-bit vector of [16 x bfloat]. +/// \returns A 512-bit vector of [16 x float] come from convertion of __A +static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) { + return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(__A), 16)); +} + +/// Convert Packed BF16 Data to Packed float Data. +/// +/// \headerfile +/// +/// \param __A +/// A 128-bit vector of [8 x bfloat]. +/// \returns A 256-bit vector of [8 x float] come from convertion of __A +static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) { + return _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepi16_epi32(__A), 16)); +} + +/// Convert Packed BF16 Data to Packed float Data using zeroing mask. +/// +/// \headerfile +/// +/// \param __M +/// A 16-bit mask. Elements are zeroed out when the corresponding mask +/// bit is not set. +/// \param __A +/// A 256-bit vector of [16 x bfloat]. +/// \returns A 512-bit vector of [16 x float] come from convertion of __A +static __inline__ __m512 __DEFAULT_FN_ATTRS512 +_mm512_maskz_cvtpbh_ps(__mmask16 __M, __m256bh __A) { + return _mm512_castsi512_ps( + _mm512_slli_epi32(_mm512_maskz_cvtepi16_epi32(__M, __A), 16)); +} + +/// Convert Packed BF16 Data to Packed float Data using zeroing mask. +/// +/// \headerfile +/// +/// \param __M +/// A 8-bit mask. Elements are zeroed out when the corresponding mask +/// bit is not set. +/// \param __A +/// A 128-bit vector of [8 x bfloat]. +/// \returns A 256-bit vector of [8 x float] come from convertion of __A +static __inline__ __m256 __DEFAULT_FN_ATTRS256 +_mm256_maskz_cvtpbh_ps(__mmask8 __M, __m128bh __A) { + return _mm256_castsi256_ps( + _mm256_slli_epi32(_mm256_maskz_cvtepi16_epi32(__M, __A), 16)); +} + +/// Convert Packed BF16 Data to Packed float Data using merging mask. +/// +/// \headerfile +/// +/// \param __S +/// A 512-bit vector of [16 x float]. Elements are copied from __S when +/// the corresponding mask bit is not set. +/// \param __M +/// A 16-bit mask. +/// \param __A +/// A 256-bit vector of [16 x bfloat]. +/// \returns A 512-bit vector of [16 x float] come from convertion of __A +static __inline__ __m512 __DEFAULT_FN_ATTRS512 +_mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __M, __m256bh __A) { + return _mm512_castsi512_ps( + _mm512_mask_slli_epi32(__S, __M, _mm512_cvtepi16_epi32(__A), 16)); +} + +/// Convert Packed BF16 Data to Packed float Data using merging mask. +/// +/// \headerfile +/// +/// \param __S +/// A 256-bit vector of [8 x float]. Elements are copied from __S when +/// the corresponding mask bit is not set. +/// \param __M +/// A 8-bit mask. Elements are zeroed out when the corresponding mask +/// bit is not set. +/// \param __A +/// A 128-bit vector of [8 x bfloat]. +/// \returns A 256-bit vector of [8 x float] come from convertion of __A +static __inline__ __m256 __DEFAULT_FN_ATTRS256 +_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __M, __m128bh __A) { + return _mm256_castsi256_ps( + _mm256_mask_slli_epi32(__S, __M, _mm256_cvtepi16_epi32(__A), 16)); +} + +#undef __DEFAULT_FN_ATTRS +#undef __DEFAULT_FN_ATTRS128 +#undef __DEFAULT_FN_ATTRS256 #undef __DEFAULT_FN_ATTRS512 #endif Index: test/CodeGen/avx512bf16-builtins.c =================================================================== --- test/CodeGen/avx512bf16-builtins.c +++ test/CodeGen/avx512bf16-builtins.c @@ -4,46 +4,64 @@ #include -__m512bh test_mm512_cvtne2ps2bf16(__m512 A, __m512 B) { - // CHECK-LABEL: @test_mm512_cvtne2ps2bf16 +float test_mm_cvtsbh_ss(__bfloat16 A) { + // CHECK-LABEL: @test_mm_cvtsbh_ss + // CHECK: zext i16 %{{.*}} to i32 + // CHECK: shl i32 %{{.*}}, 16 + // CHECK: bitcast i32 %{{.*}} to float + // CHECK: ret float %{{.*}} + return _mm_cvtsbh_ss(A); +} + +__bfloat16 test_mm_cvtness_sbh(float A) { + // CHECK-LABEL: @test_mm_cvtness_sbh + // CHECK: bitcast float %{{.*}} to i32 + // CHECK: lshr i32 %{{.*}}, 16 + // CHECK: trunc i32 %{{.*}} to i16 + // CHECK: ret i16 %{{.*}} + return _mm_cvtness_sbh(A); +} + +__m512bh test_mm512_cvtne2ps_pbh(__m512 A, __m512 B) { + // CHECK-LABEL: @test_mm512_cvtne2ps_pbh // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512 // CHECK: ret <32 x i16> %{{.*}} return _mm512_cvtne2ps_pbh(A, B); } -__m512bh test_mm512_maskz_cvtne2ps2bf16(__m512 A, __m512 B, __mmask32 U) { - // CHECK-LABEL: @test_mm512_maskz_cvtne2ps2bf16 +__m512bh test_mm512_maskz_cvtne2ps_pbh(__m512 A, __m512 B, __mmask32 U) { + // CHECK-LABEL: @test_mm512_maskz_cvtne2ps_pbh // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512 // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}} // CHECK: ret <32 x i16> %{{.*}} return _mm512_maskz_cvtne2ps_pbh(U, A, B); } -__m512bh test_mm512_mask_cvtne2ps2bf16(__m512bh C, __mmask32 U, __m512 A, __m512 B) { - // CHECK-LABEL: @test_mm512_mask_cvtne2ps2bf16 +__m512bh test_mm512_mask_cvtne2ps_pbh(__m512bh C, __mmask32 U, __m512 A, __m512 B) { + // CHECK-LABEL: @test_mm512_mask_cvtne2ps_pbh // CHECK: @llvm.x86.avx512bf16.cvtne2ps2bf16.512 // CHECK: select <32 x i1> %{{.*}}, <32 x i16> %{{.*}}, <32 x i16> %{{.*}} // CHECK: ret <32 x i16> %{{.*}} return _mm512_mask_cvtne2ps_pbh(C, U, A, B); } -__m256bh test_mm512_cvtneps2bf16(__m512 A) { - // CHECK-LABEL: @test_mm512_cvtneps2bf16 +__m256bh test_mm512_cvtneps_pbh(__m512 A) { + // CHECK-LABEL: @test_mm512_cvtneps_pbh // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512 // CHECK: ret <16 x i16> %{{.*}} return _mm512_cvtneps_pbh(A); } -__m256bh test_mm512_mask_cvtneps2bf16(__m256bh C, __mmask16 U, __m512 A) { - // CHECK-LABEL: @test_mm512_mask_cvtneps2bf16 +__m256bh test_mm512_mask_cvtneps_pbh(__m256bh C, __mmask16 U, __m512 A) { + // CHECK-LABEL: @test_mm512_mask_cvtneps_pbh // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512 // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}} // CHECK: ret <16 x i16> %{{.*}} return _mm512_mask_cvtneps_pbh(C, U, A); } -__m256bh test_mm512_maskz_cvtneps2bf16(__m512 A, __mmask16 U) { - // CHECK-LABEL: @test_mm512_maskz_cvtneps2bf16 +__m256bh test_mm512_maskz_cvtneps_pbh(__m512 A, __mmask16 U) { + // CHECK-LABEL: @test_mm512_maskz_cvtneps_pbh // CHECK: @llvm.x86.avx512bf16.cvtneps2bf16.512 // CHECK: select <16 x i1> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}} // CHECK: ret <16 x i16> %{{.*}} @@ -72,3 +90,61 @@ // CHECK: ret <16 x float> %{{.*}} return _mm512_mask_dpbf16_ps(D, U, A, B); } + +__m512 test_mm512_cvtpbh_ps(__m256bh A) { + // CHECK-LABEL: @test_mm512_cvtpbh_ps + // CHECK: sext <16 x i16> %{{.*}} to <16 x i32> + // CHECK: @llvm.x86.avx512.pslli.d.512 + // CHECK: bitcast <8 x i64> %{{.*}} to <16 x float> + // CHECK: ret <16 x float> %{{.*}} + return _mm512_cvtpbh_ps(A); +} + +__m256 test_mm256_cvtpbh_ps(__m128bh A) { + // CHECK-LABEL: @test_mm256_cvtpbh_ps + // CHECK: sext <8 x i16> %{{.*}} to <8 x i32> + // CHECK: @llvm.x86.avx2.pslli.d + // CHECK: bitcast <4 x i64> %{{.*}} to <8 x float> + // CHECK: ret <8 x float> %{{.*}} + return _mm256_cvtpbh_ps(A); +} + +__m512 test_mm512_maskz_cvtpbh_ps(__mmask16 M, __m256bh A) { + // CHECK-LABEL: @test_mm512_maskz_cvtpbh_ps + // CHECK: sext <16 x i16> %{{.*}} to <16 x i32> + // CHECK: select <16 x i1> %{{.*}}, <16 x i32> %{{.*}}, <16 x i32> %{{.*}} + // CHECK: @llvm.x86.avx512.pslli.d.512 + // CHECK: bitcast <8 x i64> %{{.*}} to <16 x float> + // CHECK: ret <16 x float> %{{.*}} + return _mm512_maskz_cvtpbh_ps(M, A); +} + +__m256 test_mm256_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) { + // CHECK-LABEL: @test_mm256_maskz_cvtpbh_ps + // CHECK: sext <8 x i16> %{{.*}} to <8 x i32> + // CHECK: select <8 x i1> %{{.*}}, <8 x i32> %{{.*}}, <8 x i32> %{{.*}} + // CHECK: @llvm.x86.avx2.pslli.d + // CHECK: bitcast <4 x i64> %{{.*}} to <8 x float> + // CHECK: ret <8 x float> %{{.*}} + return _mm256_maskz_cvtpbh_ps(M, A); +} + +__m512 test_mm512_mask_cvtpbh_ps(__m512 S, __mmask16 M, __m256bh A) { + // CHECK-LABEL: @test_mm512_mask_cvtpbh_ps + // CHECK: sext <16 x i16> %{{.*}} to <16 x i32> + // CHECK: @llvm.x86.avx512.pslli.d.512 + // CHECK: select <16 x i1> %{{.*}}, <16 x i32> %{{.*}}, <16 x i32> %{{.*}} + // CHECK: bitcast <8 x i64> %{{.*}} to <16 x float> + // CHECK: ret <16 x float> %{{.*}} + return _mm512_mask_cvtpbh_ps(S, M, A); +} + +__m256 test_mm256_mask_cvtpbh_ps(__m256 S, __mmask8 M, __m128bh A) { + // CHECK-LABEL: @test_mm256_mask_cvtpbh_ps + // CHECK: sext <8 x i16> %{{.*}} to <8 x i32> + // CHECK: @llvm.x86.avx2.pslli.d + // CHECK: select <8 x i1> %{{.*}}, <8 x i32> %{{.*}}, <8 x i32> %{{.*}} + // CHECK: bitcast <4 x i64> %{{.*}} to <8 x float> + // CHECK: ret <8 x float> %{{.*}} + return _mm256_mask_cvtpbh_ps(S, M, A); +}