diff --git a/llvm/lib/Target/ARM/ARMCallingConv.cpp b/llvm/lib/Target/ARM/ARMCallingConv.cpp --- a/llvm/lib/Target/ARM/ARMCallingConv.cpp +++ b/llvm/lib/Target/ARM/ARMCallingConv.cpp @@ -209,14 +209,17 @@ break; } case MVT::f16: + case MVT::bf16: case MVT::f32: RegList = SRegList; break; case MVT::v4f16: + case MVT::v4bf16: case MVT::f64: RegList = DRegList; break; case MVT::v8f16: + case MVT::v8bf16: case MVT::v2f64: RegList = QRegList; break; diff --git a/llvm/lib/Target/ARM/ARMCallingConv.td b/llvm/lib/Target/ARM/ARMCallingConv.td --- a/llvm/lib/Target/ARM/ARMCallingConv.td +++ b/llvm/lib/Target/ARM/ARMCallingConv.td @@ -30,8 +30,8 @@ CCIfSwiftError>>, // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, // f64 and v2f64 are passed in adjacent GPRs, possibly split onto the stack CCIfType<[f64, v2f64], CCCustom<"CC_ARM_APCS_Custom_f64">>, @@ -56,8 +56,8 @@ CCIfSwiftError>>, // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, CCIfType<[f64, v2f64], CCCustom<"RetCC_ARM_APCS_Custom_f64">>, @@ -71,8 +71,8 @@ let Entry = 1 in def FastCC_ARM_APCS : CallingConv<[ // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, CCIfType<[v2f64], CCAssignToReg<[Q0, Q1, Q2, Q3]>>, CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>, @@ -91,8 +91,8 @@ let Entry = 1 in def RetFastCC_ARM_APCS : CallingConv<[ // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, CCIfType<[v2f64], CCAssignToReg<[Q0, Q1, Q2, Q3]>>, CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>, @@ -108,8 +108,8 @@ let Entry = 1 in def CC_ARM_APCS_GHC : CallingConv<[ // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, CCIfType<[v2f64], CCAssignToReg<[Q4, Q5]>>, CCIfType<[f64], CCAssignToReg<[D8, D9, D10, D11]>>, @@ -139,7 +139,7 @@ CCIfType<[i32], CCIfAlign<"8", CCAssignToStackWithShadow<4, 8, [R0, R1, R2, R3]>>>, CCIfType<[i32], CCAssignToStackWithShadow<4, 4, [R0, R1, R2, R3]>>, - CCIfType<[f16, f32], CCAssignToStackWithShadow<4, 4, [Q0, Q1, Q2, Q3]>>, + CCIfType<[f16, bf16, f32], CCAssignToStackWithShadow<4, 4, [Q0, Q1, Q2, Q3]>>, CCIfType<[f64], CCAssignToStackWithShadow<8, 8, [Q0, Q1, Q2, Q3]>>, CCIfType<[v2f64], CCIfAlign<"16", CCAssignToStackWithShadow<16, 16, [Q0, Q1, Q2, Q3]>>>, @@ -165,8 +165,8 @@ CCIfNest>, // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, // Pass SwiftSelf in a callee saved register. CCIfSwiftSelf>>, @@ -176,15 +176,15 @@ CCIfType<[f64, v2f64], CCCustom<"CC_ARM_AAPCS_Custom_f64">>, CCIfType<[f32], CCBitConvertToType>, - CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>, + CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>, CCDelegateTo ]>; let Entry = 1 in def RetCC_ARM_AAPCS : CallingConv<[ // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, // Pass SwiftSelf in a callee saved register. CCIfSwiftSelf>>, @@ -194,7 +194,7 @@ CCIfType<[f64, v2f64], CCCustom<"RetCC_ARM_AAPCS_Custom_f64">>, CCIfType<[f32], CCBitConvertToType>, - CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>, + CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_Custom_f16">>, CCDelegateTo ]>; @@ -210,8 +210,8 @@ CCIfByVal>, // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, // Pass SwiftSelf in a callee saved register. CCIfSwiftSelf>>, @@ -226,15 +226,15 @@ CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>, CCIfType<[f32], CCAssignToReg<[S0, S1, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, S12, S13, S14, S15]>>, - CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>, + CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>, CCDelegateTo ]>; let Entry = 1 in def RetCC_ARM_AAPCS_VFP : CallingConv<[ // Handle all vector types as either f64 or v2f64. - CCIfType<[v1i64, v2i32, v4i16, v4f16, v8i8, v2f32], CCBitConvertToType>, - CCIfType<[v2i64, v4i32, v8i16, v8f16, v16i8, v4f32], CCBitConvertToType>, + CCIfType<[v1i64, v2i32, v4i16, v4f16, v4bf16, v8i8, v2f32], CCBitConvertToType>, + CCIfType<[v2i64, v4i32, v8i16, v8f16, v8bf16, v16i8, v4f32], CCBitConvertToType>, // Pass SwiftSelf in a callee saved register. CCIfSwiftSelf>>, @@ -246,7 +246,7 @@ CCIfType<[f64], CCAssignToReg<[D0, D1, D2, D3, D4, D5, D6, D7]>>, CCIfType<[f32], CCAssignToReg<[S0, S1, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, S12, S13, S14, S15]>>, - CCIfType<[f16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>, + CCIfType<[f16, bf16], CCCustom<"CC_ARM_AAPCS_VFP_Custom_f16">>, CCDelegateTo ]>; diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -723,6 +723,9 @@ setOperationAction(ISD::FMAXNUM, MVT::f16, Legal); } + if (Subtarget->hasBF16()) + addRegisterClass(MVT::bf16, &ARM::HPRRegClass); + for (MVT VT : MVT::fixedlen_vector_valuetypes()) { for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) { setTruncStoreAction(VT, InnerVT, Expand); @@ -770,6 +773,11 @@ addQRTypeForNEON(MVT::v8f16); addDRTypeForNEON(MVT::v4f16); } + + if (Subtarget->hasBF16()) { + addQRTypeForNEON(MVT::v8bf16); + addDRTypeForNEON(MVT::v4bf16); + } } if (Subtarget->hasMVEIntegerOps() || Subtarget->hasNEON()) { @@ -2077,9 +2085,12 @@ // f16 arguments have their size extended to 4 bytes and passed as if they // had been copied to the LSBs of a 32-bit register. // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI) - if (VA.needsCustom() && VA.getValVT() == MVT::f16) { - assert(Subtarget->hasFullFP16() && + if (VA.needsCustom() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) { + assert((VA.getValVT() == MVT::f16) == Subtarget->hasFullFP16() && "Lowering f16 type return without full fp16 support"); + assert((VA.getValVT() == MVT::bf16) == Subtarget->hasBF16() && + "Lowering bf16 type return without bfloat support"); Val = DAG.getNode(ISD::BITCAST, dl, MVT::getIntegerVT(VA.getLocVT().getSizeInBits()), Val); Val = DAG.getNode(ARMISD::VMOVhr, dl, VA.getValVT(), Val); @@ -2256,9 +2267,12 @@ // f16 arguments have their size extended to 4 bytes and passed as if they // had been copied to the LSBs of a 32-bit register. // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI) - if (VA.needsCustom() && VA.getValVT() == MVT::f16) { - assert(Subtarget->hasFullFP16() && - "Lowering f16 type argument without full fp16 support"); + if (VA.needsCustom() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) { + assert((VA.getValVT() == MVT::f16) == Subtarget->hasFullFP16() && + "Lowering f16 type return without full fp16 support"); + assert((VA.getValVT() == MVT::bf16) == Subtarget->hasBF16() && + "Lowering bf16 type return without bfloat support"); Arg = DAG.getNode(ARMISD::VMOVrh, dl, MVT::getIntegerVT(VA.getLocVT().getSizeInBits()), Arg); Arg = DAG.getNode(ISD::BITCAST, dl, VA.getLocVT(), Arg); @@ -2971,8 +2985,8 @@ // Guarantee that all emitted copies are // stuck together, avoiding something bad. Flag = Chain.getValue(1); - RetOps.push_back(DAG.getRegister(VA.getLocReg(), - ReturnF16 ? MVT::f16 : VA.getLocVT())); + RetOps.push_back(DAG.getRegister(VA.getLocReg(), ReturnF16 ? + Arg.getValueType() : VA.getLocVT())); } const ARMBaseRegisterInfo *TRI = Subtarget->getRegisterInfo(); const MCPhysReg *I = @@ -4105,7 +4119,8 @@ unsigned NumParts, MVT PartVT, Optional CC) const { bool IsABIRegCopy = CC.hasValue(); EVT ValueVT = Val.getValueType(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { unsigned ValueBits = ValueVT.getSizeInBits(); unsigned PartBits = PartVT.getSizeInBits(); Val = DAG.getNode(ISD::BITCAST, DL, MVT::getIntegerVT(ValueBits), Val); @@ -4121,7 +4136,8 @@ SelectionDAG &DAG, const SDLoc &DL, const SDValue *Parts, unsigned NumParts, MVT PartVT, EVT ValueVT, Optional CC) const { bool IsABIRegCopy = CC.hasValue(); - if (IsABIRegCopy && ValueVT == MVT::f16 && PartVT == MVT::f32) { + if (IsABIRegCopy && (ValueVT == MVT::f16 || ValueVT == MVT::bf16) && + PartVT == MVT::f32) { unsigned ValueBits = ValueVT.getSizeInBits(); unsigned PartBits = PartVT.getSizeInBits(); SDValue Val = Parts[0]; @@ -4233,13 +4249,15 @@ const TargetRegisterClass *RC; - if (RegVT == MVT::f16) + if (RegVT == MVT::f16 || RegVT == MVT::bf16) RC = &ARM::HPRRegClass; else if (RegVT == MVT::f32) RC = &ARM::SPRRegClass; - else if (RegVT == MVT::f64 || RegVT == MVT::v4f16) + else if (RegVT == MVT::f64 || RegVT == MVT::v4f16 || + RegVT == MVT::v4bf16) RC = &ARM::DPRRegClass; - else if (RegVT == MVT::v2f64 || RegVT == MVT::v8f16) + else if (RegVT == MVT::v2f64 || RegVT == MVT::v8f16 || + RegVT == MVT::v8bf16) RC = &ARM::QPRRegClass; else if (RegVT == MVT::i32) RC = AFI->isThumb1OnlyFunction() ? &ARM::tGPRRegClass @@ -4282,9 +4300,12 @@ // f16 arguments have their size extended to 4 bytes and passed as if they // had been copied to the LSBs of a 32-bit register. // For that, it's passed extended to i32 (soft ABI) or to f32 (hard ABI) - if (VA.needsCustom() && VA.getValVT() == MVT::f16) { - assert(Subtarget->hasFullFP16() && - "Lowering f16 type argument without full fp16 support"); + if (VA.needsCustom() && + (VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) { + assert((VA.getValVT() == MVT::f16) == Subtarget->hasFullFP16() && + "Lowering f16 type return without full fp16 support"); + assert((VA.getValVT() == MVT::bf16) == Subtarget->hasBF16() && + "Lowering bf16 type return without bfloat support"); ArgValue = DAG.getNode(ISD::BITCAST, dl, MVT::getIntegerVT(VA.getLocVT().getSizeInBits()), ArgValue); @@ -5070,6 +5091,8 @@ return !Subtarget->hasFP64(); if (VT == MVT::f16) return !Subtarget->hasFullFP16(); + if (VT == MVT::bf16) + return !Subtarget->hasBF16(); return false; } @@ -5880,18 +5903,20 @@ EVT SrcVT = Op.getValueType(); EVT DstVT = N->getValueType(0); - if (SrcVT == MVT::i16 && DstVT == MVT::f16) { - if (!Subtarget->hasFullFP16()) + if (SrcVT == MVT::i16 && (DstVT == MVT::f16 || DstVT == MVT::bf16)) { + if ((DstVT == MVT::f16 && !Subtarget->hasFullFP16()) || + (DstVT == MVT::bf16 && !Subtarget->hasBF16())) return SDValue(); - // f16 bitcast i16 -> VMOVhr - return DAG.getNode(ARMISD::VMOVhr, SDLoc(N), MVT::f16, + // (b)f16 bitcast i16 -> VMOVhr + return DAG.getNode(ARMISD::VMOVhr, SDLoc(N), DstVT, DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), MVT::i32, Op)); } - if (SrcVT == MVT::f16 && DstVT == MVT::i16) { - if (!Subtarget->hasFullFP16()) + if ((SrcVT == MVT::f16 || SrcVT == MVT::bf16) && DstVT == MVT::i16) { + if ((SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) || + (SrcVT == MVT::bf16 && !Subtarget->hasBF16())) return SDValue(); - // i16 bitcast f16 -> VMOVrh + // i16 bitcast (b)f16 -> VMOVrh return DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::i16, DAG.getNode(ARMISD::VMOVrh, SDLoc(N), MVT::i32, Op)); } @@ -13159,7 +13184,7 @@ Copy->getOpcode() == ISD::CopyFromReg) { SDValue Ops[] = {Copy->getOperand(0), Copy->getOperand(1)}; SDValue NewCopy = - DCI.DAG.getNode(ISD::CopyFromReg, SDLoc(N), MVT::f16, Ops); + DCI.DAG.getNode(ISD::CopyFromReg, SDLoc(N), N->getValueType(0), Ops); return NewCopy; } } @@ -13168,7 +13193,7 @@ if (LoadSDNode *LN0 = dyn_cast(Op0)) { if (LN0->hasOneUse() && LN0->isUnindexed() && LN0->getMemoryVT() == MVT::i16) { - SDValue Load = DCI.DAG.getLoad(MVT::f16, SDLoc(N), LN0->getChain(), + SDValue Load = DCI.DAG.getLoad(N->getValueType(0), SDLoc(N), LN0->getChain(), LN0->getBasePtr(), LN0->getMemOperand()); DCI.DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Load.getValue(0)); DCI.DAG.ReplaceAllUsesOfValueWith(Op0.getValue(1), Load.getValue(1)); diff --git a/llvm/lib/Target/ARM/ARMInstrNEON.td b/llvm/lib/Target/ARM/ARMInstrNEON.td --- a/llvm/lib/Target/ARM/ARMInstrNEON.td +++ b/llvm/lib/Target/ARM/ARMInstrNEON.td @@ -7395,6 +7395,9 @@ def : Pat<(v4i16 (bitconvert (v4f16 DPR:$src))), (v4i16 DPR:$src)>; def : Pat<(v4f16 (bitconvert (v4i16 DPR:$src))), (v4f16 DPR:$src)>; +def : Pat<(v4i16 (bitconvert (v4bf16 DPR:$src))), (v4i16 DPR:$src)>; +def : Pat<(v4bf16 (bitconvert (v4i16 DPR:$src))), (v4bf16 DPR:$src)>; + // 128 bit conversions def : Pat<(v2f64 (bitconvert (v2i64 QPR:$src))), (v2f64 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v2f64 QPR:$src))), (v2i64 QPR:$src)>; @@ -7404,6 +7407,9 @@ def : Pat<(v8i16 (bitconvert (v8f16 QPR:$src))), (v8i16 QPR:$src)>; def : Pat<(v8f16 (bitconvert (v8i16 QPR:$src))), (v8f16 QPR:$src)>; + +def : Pat<(v8i16 (bitconvert (v8bf16 QPR:$src))), (v8i16 QPR:$src)>; +def : Pat<(v8bf16 (bitconvert (v8i16 QPR:$src))), (v8bf16 QPR:$src)>; } let Predicates = [IsLE,HasNEON] in { @@ -7411,24 +7417,28 @@ def : Pat<(f64 (bitconvert (v2f32 DPR:$src))), (f64 DPR:$src)>; def : Pat<(f64 (bitconvert (v2i32 DPR:$src))), (f64 DPR:$src)>; def : Pat<(f64 (bitconvert (v4f16 DPR:$src))), (f64 DPR:$src)>; + def : Pat<(f64 (bitconvert (v4bf16 DPR:$src))), (f64 DPR:$src)>; def : Pat<(f64 (bitconvert (v4i16 DPR:$src))), (f64 DPR:$src)>; def : Pat<(f64 (bitconvert (v8i8 DPR:$src))), (f64 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v2f32 DPR:$src))), (v1i64 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v2i32 DPR:$src))), (v1i64 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v4f16 DPR:$src))), (v1i64 DPR:$src)>; + def : Pat<(v1i64 (bitconvert (v4bf16 DPR:$src))), (v1i64 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v4i16 DPR:$src))), (v1i64 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v8i8 DPR:$src))), (v1i64 DPR:$src)>; def : Pat<(v2f32 (bitconvert (f64 DPR:$src))), (v2f32 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v1i64 DPR:$src))), (v2f32 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v4f16 DPR:$src))), (v2f32 DPR:$src)>; + def : Pat<(v2f32 (bitconvert (v4bf16 DPR:$src))), (v2f32 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v4i16 DPR:$src))), (v2f32 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v8i8 DPR:$src))), (v2f32 DPR:$src)>; def : Pat<(v2i32 (bitconvert (f64 DPR:$src))), (v2i32 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v1i64 DPR:$src))), (v2i32 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v4f16 DPR:$src))), (v2i32 DPR:$src)>; + def : Pat<(v2i32 (bitconvert (v4bf16 DPR:$src))), (v2i32 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v4i16 DPR:$src))), (v2i32 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v8i8 DPR:$src))), (v2i32 DPR:$src)>; @@ -7438,6 +7448,12 @@ def : Pat<(v4f16 (bitconvert (v2i32 DPR:$src))), (v4f16 DPR:$src)>; def : Pat<(v4f16 (bitconvert (v8i8 DPR:$src))), (v4f16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (f64 DPR:$src))), (v4bf16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v1i64 DPR:$src))), (v4bf16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v2f32 DPR:$src))), (v4bf16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v2i32 DPR:$src))), (v4bf16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v8i8 DPR:$src))), (v4bf16 DPR:$src)>; + def : Pat<(v4i16 (bitconvert (f64 DPR:$src))), (v4i16 DPR:$src)>; def : Pat<(v4i16 (bitconvert (v1i64 DPR:$src))), (v4i16 DPR:$src)>; def : Pat<(v4i16 (bitconvert (v2f32 DPR:$src))), (v4i16 DPR:$src)>; @@ -7449,30 +7465,35 @@ def : Pat<(v8i8 (bitconvert (v2f32 DPR:$src))), (v8i8 DPR:$src)>; def : Pat<(v8i8 (bitconvert (v2i32 DPR:$src))), (v8i8 DPR:$src)>; def : Pat<(v8i8 (bitconvert (v4f16 DPR:$src))), (v8i8 DPR:$src)>; + def : Pat<(v8i8 (bitconvert (v4bf16 DPR:$src))), (v8i8 DPR:$src)>; def : Pat<(v8i8 (bitconvert (v4i16 DPR:$src))), (v8i8 DPR:$src)>; // 128 bit conversions def : Pat<(v2f64 (bitconvert (v4f32 QPR:$src))), (v2f64 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v4i32 QPR:$src))), (v2f64 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v8f16 QPR:$src))), (v2f64 QPR:$src)>; + def : Pat<(v2f64 (bitconvert (v8bf16 QPR:$src))), (v2f64 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v8i16 QPR:$src))), (v2f64 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v16i8 QPR:$src))), (v2f64 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v4f32 QPR:$src))), (v2i64 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v4i32 QPR:$src))), (v2i64 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v8f16 QPR:$src))), (v2i64 QPR:$src)>; + def : Pat<(v2i64 (bitconvert (v8bf16 QPR:$src))), (v2i64 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v8i16 QPR:$src))), (v2i64 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v16i8 QPR:$src))), (v2i64 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v2f64 QPR:$src))), (v4f32 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v2i64 QPR:$src))), (v4f32 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v8f16 QPR:$src))), (v4f32 QPR:$src)>; + def : Pat<(v4f32 (bitconvert (v8bf16 QPR:$src))), (v4f32 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v8i16 QPR:$src))), (v4f32 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v16i8 QPR:$src))), (v4f32 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v2f64 QPR:$src))), (v4i32 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v2i64 QPR:$src))), (v4i32 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v8f16 QPR:$src))), (v4i32 QPR:$src)>; + def : Pat<(v4i32 (bitconvert (v8bf16 QPR:$src))), (v4i32 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v8i16 QPR:$src))), (v4i32 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v16i8 QPR:$src))), (v4i32 QPR:$src)>; @@ -7482,6 +7503,12 @@ def : Pat<(v8f16 (bitconvert (v4i32 QPR:$src))), (v8f16 QPR:$src)>; def : Pat<(v8f16 (bitconvert (v16i8 QPR:$src))), (v8f16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v2f64 QPR:$src))), (v8bf16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v2i64 QPR:$src))), (v8bf16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v4f32 QPR:$src))), (v8bf16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v4i32 QPR:$src))), (v8bf16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v16i8 QPR:$src))), (v8bf16 QPR:$src)>; + def : Pat<(v8i16 (bitconvert (v2f64 QPR:$src))), (v8i16 QPR:$src)>; def : Pat<(v8i16 (bitconvert (v2i64 QPR:$src))), (v8i16 QPR:$src)>; def : Pat<(v8i16 (bitconvert (v4f32 QPR:$src))), (v8i16 QPR:$src)>; @@ -7493,6 +7520,7 @@ def : Pat<(v16i8 (bitconvert (v4f32 QPR:$src))), (v16i8 QPR:$src)>; def : Pat<(v16i8 (bitconvert (v4i32 QPR:$src))), (v16i8 QPR:$src)>; def : Pat<(v16i8 (bitconvert (v8f16 QPR:$src))), (v16i8 QPR:$src)>; + def : Pat<(v16i8 (bitconvert (v8bf16 QPR:$src))), (v16i8 QPR:$src)>; def : Pat<(v16i8 (bitconvert (v8i16 QPR:$src))), (v16i8 QPR:$src)>; } @@ -7501,24 +7529,28 @@ def : Pat<(f64 (bitconvert (v2f32 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(f64 (bitconvert (v2i32 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(f64 (bitconvert (v4f16 DPR:$src))), (VREV64d16 DPR:$src)>; + def : Pat<(f64 (bitconvert (v4bf16 DPR:$src))), (VREV64d16 DPR:$src)>; def : Pat<(f64 (bitconvert (v4i16 DPR:$src))), (VREV64d16 DPR:$src)>; def : Pat<(f64 (bitconvert (v8i8 DPR:$src))), (VREV64d8 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v2f32 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v2i32 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v4f16 DPR:$src))), (VREV64d16 DPR:$src)>; + def : Pat<(v1i64 (bitconvert (v4bf16 DPR:$src))), (VREV64d16 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v4i16 DPR:$src))), (VREV64d16 DPR:$src)>; def : Pat<(v1i64 (bitconvert (v8i8 DPR:$src))), (VREV64d8 DPR:$src)>; def : Pat<(v2f32 (bitconvert (f64 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v1i64 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v4f16 DPR:$src))), (VREV32d16 DPR:$src)>; + def : Pat<(v2f32 (bitconvert (v4bf16 DPR:$src))), (VREV32d16 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v4i16 DPR:$src))), (VREV32d16 DPR:$src)>; def : Pat<(v2f32 (bitconvert (v8i8 DPR:$src))), (VREV32d8 DPR:$src)>; def : Pat<(v2i32 (bitconvert (f64 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v1i64 DPR:$src))), (VREV64d32 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v4f16 DPR:$src))), (VREV32d16 DPR:$src)>; + def : Pat<(v2i32 (bitconvert (v4bf16 DPR:$src))), (VREV32d16 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v4i16 DPR:$src))), (VREV32d16 DPR:$src)>; def : Pat<(v2i32 (bitconvert (v8i8 DPR:$src))), (VREV32d8 DPR:$src)>; @@ -7528,6 +7560,12 @@ def : Pat<(v4f16 (bitconvert (v2i32 DPR:$src))), (VREV32d16 DPR:$src)>; def : Pat<(v4f16 (bitconvert (v8i8 DPR:$src))), (VREV16d8 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (f64 DPR:$src))), (VREV64d16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v1i64 DPR:$src))), (VREV64d16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v2f32 DPR:$src))), (VREV32d16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v2i32 DPR:$src))), (VREV32d16 DPR:$src)>; + def : Pat<(v4bf16 (bitconvert (v8i8 DPR:$src))), (VREV16d8 DPR:$src)>; + def : Pat<(v4i16 (bitconvert (f64 DPR:$src))), (VREV64d16 DPR:$src)>; def : Pat<(v4i16 (bitconvert (v1i64 DPR:$src))), (VREV64d16 DPR:$src)>; def : Pat<(v4i16 (bitconvert (v2f32 DPR:$src))), (VREV32d16 DPR:$src)>; @@ -7539,30 +7577,35 @@ def : Pat<(v8i8 (bitconvert (v2f32 DPR:$src))), (VREV32d8 DPR:$src)>; def : Pat<(v8i8 (bitconvert (v2i32 DPR:$src))), (VREV32d8 DPR:$src)>; def : Pat<(v8i8 (bitconvert (v4f16 DPR:$src))), (VREV16d8 DPR:$src)>; + def : Pat<(v8i8 (bitconvert (v4bf16 DPR:$src))), (VREV16d8 DPR:$src)>; def : Pat<(v8i8 (bitconvert (v4i16 DPR:$src))), (VREV16d8 DPR:$src)>; // 128 bit conversions def : Pat<(v2f64 (bitconvert (v4f32 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v4i32 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v8f16 QPR:$src))), (VREV64q16 QPR:$src)>; + def : Pat<(v2f64 (bitconvert (v8bf16 QPR:$src))), (VREV64q16 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v8i16 QPR:$src))), (VREV64q16 QPR:$src)>; def : Pat<(v2f64 (bitconvert (v16i8 QPR:$src))), (VREV64q8 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v4f32 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v4i32 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v8f16 QPR:$src))), (VREV64q16 QPR:$src)>; + def : Pat<(v2i64 (bitconvert (v8bf16 QPR:$src))), (VREV64q16 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v8i16 QPR:$src))), (VREV64q16 QPR:$src)>; def : Pat<(v2i64 (bitconvert (v16i8 QPR:$src))), (VREV64q8 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v2f64 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v2i64 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v8f16 QPR:$src))), (VREV32q16 QPR:$src)>; + def : Pat<(v4f32 (bitconvert (v8bf16 QPR:$src))), (VREV32q16 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v8i16 QPR:$src))), (VREV32q16 QPR:$src)>; def : Pat<(v4f32 (bitconvert (v16i8 QPR:$src))), (VREV32q8 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v2f64 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v2i64 QPR:$src))), (VREV64q32 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v8f16 QPR:$src))), (VREV32q16 QPR:$src)>; + def : Pat<(v4i32 (bitconvert (v8bf16 QPR:$src))), (VREV32q16 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v8i16 QPR:$src))), (VREV32q16 QPR:$src)>; def : Pat<(v4i32 (bitconvert (v16i8 QPR:$src))), (VREV32q8 QPR:$src)>; @@ -7572,6 +7615,12 @@ def : Pat<(v8f16 (bitconvert (v4i32 QPR:$src))), (VREV32q16 QPR:$src)>; def : Pat<(v8f16 (bitconvert (v16i8 QPR:$src))), (VREV16q8 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v2f64 QPR:$src))), (VREV64q16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v2i64 QPR:$src))), (VREV64q16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v4f32 QPR:$src))), (VREV32q16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v4i32 QPR:$src))), (VREV32q16 QPR:$src)>; + def : Pat<(v8bf16 (bitconvert (v16i8 QPR:$src))), (VREV16q8 QPR:$src)>; + def : Pat<(v8i16 (bitconvert (v2f64 QPR:$src))), (VREV64q16 QPR:$src)>; def : Pat<(v8i16 (bitconvert (v2i64 QPR:$src))), (VREV64q16 QPR:$src)>; def : Pat<(v8i16 (bitconvert (v4f32 QPR:$src))), (VREV32q16 QPR:$src)>; @@ -7583,6 +7632,7 @@ def : Pat<(v16i8 (bitconvert (v4f32 QPR:$src))), (VREV32q8 QPR:$src)>; def : Pat<(v16i8 (bitconvert (v4i32 QPR:$src))), (VREV32q8 QPR:$src)>; def : Pat<(v16i8 (bitconvert (v8f16 QPR:$src))), (VREV16q8 QPR:$src)>; + def : Pat<(v16i8 (bitconvert (v8bf16 QPR:$src))), (VREV16q8 QPR:$src)>; def : Pat<(v16i8 (bitconvert (v8i16 QPR:$src))), (VREV16q8 QPR:$src)>; } @@ -7593,12 +7643,12 @@ // input and output types are the same, the bitconvert gets elided // and we end up generating a nonsense match of nothing. - foreach VT = [ v16i8, v8i16, v8f16, v4i32, v4f32, v2i64, v2f64 ] in - foreach VT2 = [ v16i8, v8i16, v8f16, v4i32, v4f32, v2i64, v2f64 ] in + foreach VT = [ v16i8, v8i16, v8f16, v8bf16, v4i32, v4f32, v2i64, v2f64 ] in + foreach VT2 = [ v16i8, v8i16, v8f16, v8bf16, v4i32, v4f32, v2i64, v2f64 ] in def : Pat<(VT (ARMVectorRegCastImpl (VT2 QPR:$src))), (VT QPR:$src)>; - foreach VT = [ v8i8, v4i16, v4f16, v2i32, v2f32, v1i64, f64 ] in - foreach VT2 = [ v8i8, v4i16, v4f16, v2i32, v2f32, v1i64, f64 ] in + foreach VT = [ v8i8, v4i16, v4f16, v4bf16, v2i32, v2f32, v1i64, f64 ] in + foreach VT2 = [ v8i8, v4i16, v4f16, v4bf16, v2i32, v2f32, v1i64, f64 ] in def : Pat<(VT (ARMVectorRegCastImpl (VT2 DPR:$src))), (VT DPR:$src)>; } diff --git a/llvm/lib/Target/ARM/ARMInstrVFP.td b/llvm/lib/Target/ARM/ARMInstrVFP.td --- a/llvm/lib/Target/ARM/ARMInstrVFP.td +++ b/llvm/lib/Target/ARM/ARMInstrVFP.td @@ -158,11 +158,14 @@ let isUnpredicable = 1 in def VLDRH : AHI5<0b1101, 0b01, (outs HPR:$Sd), (ins addrmode5fp16:$addr), IIC_fpLoad16, "vldr", ".16\t$Sd, $addr", - [(set HPR:$Sd, (alignedload16 addrmode5fp16:$addr))]>, + []>, Requires<[HasFPRegs16]>; } // End of 'let canFoldAsLoad = 1, isReMaterializable = 1 in' +def : Pat<(f16 (alignedload16 addrmode5fp16:$addr)), (VLDRH addrmode5fp16:$addr)>; +def : Pat<(bf16 (alignedload16 addrmode5fp16:$addr)), (VLDRH addrmode5fp16:$addr)>; + def VSTRD : ADI5<0b1101, 0b00, (outs), (ins DPR:$Dd, addrmode5:$addr), IIC_fpStore64, "vstr", "\t$Dd, $addr", [(alignedstore32 (f64 DPR:$Dd), addrmode5:$addr)]>, @@ -180,9 +183,12 @@ let isUnpredicable = 1 in def VSTRH : AHI5<0b1101, 0b00, (outs), (ins HPR:$Sd, addrmode5fp16:$addr), IIC_fpStore16, "vstr", ".16\t$Sd, $addr", - [(alignedstore16 HPR:$Sd, addrmode5fp16:$addr)]>, + []>, Requires<[HasFPRegs16]>; +def : Pat<(alignedstore16 (f16 HPR:$Sd), addrmode5fp16:$addr), (VSTRH (f16 HPR:$Sd), addrmode5fp16:$addr)>; +def : Pat<(alignedstore16 (bf16 HPR:$Sd), addrmode5fp16:$addr), (VSTRH (bf16 HPR:$Sd), addrmode5fp16:$addr)>; + //===----------------------------------------------------------------------===// // Load / store multiple Instructions. // @@ -1250,7 +1256,7 @@ def VMOVRH : AVConv2I<0b11100001, 0b1001, (outs rGPR:$Rt), (ins HPR:$Sn), IIC_fpMOVSI, "vmov", ".f16\t$Rt, $Sn", - [(set rGPR:$Rt, (arm_vmovrh HPR:$Sn))]>, + []>, Requires<[HasFPRegs16]>, Sched<[WriteFPMOV]> { // Instruction operands. @@ -1272,7 +1278,7 @@ def VMOVHR : AVConv4I<0b11100000, 0b1001, (outs HPR:$Sn), (ins rGPR:$Rt), IIC_fpMOVIS, "vmov", ".f16\t$Sn, $Rt", - [(set HPR:$Sn, (arm_vmovhr rGPR:$Rt))]>, + []>, Requires<[HasFPRegs16]>, Sched<[WriteFPMOV]> { // Instruction operands. @@ -1290,6 +1296,11 @@ let isUnpredicable = 1; } +def : Pat<(arm_vmovrh (f16 HPR:$Sn)), (VMOVRH (f16 HPR:$Sn))>; +def : Pat<(arm_vmovrh (bf16 HPR:$Sn)), (VMOVRH (bf16 HPR:$Sn))>; +def : Pat<(f16 (arm_vmovhr rGPR:$Rt)), (VMOVHR rGPR:$Rt)>; +def : Pat<(bf16 (arm_vmovhr rGPR:$Rt)), (VMOVHR rGPR:$Rt)>; + // FMRDH: SPR -> GPR // FMRDL: SPR -> GPR // FMRRS: SPR -> GPR diff --git a/llvm/lib/Target/ARM/ARMRegisterInfo.td b/llvm/lib/Target/ARM/ARMRegisterInfo.td --- a/llvm/lib/Target/ARM/ARMRegisterInfo.td +++ b/llvm/lib/Target/ARM/ARMRegisterInfo.td @@ -390,7 +390,7 @@ let DiagnosticString = "operand must be a register in range [s0, s31]"; } -def HPR : RegisterClass<"ARM", [f16], 32, (sequence "S%u", 0, 31)> { +def HPR : RegisterClass<"ARM", [f16, bf16], 32, (sequence "S%u", 0, 31)> { let AltOrders = [(add (decimate HPR, 2), SPR), (add (decimate HPR, 4), (decimate HPR, 2), @@ -412,7 +412,7 @@ // class. // ARM requires only word alignment for double. It's more performant if it // is double-word alignment though. -def DPR : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16], 64, +def DPR : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16, v4bf16], 64, (sequence "D%u", 0, 31)> { // Allocate non-VFP2 registers D16-D31 first, and prefer even registers on // Darwin platforms. @@ -433,20 +433,20 @@ // Subset of DPR that are accessible with VFP2 (and so that also have // 32-bit SPR subregs). -def DPR_VFP2 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16], 64, +def DPR_VFP2 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16, v4bf16], 64, (trunc DPR, 16)> { let DiagnosticString = "operand must be a register in range [d0, d15]"; } // Subset of DPR which can be used as a source of NEON scalars for 16-bit // operations -def DPR_8 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16], 64, +def DPR_8 : RegisterClass<"ARM", [f64, v8i8, v4i16, v2i32, v1i64, v2f32, v4f16, v4bf16], 64, (trunc DPR, 8)> { let DiagnosticString = "operand must be a register in range [d0, d7]"; } // Generic 128-bit vector register class. -def QPR : RegisterClass<"ARM", [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16], 128, +def QPR : RegisterClass<"ARM", [v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16, v8bf16], 128, (sequence "Q%u", 0, 15)> { // Allocate non-VFP2 aliases Q8-Q15 first. let AltOrders = [(rotl QPR, 8), (trunc QPR, 8)]; diff --git a/llvm/lib/Target/ARM/ARMSubtarget.h b/llvm/lib/Target/ARM/ARMSubtarget.h --- a/llvm/lib/Target/ARM/ARMSubtarget.h +++ b/llvm/lib/Target/ARM/ARMSubtarget.h @@ -702,6 +702,7 @@ bool hasD32() const { return HasD32; } bool hasFullFP16() const { return HasFullFP16; } bool hasFP16FML() const { return HasFP16FML; } + bool hasBF16() const { return HasBF16; } bool hasFuseAES() const { return HasFuseAES; } bool hasFuseLiterals() const { return HasFuseLiterals; } diff --git a/llvm/test/CodeGen/ARM/bfloat.ll b/llvm/test/CodeGen/ARM/bfloat.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/ARM/bfloat.ll @@ -0,0 +1,67 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -float-abi hard -mattr=+bf16 < %s | FileCheck %s + +target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64" +target triple = "armv8.6a-arm-none-eabi" + +define arm_aapcs_vfpcc bfloat @load_scalar_bf(bfloat* %addr) { +; CHECK-LABEL: load_scalar_bf: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vldr.16 s0, [r0] +; CHECK-NEXT: bx lr +entry: + %0 = load bfloat, bfloat* %addr, align 2 + ret bfloat %0 +} + +define arm_aapcs_vfpcc void @store_scalar_bf(bfloat %v, bfloat* %addr) { +; CHECK-LABEL: store_scalar_bf: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vstr.16 s0, [r0] +; CHECK-NEXT: bx lr +entry: + store bfloat %v, bfloat* %addr, align 2 + ret void +} + +define arm_aapcs_vfpcc <4 x bfloat> @load_vector4_bf(<4 x bfloat>* %addr) { +; CHECK-LABEL: load_vector4_bf: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vldr d0, [r0] +; CHECK-NEXT: bx lr +entry: + %0 = load <4 x bfloat>, <4 x bfloat>* %addr, align 8 + ret <4 x bfloat> %0 +} + +define arm_aapcs_vfpcc void @store_vector4_bf(<4 x bfloat> %v, <4 x bfloat>* %addr) { +; CHECK-LABEL: store_vector4_bf: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vstr d0, [r0] +; CHECK-NEXT: bx lr +entry: + store <4 x bfloat> %v, <4 x bfloat>* %addr, align 8 + ret void +} + +define arm_aapcs_vfpcc <8 x bfloat> @load_vector8_bf(<8 x bfloat>* %addr) { +; CHECK-LABEL: load_vector8_bf: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vld1.64 {d0, d1}, [r0] +; CHECK-NEXT: bx lr +entry: + %0 = load <8 x bfloat>, <8 x bfloat>* %addr, align 8 + ret <8 x bfloat> %0 +} + +define arm_aapcs_vfpcc void @store_vector8_bf(<8 x bfloat> %v, <8 x bfloat>* %addr) { +; CHECK-LABEL: store_vector8_bf: +; CHECK: @ %bb.0: @ %entry +; CHECK-NEXT: vst1.64 {d0, d1}, [r0] +; CHECK-NEXT: bx lr +entry: + store <8 x bfloat> %v, <8 x bfloat>* %addr, align 8 + ret void +} + +