diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3746,6 +3746,10 @@ if (OpVT != MVT::f16 && OpVT != MVT::bf16) return SDValue(); + // Bitcasts between f16 and bf16 are legal. + if (ArgVT == MVT::f16 || ArgVT == MVT::bf16) + return Op; + assert(ArgVT == MVT::i16); SDLoc DL(Op); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -7550,6 +7550,9 @@ def : Pat<(i64 (bitconvert (v1f64 V64:$Vn))), (COPY_TO_REGCLASS V64:$Vn, GPR64)>; +def : Pat<(f16 (bitconvert (bf16 FPR16:$src))), (f16 FPR16:$src)>; +def : Pat<(bf16 (bitconvert (f16 FPR16:$src))), (bf16 FPR16:$src)>; + let Predicates = [IsLE] in { def : Pat<(v1i64 (bitconvert (v2i32 FPR64:$src))), (v1i64 FPR64:$src)>; def : Pat<(v1i64 (bitconvert (v4i16 FPR64:$src))), (v1i64 FPR64:$src)>; diff --git a/llvm/test/CodeGen/AArch64/bf16.ll b/llvm/test/CodeGen/AArch64/bf16.ll --- a/llvm/test/CodeGen/AArch64/bf16.ll +++ b/llvm/test/CodeGen/AArch64/bf16.ll @@ -82,3 +82,17 @@ ret { <8 x bfloat>, <8 x bfloat>* } %res } + +define bfloat @test_bitcast_halftobfloat(half %a) nounwind { +; CHECK-LABEL: test_bitcast_halftobfloat: +; CHECK-NEXT: ret + %r = bitcast half %a to bfloat + ret bfloat %r +} + +define half @test_bitcast_bfloattohalf(bfloat %a) nounwind { +; CHECK-LABEL: test_bitcast_bfloattohalf: +; CHECK-NEXT: ret + %r = bitcast bfloat %a to half + ret half %r +}