diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1042,6 +1042,40 @@ return OpActions[(unsigned)VT.getSimpleVT().SimpleTy][Op]; } + /// Some cast operations may be natively supported by the target but only for + /// specific \p SrcVT. This method allows for checking both the \p SrcVT and + /// \p DestVT for a given operation. + LegalizeAction getCastOperationAction(unsigned Op, EVT SrcVT, + EVT DestVT) const { + LegalizeAction Action = Legal; + switch (Op) { + default: + llvm_unreachable("Unexpected cast operation."); + case ISD::FP_TO_FP16: + case ISD::SINT_TO_FP: + case ISD::UINT_TO_FP: + Action = getOperationAction(Op, SrcVT); + break; + case ISD::FP_ROUND: + case ISD::FP_EXTEND: + case ISD::FP_TO_UINT: + case ISD::FP_TO_SINT: + Action = getOperationAction(Op, DestVT); + break; + } + if (Action != Legal && Action != Custom) + return Action; + return isSupportedCastOperation(Op, SrcVT, DestVT) ? Action : LibCall; + } + + /// Custom method defined by each target to indicate if an cast operation + /// which cast from \p SrcVT to \p DestVT is supported natively by the + /// target. If not, the operation is illegal. + virtual bool isSupportedCastOperation(unsigned Op, EVT SrcVT, + EVT DestVT) const { + return true; + } + /// Custom method defined by each target to indicate if an operation which /// may require a scale is supported natively by the target. /// If not, the operation is illegal. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -998,8 +998,16 @@ Action = TLI.getOperationAction(Node->getOpcode(), MVT::Other); break; case ISD::FP_TO_FP16: + case ISD::FP_ROUND: + case ISD::FP_EXTEND: case ISD::SINT_TO_FP: case ISD::UINT_TO_FP: + case ISD::FP_TO_UINT: + case ISD::FP_TO_SINT: + Action = TLI.getCastOperationAction(Node->getOpcode(), + Node->getOperand(0).getValueType(), + Node->getValueType(0)); + break; case ISD::EXTRACT_VECTOR_ELT: case ISD::LROUND: case ISD::LLROUND: