diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -776,17 +776,17 @@ FP_TO_UINT, /// FP_TO_[US]INT_SAT - Convert floating point value in operand 0 to a - /// signed or unsigned integer type with the bit width given in operand 1 with - /// the following semantics: + /// signed or unsigned scalar integer type given in operand 1 with the + /// following semantics: /// /// * If the value is NaN, zero is returned. /// * If the value is larger/smaller than the largest/smallest integer, /// the largest/smallest integer is returned (saturation). /// * Otherwise the result of rounding the value towards zero is returned. /// - /// The width given in operand 1 must be equal to, or smaller than, the scalar - /// result type width. It may end up being smaller than the result witdh as a - /// result of integer type legalization. + /// The scalar width of the type given in operand 1 must be equal to, or + /// smaller than, the scalar result type width. It may end up being smaller + /// than the result width as a result of integer type legalization. FP_TO_SINT_SAT, FP_TO_UINT_SAT, diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -165,7 +165,7 @@ SDTCisInt<0>, SDTCisFP<1>, SDTCisSameNumEltsAs<0, 1> ]>; def SDTFPToIntSatOp : SDTypeProfile<1, 2, [ // fp_to_[su]int_sat - SDTCisInt<0>, SDTCisFP<1>, SDTCisInt<2>, SDTCisSameNumEltsAs<0, 1> + SDTCisInt<0>, SDTCisFP<1>, SDTCisSameNumEltsAs<0, 1>, SDTCisVT<2, OtherVT> ]>; def SDTExtInreg : SDTypeProfile<1, 2, [ // sext_inreg SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisVT<2, OtherVT>, diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -1547,9 +1547,6 @@ case ISD::FPOWI: Res = PromoteIntOp_FPOWI(N); break; - case ISD::FP_TO_SINT_SAT: - case ISD::FP_TO_UINT_SAT: PromoteIntOp_FP_TO_XINT_SAT(N); break; - case ISD::VECREDUCE_ADD: case ISD::VECREDUCE_MUL: case ISD::VECREDUCE_AND: @@ -1970,12 +1967,6 @@ DAG.UpdateNodeOperands(N, N->getOperand(0), N->getOperand(1), Op2), 0); } -SDValue DAGTypeLegalizer::PromoteIntOp_FP_TO_XINT_SAT(SDNode *N) { - SDValue Op1 = ZExtPromotedInteger(N->getOperand(1)); - return SDValue( - DAG.UpdateNodeOperands(N, N->getOperand(0), Op1), 0); -} - SDValue DAGTypeLegalizer::PromoteIntOp_FRAMERETURNADDR(SDNode *N) { // Promote the RETURNADDR/FRAMEADDR argument to a supported integer width. SDValue Op = ZExtPromotedInteger(N->getOperand(0)); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -391,7 +391,6 @@ SDValue PromoteIntOp_FRAMERETURNADDR(SDNode *N); SDValue PromoteIntOp_PREFETCH(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_FIX(SDNode *N); - SDValue PromoteIntOp_FP_TO_XINT_SAT(SDNode *N); SDValue PromoteIntOp_FPOWI(SDNode *N); SDValue PromoteIntOp_VECREDUCE(SDNode *N); SDValue PromoteIntOp_SET_ROUNDING(SDNode *N); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5642,6 +5642,22 @@ } break; } + case ISD::FP_TO_SINT_SAT: + case ISD::FP_TO_UINT_SAT: { + assert(VT.isInteger() && cast(N2)->getVT().isInteger() && + N1.getValueType().isFloatingPoint() && "Invalid FP_TO_*INT_SAT"); + assert(N1.getValueType().isVector() == VT.isVector() && + "FP_TO_*INT_SAT type should be vector iff the operand type is " + "vector!"); + assert((!VT.isVector() || VT.getVectorNumElements() == + N1.getValueType().getVectorNumElements()) && + "Vector element counts must match in FP_TO_*INT_SAT"); + assert(!cast(N2)->getVT().isVector() && + "Type to saturate to must be a scalar."); + assert(cast(N2)->getVT().bitsLE(VT.getScalarType()) && + "Not extending!"); + break; + } case ISD::EXTRACT_VECTOR_ELT: assert(VT.getSizeInBits() >= N1.getValueType().getScalarSizeInBits() && "The result of EXTRACT_VECTOR_ELT must be at least as wide as the \ diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -6315,17 +6315,17 @@ getValue(I.getArgOperand(0))))); return; case Intrinsic::fptosi_sat: { - EVT Type = TLI.getValueType(DAG.getDataLayout(), I.getType()); - SDValue SatW = DAG.getConstant(Type.getScalarSizeInBits(), sdl, MVT::i32); - setValue(&I, DAG.getNode(ISD::FP_TO_SINT_SAT, sdl, Type, - getValue(I.getArgOperand(0)), SatW)); + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + setValue(&I, DAG.getNode(ISD::FP_TO_SINT_SAT, sdl, VT, + getValue(I.getArgOperand(0)), + DAG.getValueType(VT.getScalarType()))); return; } case Intrinsic::fptoui_sat: { - EVT Type = TLI.getValueType(DAG.getDataLayout(), I.getType()); - SDValue SatW = DAG.getConstant(Type.getScalarSizeInBits(), sdl, MVT::i32); - setValue(&I, DAG.getNode(ISD::FP_TO_UINT_SAT, sdl, Type, - getValue(I.getArgOperand(0)), SatW)); + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + setValue(&I, DAG.getNode(ISD::FP_TO_UINT_SAT, sdl, VT, + getValue(I.getArgOperand(0)), + DAG.getValueType(VT.getScalarType()))); return; } case Intrinsic::set_rounding: diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -8548,7 +8548,8 @@ EVT SrcVT = Src.getValueType(); EVT DstVT = Node->getValueType(0); - unsigned SatWidth = Node->getConstantOperandVal(1); + EVT SatVT = cast(Node->getOperand(1))->getVT(); + unsigned SatWidth = SatVT.getScalarSizeInBits(); unsigned DstWidth = DstVT.getScalarSizeInBits(); assert(SatWidth <= DstWidth && "Expected saturation width smaller than result width"); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -1973,12 +1973,13 @@ SelectionDAG &DAG) const { SDLoc DL(Op); EVT ResT = Op.getValueType(); - uint64_t Width = Op.getConstantOperandVal(1); + EVT SatVT = cast(Op.getOperand(1))->getVT(); - if ((ResT == MVT::i32 || ResT == MVT::i64) && (Width == 32 || Width == 64)) + if ((ResT == MVT::i32 || ResT == MVT::i64) && + (SatVT == MVT::i32 || SatVT == MVT::i64)) return Op; - if (ResT == MVT::v4i32 && Width == 32) + if (ResT == MVT::v4i32 && SatVT == MVT::i32) return Op; return SDValue(); @@ -2143,7 +2144,7 @@ auto FPToIntOp = FPToInt.getOpcode(); if (FPToIntOp != ISD::FP_TO_SINT_SAT && FPToIntOp != ISD::FP_TO_UINT_SAT) return SDValue(); - if (FPToInt.getConstantOperandVal(1) != 32) + if (cast(FPToInt.getOperand(1))->getVT() != MVT::i32) return SDValue(); auto Source = FPToInt.getOperand(0); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrConv.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrConv.td --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrConv.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrConv.td @@ -97,14 +97,14 @@ Requires<[HasNontrappingFPToInt]>; // Support the explicitly saturating operations as well. -def : Pat<(fp_to_sint_sat F32:$src, (i32 32)), (I32_TRUNC_S_SAT_F32 F32:$src)>; -def : Pat<(fp_to_uint_sat F32:$src, (i32 32)), (I32_TRUNC_U_SAT_F32 F32:$src)>; -def : Pat<(fp_to_sint_sat F64:$src, (i32 32)), (I32_TRUNC_S_SAT_F64 F64:$src)>; -def : Pat<(fp_to_uint_sat F64:$src, (i32 32)), (I32_TRUNC_U_SAT_F64 F64:$src)>; -def : Pat<(fp_to_sint_sat F32:$src, (i32 64)), (I64_TRUNC_S_SAT_F32 F32:$src)>; -def : Pat<(fp_to_uint_sat F32:$src, (i32 64)), (I64_TRUNC_U_SAT_F32 F32:$src)>; -def : Pat<(fp_to_sint_sat F64:$src, (i32 64)), (I64_TRUNC_S_SAT_F64 F64:$src)>; -def : Pat<(fp_to_uint_sat F64:$src, (i32 64)), (I64_TRUNC_U_SAT_F64 F64:$src)>; +def : Pat<(fp_to_sint_sat F32:$src, i32), (I32_TRUNC_S_SAT_F32 F32:$src)>; +def : Pat<(fp_to_uint_sat F32:$src, i32), (I32_TRUNC_U_SAT_F32 F32:$src)>; +def : Pat<(fp_to_sint_sat F64:$src, i32), (I32_TRUNC_S_SAT_F64 F64:$src)>; +def : Pat<(fp_to_uint_sat F64:$src, i32), (I32_TRUNC_U_SAT_F64 F64:$src)>; +def : Pat<(fp_to_sint_sat F32:$src, i64), (I64_TRUNC_S_SAT_F32 F32:$src)>; +def : Pat<(fp_to_uint_sat F32:$src, i64), (I64_TRUNC_U_SAT_F32 F32:$src)>; +def : Pat<(fp_to_sint_sat F64:$src, i64), (I64_TRUNC_S_SAT_F64 F64:$src)>; +def : Pat<(fp_to_uint_sat F64:$src, i64), (I64_TRUNC_U_SAT_F64 F64:$src)>; // Conversion from floating point to integer pseudo-instructions which don't // trap on overflow or invalid. diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1092,8 +1092,8 @@ defm "" : SIMDConvert; // Support the saturating variety as well. -def trunc_s_sat32 : PatFrag<(ops node:$x), (fp_to_sint_sat $x, (i32 32))>; -def trunc_u_sat32 : PatFrag<(ops node:$x), (fp_to_uint_sat $x, (i32 32))>; +def trunc_s_sat32 : PatFrag<(ops node:$x), (fp_to_sint_sat $x, i32)>; +def trunc_u_sat32 : PatFrag<(ops node:$x), (fp_to_uint_sat $x, i32)>; def : Pat<(v4i32 (trunc_s_sat32 (v4f32 V128:$src))), (fp_to_sint_I32x4 $src)>; def : Pat<(v4i32 (trunc_u_sat32 (v4f32 V128:$src))), (fp_to_uint_I32x4 $src)>; diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -21534,7 +21534,8 @@ if (!isScalarFPTypeInSSEReg(SrcVT)) return SDValue(); - unsigned SatWidth = Node->getConstantOperandVal(1); + EVT SatVT = cast(Node->getOperand(1))->getVT(); + unsigned SatWidth = SatVT.getScalarSizeInBits(); unsigned DstWidth = DstVT.getScalarSizeInBits(); unsigned TmpWidth = TmpVT.getScalarSizeInBits(); assert(SatWidth <= DstWidth && SatWidth <= TmpWidth &&