Index: llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp =================================================================== --- llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -1450,6 +1450,9 @@ bool IRTranslator::translateCast(unsigned Opcode, const User &U, MachineIRBuilder &MIRBuilder) { + if (U.getType()->getScalarType()->isBFloatTy() || + U.getOperand(0)->getType()->getScalarType()->isBFloatTy()) + return false; Register Op = getOrCreateVReg(*U.getOperand(0)); Register Res = getOrCreateVReg(U); MIRBuilder.buildInstr(Opcode, {Res}, {Op}); Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -800,10 +800,15 @@ setOperationAction(ISD::ConstantFP, MVT::f64, Legal); } + // Converting f64 -> bf16 would need a double-round so we must libcall it + // unless we have fast-math. + setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom); + // AArch64 does not have floating-point extending loads, i1 sign-extending // load, floating-point truncating stores, or v2i32->v2i16 truncating store. for (MVT VT : MVT::fp_valuetypes()) { setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand); + setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand); setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand); setLoadExtAction(ISD::EXTLOAD, VT, MVT::f64, Expand); setLoadExtAction(ISD::EXTLOAD, VT, MVT::f80, Expand); @@ -812,12 +817,15 @@ setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Expand); setTruncStoreAction(MVT::f32, MVT::f16, Expand); + setTruncStoreAction(MVT::f32, MVT::bf16, Expand); setTruncStoreAction(MVT::f64, MVT::f32, Expand); setTruncStoreAction(MVT::f64, MVT::f16, Expand); + setTruncStoreAction(MVT::f64, MVT::bf16, Expand); setTruncStoreAction(MVT::f128, MVT::f80, Expand); setTruncStoreAction(MVT::f128, MVT::f64, Expand); setTruncStoreAction(MVT::f128, MVT::f32, Expand); setTruncStoreAction(MVT::f128, MVT::f16, Expand); + setTruncStoreAction(MVT::f128, MVT::bf16, Expand); setOperationAction(ISD::BITCAST, MVT::i16, Custom); setOperationAction(ISD::BITCAST, MVT::f16, Custom); @@ -3133,6 +3141,12 @@ if (useSVEForFixedLengthVectorVT(SrcVT)) return SDValue(); + if (Op.getValueType() == MVT::bf16 && SrcVT != MVT::f32 && + !DAG.getTarget().Options.UnsafeFPMath) { + report_fatal_error("No way to correctly truncate anything but float to bfloat"); + return SDValue(); + } + // It's legal except when f128 is involved return Op; } Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -873,6 +873,9 @@ (BF16DOTlanev4bf16 (v2f32 V64:$Rd), (v4bf16 V64:$Rn), (SUBREG_TO_REG (i32 0), V64:$Rm, dsub), VectorIndexS:$idx)>; + +def : Pat<(bf16 (fpround f32:$src)), (BFCVT $src)>; +def : Pat<(v4bf16 (fpround v4f32:$src)), (EXTRACT_SUBREG (BFCVTN $src), dsub)>; } // ARMv8.6A AArch64 matrix multiplication @@ -4166,6 +4169,19 @@ def : Pat<(v2f32 (AArch64rev64 V64:$Rn)), (REV64v2i32 V64:$Rn)>; def : Pat<(v4f32 (AArch64rev64 V128:$Rn)), (REV64v4i32 V128:$Rn)>; +def : Pat<(f32 (fpextend bf16:$src)), + (EXTRACT_SUBREG (SHLLv4i16 (SUBREG_TO_REG (i64 0), $src, hsub)), ssub)>; + +def : Pat<(v4f32 (fpextend v4bf16:$src)), (SHLLv4i16 $src)>; + +def : Pat<(f64 (fpextend bf16:$src)), + (FCVTDSr (EXTRACT_SUBREG + (SHLLv4i16 (SUBREG_TO_REG (i64 0), $src, hsub)), + ssub))>; + +def : Pat<(bf16 (fpround f64:$src)), (BFCVT (FCVTSDr $src))>; + + // Patterns for vector long shift (by element width). These need to match all // three of zext, sext and anyext so it's easier to pull the patterns out of the // definition. Index: llvm/test/CodeGen/AArch64/bf16-conversions.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/bf16-conversions.ll @@ -0,0 +1,167 @@ +; RUN: llc -mtriple=arm64-apple-macosx %s -o - -mattr=+bf16 | FileCheck %s +; RUN: llc -mtriple=arm64-apple-macosx %s -o - -mattr=+bf16 -global-isel -global-isel-abort=0 | FileCheck %s +; RUN: llc -mtriple=arm64-apple-macosx %s -o - -mattr=+bf16 -fast-isel | FileCheck %s + +define <4 x float> @extendvec_bf16_f32(<4 x bfloat> %in) { +; CHECK-LABEL: extendvec_bf16_f32: +; CHECK: shll.4s v0, v0, #16 + + %res = fpext <4 x bfloat> %in to <4 x float> + ret <4 x float> %res +} + +define float @extend_bf16_f32(bfloat %in) { +; CHECK-LABEL: extend_bf16_f32: +; CHECK: shll.4s v0, v0, #16 + + %res = fpext bfloat %in to float + ret float %res +} + +; Scalarized +define <4 x double> @extendvec_bf16_f64(<4 x bfloat> %in) { +; CHECK-LABEL: extendvec_bf16_f64: +; CHECK: shll.4s v[[TMP:[0-9]]], {{.*}}, #16 +; CHECK: fcvt {{d.*}}, s[[TMP]] +; CHECK: fcvt +; CHECK: fcvt +; CHECK: fcvt + + %res = fpext <4 x bfloat> %in to <4 x double> + ret <4 x double> %res +} + +define double @extend_bf16_f64(bfloat %in) { +; CHECK-LABEL: extend_bf16_f64 +; CHECK: shll.4s v[[TMP:[0-9]]], v0, #16 +; CHECK: fcvt d0, s[[TMP]] + %res = fpext bfloat %in to double + ret double %res +} + +define <4 x bfloat> @truncvec_f32_bf16(<4 x float> %in) { +; CHECK-LABEL: truncvec_f32_bf16: +; CHECK: bfcvtn.4h v0, v0 + %res = fptrunc <4 x float> %in to <4 x bfloat> + ret <4 x bfloat> %res +} + +define bfloat @trunc_f32_bf16(float %in) { +; CHECK-LABEL: trunc_f32_bf16: +; CHECK: bfcvt h0, s0 + %res = fptrunc float %in to bfloat + ret bfloat %res +} + +; Scalarized +define <4 x bfloat> @truncvec_f64_bf16(<4 x double> %in) "unsafe-fp-math"="true" { +; CHECK-LABEL: truncvec_f64_bf16: +; CHECK: fcvt [[TMP:s[0-9]+]], {{d.*}} +; CHECK: bfcvt {{h.*}}, [[TMP]] +; CHECK: bfcvt +; CHECK: bfcvt +; CHECK: bfcvt + + %res = fptrunc <4 x double> %in to <4 x bfloat> + ret <4 x bfloat> %res +} + +define bfloat @trunc_f64_bf16(double %in) "unsafe-fp-math"="true" { +; CHECK-LABEL: trunc_f64_bf16: +; CHECK: fcvt [[TMP:s[0-9]+]], d0 +; CHECK: bfcvt h0, [[TMP]] + + %res = fptrunc double %in to bfloat + ret bfloat %res +} + +define float @extload_bf16_f32(bfloat* %ptr) { +; CHECK-LABEL: extload_bf16_f32: +; CHECK: ldr h[[TMP:[0-9]+]], [x0] +; CHECK: shll.4s v0, v[[TMP]], #16 + + %tmp = load bfloat, bfloat* %ptr + %res = fpext bfloat %tmp to float + ret float %res +} + +define double @extload_bf16_f64(bfloat* %ptr) { +; CHECK-LABEL: extload_bf16_f64: +; CHECK: ldr h[[TMP:[0-9]+]], [x0] +; CHECK: shll.4s v[[TMP1:[0-9]+]], v[[TMP]], #16 +; CHECK: fcvt d0, s[[TMP1]] + + %tmp = load bfloat, bfloat* %ptr + %res = fpext bfloat %tmp to double + ret double %res +} + +define <4 x float> @extloadvec_bf16_f32(<4 x bfloat>* %ptr) { +; CHECK-LABEL: extloadvec_bf16_f32: +; CHECK: ldr d[[TMP:[0-9]+]], [x0] +; CHECK: shll.4s v0, v[[TMP]], #16 + + %tmp = load <4 x bfloat>, <4 x bfloat>* %ptr + %res = fpext <4 x bfloat> %tmp to <4 x float> + ret <4 x float> %res +} + +; Scalarized +define <4 x double> @extloadvec_bf16_f64(<4 x bfloat>* %ptr) { +; CHECK-LABEL: extloadvec_bf16_f64: +; CHECK: ldr d[[TMP:[0-9]+]], [x0] +; CHECK: shll.4s v[[TMP1:[0-9]+]], v[[TMP]], #16 +; CHECK: fcvt {{d.*}}, s[[TMP1]] +; CHECK: fcvt +; CHECK: fcvt +; CHECK: fcvt + + %tmp = load <4 x bfloat>, <4 x bfloat>* %ptr + %res = fpext <4 x bfloat> %tmp to <4 x double> + ret <4 x double> %res +} + +define void @truncstore_f32_bf16(float %in, bfloat* %ptr) { +; CHECK-LABEL: truncstore_f32_bf16: +; CHECK: bfcvt [[TMP:h[0-9]+]], s0 +; CHECK: str [[TMP]], [x0] + + %val = fptrunc float %in to bfloat + store bfloat %val, bfloat* %ptr + ret void +} + +define void @truncstore_f64_bf16(double %in, bfloat* %ptr) "unsafe-fp-math"="true" { +; CHECK-LABEL: truncstore_f64_bf16: +; CHECK: fcvt [[TMP:s[0-9]+]], d0 +; CHECK: bfcvt [[TMP1:h[0-9]+]], [[TMP]] +; CHECK: str [[TMP1]], [x0] + + + %val = fptrunc double %in to bfloat + store bfloat %val, bfloat* %ptr + ret void +} + +define void @truncstorevec_f32_bf16(<4 x float> %in, <4 x bfloat>* %ptr) { +; CHECK-LABEL: truncstorevec_f32_bf16: +; CHECK: bfcvtn.4h v[[TMP:[0-9]+]], v0 +; CHECK: str d[[TMP]], [x0] + %val = fptrunc <4 x float> %in to <4 x bfloat> + store <4 x bfloat> %val, <4 x bfloat>* %ptr + ret void +} + +; Scalarized +define void @truncstorevec_f64_bf16(<4 x double> %in, <4 x bfloat>* %ptr) "unsafe-fp-math"="true" { +; CHECK-LABEL: truncstorevec_f64_bf16: +; CHECK: fcvt [[TMP:s[0-9]+]], d0 +; CHECK: bfcvt [[TMP1:h[0-9]+]], [[TMP]] +; CHECK: bfcvt +; CHECK: bfcvt +; CHECK: bfcvt + + %val = fptrunc <4 x double> %in to <4 x bfloat> + store <4 x bfloat> %val, <4 x bfloat>* %ptr + ret void +}