diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -137,7 +137,7 @@ } } -static inline MVT getPromotedVTForPredicate(MVT VT) { +static inline EVT getPromotedVTForPredicate(EVT VT) { assert(VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1) && "Expected scalable predicate vector type!"); switch (VT.getVectorMinNumElements()) { @@ -1030,10 +1030,8 @@ // There are no legal MVT::nxv16f## based types. if (VT != MVT::nxv16i1) { - setOperationAction(ISD::SINT_TO_FP, VT, Promote); - AddPromotedToType(ISD::SINT_TO_FP, VT, getPromotedVTForPredicate(VT)); - setOperationAction(ISD::UINT_TO_FP, VT, Promote); - AddPromotedToType(ISD::UINT_TO_FP, VT, getPromotedVTForPredicate(VT)); + setOperationAction(ISD::SINT_TO_FP, VT, Custom); + setOperationAction(ISD::UINT_TO_FP, VT, Custom); } } @@ -3086,11 +3084,20 @@ SDLoc dl(Op); SDValue In = Op.getOperand(0); EVT InVT = In.getValueType(); + unsigned Opc = Op.getOpcode(); + bool IsSigned = Opc == ISD::SINT_TO_FP || Opc == ISD::STRICT_SINT_TO_FP; if (VT.isScalableVector()) { - unsigned Opcode = Op.getOpcode() == ISD::UINT_TO_FP - ? AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU - : AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU; + if (InVT.getVectorElementType() == MVT::i1) { + // We can't directly extend an SVE predicate; extend it first. + unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + EVT CastVT = getPromotedVTForPredicate(InVT); + In = DAG.getNode(CastOpc, dl, CastVT, In); + return DAG.getNode(Opc, dl, VT, In); + } + + unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU + : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU; return LowerToPredicatedOp(Op, DAG, Opcode); } @@ -3100,16 +3107,15 @@ MVT CastVT = MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()), InVT.getVectorNumElements()); - In = DAG.getNode(Op.getOpcode(), dl, CastVT, In); + In = DAG.getNode(Opc, dl, CastVT, In); return DAG.getNode(ISD::FP_ROUND, dl, VT, In, DAG.getIntPtrConstant(0, dl)); } if (VTSize > InVTSize) { - unsigned CastOpc = - Op.getOpcode() == ISD::SINT_TO_FP ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; + unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND; EVT CastVT = VT.changeVectorElementTypeToInteger(); In = DAG.getNode(CastOpc, dl, CastVT, In); - return DAG.getNode(Op.getOpcode(), dl, VT, In); + return DAG.getNode(Opc, dl, VT, In); } return Op;