Index: lib/Target/ARM/ARMISelLowering.h =================================================================== --- lib/Target/ARM/ARMISelLowering.h +++ lib/Target/ARM/ARMISelLowering.h @@ -171,6 +171,13 @@ // Vector move f32 immediate: VMOVFPIMM, + // VMOV GPR <-> HPR + // Used for half-precision function arguments and return values that + // are passed as as int (SoftFP), which need to be moved from int to + // fp registers (or vice versa). + VMOVrh, + VMOVhr, + // Vector duplicate: VDUP, VDUPLANE, Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -524,9 +524,8 @@ if (Subtarget->hasFullFP16()) { addRegisterClass(MVT::f16, &ARM::HPRRegClass); - // Clean up bitcast of incoming arguments if hard float abi is enabled. - if (Subtarget->isTargetHardFloat()) - setOperationAction(ISD::BITCAST, MVT::i16, Custom); + setOperationAction(ISD::BITCAST, MVT::i16, Custom); + setOperationAction(ISD::BITCAST, MVT::f16, Custom); } for (MVT VT : MVT::vector_valuetypes()) { @@ -1273,6 +1272,8 @@ case ARMISD::VMOVRRD: return "ARMISD::VMOVRRD"; case ARMISD::VMOVDRR: return "ARMISD::VMOVDRR"; + case ARMISD::VMOVhr: return "ARMISD::VMOVhr"; + case ARMISD::VMOVrh: return "ARMISD::VMOVrh"; case ARMISD::EH_SJLJ_SETJMP: return "ARMISD::EH_SJLJ_SETJMP"; case ARMISD::EH_SJLJ_LONGJMP: return "ARMISD::EH_SJLJ_LONGJMP"; @@ -5061,38 +5062,91 @@ EVT SrcVT = Op.getValueType(); EVT DstVT = N->getValueType(0); - // Half-precision arguments can be passed in like this: - // - // t4: f32,ch = CopyFromReg t0, Register:f32 %1 - // t8: i32 = bitcast t4 - // t9: i16 = truncate t8 - // t10: f16 = bitcast t9 <~~~~ SDNode N - // - // but we want to avoid code generation for the bitcast, so transform this - // into: - // - // t18: f16 = CopyFromReg t0, Register:f32 %0 - // + + // Half-precision arguments: avoid stack stores/loads if (SrcVT == MVT::i16 && DstVT == MVT::f16) { - if (Op.getOpcode() != ISD::TRUNCATE) - return SDValue(); + if (Op.getOpcode() != ISD::TRUNCATE) + return SDValue(); + // Transform this: + // + // t4: f32,ch = CopyFromReg t0, Register:f32 %1 + // t8: i32 = bitcast t4 + // t9: i16 = truncate t8 <~~~~ Op + // t10: f16 = bitcast t9 <~~~~ SDNode N + // + // into an f16 copy from reg: + // + // t18: f16 = CopyFromReg t0, Register:f32 %0 + // SDValue Bitcast = Op.getOperand(0); - if (Bitcast.getOpcode() != ISD::BITCAST || - Bitcast.getValueType() != MVT::i32) - return SDValue(); + if (Bitcast.getOpcode() == ISD::BITCAST && + Bitcast.getValueType() == MVT::i32) { + + SDValue Copy = Bitcast.getOperand(0); + if (Copy.getOpcode() != ISD::CopyFromReg || + Copy.getValueType() != MVT::f32) + return SDValue(); + + SDValue Ops[] = { Copy->getOperand(0), Copy->getOperand(1) }; + return DAG.getNode(ISD::CopyFromReg, SDLoc(Copy), MVT::f16, Ops); + } - SDValue Copy = Bitcast.getOperand(0); - if (Copy.getOpcode() != ISD::CopyFromReg || - Copy.getValueType() != MVT::f32) + // And for FullFP16 we can have this: + // + // t0: ch = EntryToken + // ... + // t5: i32,ch = CopyFromReg t0, Register:i32 %1 + // t9: i16 = truncate t5 <~~~~ Op + // t10: f16 = bitcast t9 <~~~~ SDNode N + // t11: f16 = fadd t8, t10 + // + // and transform this into: + // + // t5: i32,ch = CopyFromReg t0, Register:i32 %1 + // t18: f16 = ARMISD::VMOVhr t5 + // + SDValue Copy = Op.getOperand(0); + if (Copy.getOpcode() == ISD::CopyFromReg && + Copy.getValueType() == MVT::i32 && + Copy.getOperand(0).getOpcode() == ISD::EntryToken) { + return DAG.getNode(ARMISD::VMOVhr, SDLoc(Op), + MVT::f16, Op.getOperand(0)); + } + return SDValue(); + } + + // Half-precision return values: avoid stack stores/loads + if (SrcVT == MVT::f16 && DstVT == MVT::i16) { + // + // t11: f16 = fadd t8, t10 + // t12: i16 = bitcast t11 <~~~ SDNode N + // t13: i32 = zero_extend t12 + // t16: ch,glue = CopyToReg t0, Register:i32 %r0, t13 + // t17: ch = ARMISD::RET_FLAG t16, Register:i32 %r0, t16:1 + // + // transform this into: + // + // t20: i32 = ARMISD::VMOVrh t11 + // t16: ch,glue = CopyToReg t0, Register:i32 %r0, t20 + // + auto ZeroExtend = N->use_begin(); + if (N->use_size() != 1 || ZeroExtend->getOpcode() != ISD::ZERO_EXTEND || + ZeroExtend->getValueType(0) != MVT::i32) return SDValue(); - SDValue Ops[] = { Copy->getOperand(0), Copy->getOperand(1) }; - return DAG.getNode(ISD::CopyFromReg, SDLoc(Copy), MVT::f16, Ops); + auto Copy = ZeroExtend->use_begin(); + if (Copy->getOpcode() == ISD::CopyToReg && + Copy->use_begin()->getOpcode() == ARMISD::RET_FLAG) { + SDValue Cvt = DAG.getNode(ARMISD::VMOVrh, SDLoc(Op), MVT::i32, Op); + DAG.ReplaceAllUsesWith(*ZeroExtend, &Cvt); + return Cvt; + } + return SDValue(); } - assert((SrcVT == MVT::i64 || DstVT == MVT::i64) && - "ExpandBITCAST called for non-i64 type"); + if (!(SrcVT == MVT::i64 || DstVT == MVT::i64)) + return SDValue(); // Turn i64->f64 into VMOVDRR. if (SrcVT == MVT::i64 && TLI.isTypeLegal(DstVT)) { Index: lib/Target/ARM/ARMInstrVFP.td =================================================================== --- lib/Target/ARM/ARMInstrVFP.td +++ lib/Target/ARM/ARMInstrVFP.td @@ -23,6 +23,11 @@ def arm_fmdrr : SDNode<"ARMISD::VMOVDRR", SDT_VMOVDRR>; def arm_fmrrd : SDNode<"ARMISD::VMOVRRD", SDT_VMOVRRD>; +def SDT_VMOVhr : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVT<1, i32>] >; +def SDT_VMOVrh : SDTypeProfile<1, 1, [SDTCisVT<0, i32>, SDTCisFP<1>] >; +def arm_vmovhr : SDNode<"ARMISD::VMOVhr", SDT_VMOVhr>; +def arm_vmovrh : SDNode<"ARMISD::VMOVrh", SDT_VMOVrh>; + //===----------------------------------------------------------------------===// // Operand Definitions. // @@ -750,6 +755,12 @@ let Inst{5} = Dm{4}; } +def : FullFP16Pat<(arm_vmovhr GPR:$a), + (f16 (COPY_TO_REGCLASS GPR:$a, HPR))>; + +def : FullFP16Pat<(arm_vmovrh HPR:$a), + (i32 (COPY_TO_REGCLASS HPR:$a, GPR))>; + def : FP16Pat<(fp_to_f16 SPR:$a), (i32 (COPY_TO_REGCLASS (VCVTBSH SPR:$a), GPR))>; Index: test/CodeGen/ARM/fp16-instructions.ll =================================================================== --- test/CodeGen/ARM/fp16-instructions.ll +++ test/CodeGen/ARM/fp16-instructions.ll @@ -43,14 +43,11 @@ ; CHECK-SOFTFP-FP16: vcvtb.f16.f32 [[S0]], [[S0]] ; CHECK-SOFTFP-FP16: vmov r0, s0 -; CHECK-SOFTFP-FULLFP16: strh r1, {{.*}} -; CHECK-SOFTFP-FULLFP16: strh r0, {{.*}} -; CHECK-SOFTFP-FULLFP16: vldr.16 [[S0:s[0-9]]], {{.*}} -; CHECK-SOFTFP-FULLFP16: vldr.16 [[S2:s[0-9]]], {{.*}} -; CHECK-SOFTFP-FULLFP16: vadd.f16 [[S0]], [[S2]], [[S0]] -; CHECK-SOFTFP-FULLFP16: vstr.16 [[S2:s[0-9]]], {{.*}} -; CHECK-SOFTFP-FULLFP16: ldrh r0, {{.*}} -; CHECK-SOFTFP-FULLFP16: mov pc, lr +; CHECK-SOFTFP-FULLFP16: vmov [[S0:s[0-9]]], r1 +; CHECK-SOFTFP-FULLFP16: vmov [[S2:s[0-9]]], r0 +; CHECK-SOFTFP-FULLFP16: vadd.f16 [[S0]], [[S2]], [[S0]] +; CHECK-SOFTFP-FULLFP16-NEXT: vmov r0, s0 +; CHECK-SOFTFP-FULLFP16-NEXT: mov pc, lr ; CHECK-HARDFP-VFP3: vmov r{{.}}, s0 ; CHECK-HARDFP-VFP3: vmov{{.*}}, s1 @@ -69,4 +66,3 @@ ; CHECK-HARDFP-FULLFP16-NEXT: mov pc, lr } -