Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -365,13 +365,25 @@ bool VectorReduction : 1; bool AllowContract : 1; + bool RoundDynamic : 1; + bool RoundTonearest : 1; + bool RoundDownward : 1; + bool RoundUpward : 1; + bool RoundTowardZero : 1; + + bool ExceptIgnore : 1; + bool ExceptMayTrap : 1; + bool ExceptStrict : 1; + public: /// Default constructor turns off all optimization flags. SDNodeFlags() : AnyDefined(false), NoUnsignedWrap(false), NoSignedWrap(false), Exact(false), UnsafeAlgebra(false), NoNaNs(false), NoInfs(false), NoSignedZeros(false), AllowReciprocal(false), VectorReduction(false), - AllowContract(false) {} + AllowContract(false), RoundDynamic(false), RoundTonearest(false), + RoundDownward(false), RoundUpward(false), RoundTowardZero(false), + ExceptIgnore(false), ExceptMayTrap(false), ExceptStrict(false) {} /// Sets the state of the flags to the defined state. void setDefined() { AnyDefined = true; } @@ -420,6 +432,46 @@ AllowContract = b; } + void setRoundDynamic(bool b) { + setDefined(); + RoundDynamic = b; + } + + void setRoundTonearest(bool b) { + setDefined(); + RoundTonearest = b; + } + + void setRoundDownward(bool b) { + setDefined(); + RoundDownward = b; + } + + void setRoundUpward(bool b) { + setDefined(); + RoundUpward = b; + } + + void setRoundTowardZero(bool b) { + setDefined(); + RoundTowardZero = b; + } + + void setExceptIgnore(bool b) { + setDefined(); + ExceptIgnore = b; + } + + void setExceptMayTrap(bool b) { + setDefined(); + ExceptMayTrap = b; + } + + void setExceptStrict(bool b) { + setDefined(); + ExceptStrict = b; + } + // These are accessors for each flag. bool hasNoUnsignedWrap() const { return NoUnsignedWrap; } bool hasNoSignedWrap() const { return NoSignedWrap; } @@ -431,6 +483,14 @@ bool hasAllowReciprocal() const { return AllowReciprocal; } bool hasVectorReduction() const { return VectorReduction; } bool hasAllowContract() const { return AllowContract; } + bool hasRoundDynamic() const { return RoundDynamic; } + bool hasRoundTonearest() const { return RoundTonearest; } + bool hasRoundDownward() const { return RoundDownward; } + bool hasRoundUpward() const { return RoundUpward; } + bool hasRoundTowardZero() const { return RoundTowardZero; } + bool hasExceptIgnore() const { return ExceptIgnore; } + bool hasExceptMayTrap() const { return ExceptMayTrap; } + bool hasExceptStrict() const { return ExceptStrict; } /// Clear any flags in this flag set that aren't also set in Flags. /// If the given Flags are undefined then don't do anything. @@ -447,6 +507,14 @@ AllowReciprocal &= Flags.AllowReciprocal; VectorReduction &= Flags.VectorReduction; AllowContract &= Flags.AllowContract; + RoundDynamic &= Flags.RoundDynamic; + RoundTonearest &= Flags.RoundTonearest; + RoundDownward &= Flags.RoundDownward; + RoundUpward &= Flags.RoundUpward; + RoundTowardZero &= Flags.RoundTowardZero; + ExceptIgnore &= Flags.ExceptIgnore; + ExceptMayTrap &= Flags.ExceptMayTrap; + ExceptStrict &= Flags.ExceptStrict; } }; Index: lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -940,7 +940,7 @@ case ISD::STRICT_FSQRT: EqOpc = ISD::FSQRT; break; case ISD::STRICT_FPOW: EqOpc = ISD::FPOW; break; case ISD::STRICT_FPOWI: EqOpc = ISD::FPOWI; break; - case ISD::STRICT_FMA: EqOpc = ISD::FMA; break; + case ISD::STRICT_FMA: EqOpc = ISD::STRICT_FMA; break; case ISD::STRICT_FSIN: EqOpc = ISD::FSIN; break; case ISD::STRICT_FCOS: EqOpc = ISD::FCOS; break; case ISD::STRICT_FEXP: EqOpc = ISD::FEXP; break; @@ -954,6 +954,9 @@ auto Action = TLI.getOperationAction(EqOpc, VT); + if (Action == TargetLowering::Custom) + return Action; + // We don't currently handle Custom or Promote for strict FP pseudo-ops. // For now, we just expand for those cases. if (Action != TargetLowering::Legal) Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -5467,7 +5467,22 @@ case Intrinsic::experimental_constrained_fmul: case Intrinsic::experimental_constrained_fdiv: case Intrinsic::experimental_constrained_frem: - case Intrinsic::experimental_constrained_fma: + case Intrinsic::experimental_constrained_fma: { + SDNodeFlags SDFlags; + const ConstrainedFPIntrinsic &FPI = cast(I); + if (FPI.getRoundingMode() == llvm::ConstrainedFPIntrinsic::rmDynamic) + SDFlags.setRoundDynamic(true); + + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + SDValue Res = DAG.getNode(ISD::STRICT_FMA, sdl, VT, + getValue(I.getArgOperand(0)), + getValue(I.getArgOperand(1)), + getValue(I.getArgOperand(2))); + + Res.getNode()->setFlags(SDFlags); + setValue(&I, Res); + return nullptr; + } case Intrinsic::experimental_constrained_sqrt: case Intrinsic::experimental_constrained_pow: case Intrinsic::experimental_constrained_powi: Index: lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp =================================================================== --- lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp +++ lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp @@ -780,8 +780,12 @@ SelectVOP3Mods(N->getOperand(3), Ops[5], Ops[4]); Ops[8] = N->getOperand(0); Ops[9] = N->getOperand(4); - - CurDAG->SelectNodeTo(N, AMDGPU::V_FMA_F32, N->getVTList(), Ops); + assert((N->getValueType(0) == MVT::f32 || N->getValueType(0) == MVT::f64) && + "Incorrent Value Type!"); + unsigned TargetOpc = N->getValueType(0) == MVT::f32 ? + AMDGPU::V_FMA_F32 : + AMDGPU::V_FMA_F64; + CurDAG->SelectNodeTo(N, TargetOpc, N->getVTList(), Ops); } void AMDGPUDAGToDAGISel::SelectFMUL_W_CHAIN(SDNode *N) { Index: lib/Target/AMDGPU/SIISelLowering.h =================================================================== --- lib/Target/AMDGPU/SIISelLowering.h +++ lib/Target/AMDGPU/SIISelLowering.h @@ -54,6 +54,7 @@ SDValue LowerFDIV32(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFDIV64(SDValue Op, SelectionDAG &DAG) const; SDValue LowerFDIV(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerConstrainedFMA(SDValue Op, SelectionDAG &DAG) const; SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG, bool Signed) const; SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerTrig(SDValue Op, SelectionDAG &DAG) const; Index: lib/Target/AMDGPU/SIISelLowering.cpp =================================================================== --- lib/Target/AMDGPU/SIISelLowering.cpp +++ lib/Target/AMDGPU/SIISelLowering.cpp @@ -316,6 +316,12 @@ setOperationAction(ISD::FDIV, MVT::f32, Custom); setOperationAction(ISD::FDIV, MVT::f64, Custom); + //setOperationAction(ISD::FMA, MVT::f32, Custom); + //setOperationAction(ISD::FMA, MVT::f64, Custom); + + setOperationAction(ISD::STRICT_FMA, MVT::f32, Custom); + setOperationAction(ISD::STRICT_FMA, MVT::f64, Custom); + if (Subtarget->has16BitInsts()) { setOperationAction(ISD::Constant, MVT::i16, Legal); @@ -3168,6 +3174,7 @@ return LowerTrig(Op, DAG); case ISD::SELECT: return LowerSELECT(Op, DAG); case ISD::FDIV: return LowerFDIV(Op, DAG); + case ISD::STRICT_FMA: return LowerConstrainedFMA(Op, DAG); case ISD::ATOMIC_CMP_SWAP: return LowerATOMIC_CMP_SWAP(Op, DAG); case ISD::STORE: return LowerSTORE(Op, DAG); case ISD::GlobalAddress: { @@ -4657,7 +4664,7 @@ return DAG.getNode(Opcode, SL, VT, A, B, C); } - assert(GlueChain->getNumValues() == 3); + assert(GlueChain->getNumValues() == 3 || GlueChain->getNumValues() == 2); SDVTList VTList = DAG.getVTList(VT, MVT::Other, MVT::Glue); switch (Opcode) { @@ -4667,8 +4674,12 @@ break; } - return DAG.getNode(Opcode, SL, VTList, GlueChain.getValue(1), A, B, C, - GlueChain.getValue(2)); + if (GlueChain->getNumValues() == 3) + return DAG.getNode(Opcode, SL, VTList, GlueChain.getValue(1), A, B, C, + GlueChain.getValue(2)); + else if (GlueChain->getNumValues() == 2) + return DAG.getNode(Opcode, SL, VTList, GlueChain.getValue(0), A, B, C, + GlueChain.getValue(1)); } SDValue SITargetLowering::LowerFDIV16(SDValue Op, SelectionDAG &DAG) const { @@ -4890,6 +4901,49 @@ llvm_unreachable("Unexpected type for fdiv"); } +SDValue SITargetLowering::LowerConstrainedFMA(SDValue Op, SelectionDAG &DAG) const { + SDLoc SL(Op); + + // Retrieve FP Rouding Mode. + bool RoundMode = Op->getFlags().hasRoundDynamic(); + // TODO: Based on retrieved FP RoundMode to set up register modes. + const unsigned Denorm32Reg = AMDGPU::Hwreg::ID_MODE | + (2 << AMDGPU::Hwreg::OFFSET_SHIFT_) | + (1 << AMDGPU::Hwreg::WIDTH_M1_SHIFT_); + + const SDValue BitField = DAG.getTargetConstant(Denorm32Reg, SL, MVT::i16); + + SDVTList BindParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); + const SDValue EnableDenormValue = DAG.getConstant(FP_DENORM_FLUSH_NONE, + SL, MVT::i32); + + SDValue EnableDenorm = DAG.getNode(AMDGPUISD::SETREG, SL, BindParamVTs, + DAG.getEntryNode(), + EnableDenormValue, BitField); + + SDValue FMA = getFPTernOp(DAG, ISD::FMA, SL, MVT::f64, Op.getOperand(0), + Op.getOperand(1), + Op.getOperand(2), + EnableDenorm); + + const SDValue DisableDenormValue = DAG.getConstant(FP_DENORM_FLUSH_NONE, + SL, MVT::i32); + + SDValue DisableDenorm = DAG.getNode(AMDGPUISD::SETREG, SL, BindParamVTs, + FMA.getValue(1), + DisableDenormValue, + BitField, + FMA.getValue(2)); + + SDValue OutputChain = DAG.getNode(ISD::TokenFactor, SL, MVT::Other, + DisableDenorm, DAG.getRoot()); + DAG.setRoot(OutputChain); + + return FMA; + + llvm_unreachable("Unexpected type for fma"); +} + SDValue SITargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); StoreSDNode *Store = cast(Op); Index: test/CodeGen/AMDGPU/constrained_fp.ll =================================================================== --- /dev/null +++ test/CodeGen/AMDGPU/constrained_fp.ll @@ -6,8 +6,7 @@ ; FUNC: s_setreg_b32 ; FUNC: v_fma_f64 ; FUNC: s_setreg_b32 -define amdgpu_kernel void @fma_f64(double addrspace(1)* %out, double addrspace(1)* %in1, - double addrspace(1)* %in2, double addrspace(1)* %in3) { +define amdgpu_kernel void @fma_f64(double addrspace(1)* %out, double addrspace(1)* %in1, double addrspace(1)* %in2, double addrspace(1)* %in3) { %r0 = load double, double addrspace(1)* %in1 %r1 = load double, double addrspace(1)* %in2 %r2 = load double, double addrspace(1)* %in3