diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -1148,8 +1148,13 @@ // These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used // to transition between unpacked and packed types of the same element type, // with BITCAST used otherwise. + // This function does not handle predicate bitcasts. SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; + // Returns a safe bitcast between two scalable vector predicates, where + // any newly created lanes from a widening bitcast are defined as zero. + SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const; + bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1, LLT Ty2) const override; }; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1082,6 +1082,14 @@ setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom); } + // FIXME: Move lowering for more nodes here if those are common between + // SVE and SME. + if (Subtarget->hasSVE() || Subtarget->hasSME()) { + for (auto VT : + {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + } + if (Subtarget->hasSVE()) { for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) { setOperationAction(ISD::BITREVERSE, VT, Custom); @@ -1162,7 +1170,6 @@ setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); - setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); setOperationAction(ISD::TRUNCATE, VT, Custom); setOperationAction(ISD::VECREDUCE_AND, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); @@ -4333,27 +4340,47 @@ DAG.getTargetConstant(Pattern, DL, MVT::i32)); } -static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) { +SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op, + SelectionDAG &DAG) const { SDLoc DL(Op); - EVT OutVT = Op.getValueType(); - SDValue InOp = Op.getOperand(1); - EVT InVT = InOp.getValueType(); + EVT InVT = Op.getValueType(); + + assert(InVT.getVectorElementType() == MVT::i1 && + VT.getVectorElementType() == MVT::i1 && + "Expected a predicate-to-predicate bitcast"); + assert(VT.isScalableVector() && isTypeLegal(VT) && + InVT.isScalableVector() && isTypeLegal(InVT) && + "Only expect to cast between legal scalable predicate types!"); // Return the operand if the cast isn't changing type, - // i.e. -> - if (InVT == OutVT) - return InOp; + // e.g. -> + if (InVT == VT) + return Op; + + SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op); - SDValue Reinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp); + // We only have to zero the lanes if new lanes are being defined, e.g. when + // casting from to . If this is not the + // case (e.g. when casting from -> ) then + // we can return here. + if (InVT.bitsGT(VT)) + return Reinterpret; - // If the argument converted to an svbool is a ptrue or a comparison, the - // lanes introduced by the widening are zero by construction. - switch (InOp.getOpcode()) { + // Check if the other lanes are already known to be zeroed by + // construction. + switch (Op.getOpcode()) { + default: + // We guarantee i1 splat_vectors to zero the other lanes by + // implementing it with ptrue and possibly a punpklo for nxv1i1. + if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) + return Reinterpret; + break; case AArch64ISD::SETCC_MERGE_ZERO: return Reinterpret; case ISD::INTRINSIC_WO_CHAIN: - switch (InOp.getConstantOperandVal(0)) { + switch (Op.getConstantOperandVal(0)) { + default: + break; case Intrinsic::aarch64_sve_ptrue: case Intrinsic::aarch64_sve_cmpeq_wide: case Intrinsic::aarch64_sve_cmpne_wide: @@ -4369,15 +4396,10 @@ } } - // Splat vectors of one will generate ptrue instructions - if (ISD::isConstantSplatVectorAllOnes(InOp.getNode())) - return Reinterpret; - - // Otherwise, zero the newly introduced lanes. - SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all); - SDValue MaskReinterpret = - DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask); - return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret); + // Zero the newly introduced lanes. + SDValue Mask = DAG.getConstant(1, DL, InVT); + Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Mask); + return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask); } SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op, @@ -4546,10 +4568,9 @@ case Intrinsic::aarch64_sve_dupq_lane: return LowerDUPQLane(Op, DAG); case Intrinsic::aarch64_sve_convert_from_svbool: - return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(), - Op.getOperand(1)); + return getSVEPredicateBitCast(Op.getValueType(), Op.getOperand(1), DAG); case Intrinsic::aarch64_sve_convert_to_svbool: - return lowerConvertToSVBool(Op, DAG); + return getSVEPredicateBitCast(MVT::nxv16i1, Op.getOperand(1), DAG); case Intrinsic::aarch64_sve_fneg: return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(), Op.getOperand(2), Op.getOperand(3), Op.getOperand(1)); @@ -21464,22 +21485,17 @@ SelectionDAG &DAG) const { SDLoc DL(Op); EVT InVT = Op.getValueType(); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - (void)TLI; - assert(VT.isScalableVector() && TLI.isTypeLegal(VT) && - InVT.isScalableVector() && TLI.isTypeLegal(InVT) && + assert(VT.isScalableVector() && isTypeLegal(VT) && + InVT.isScalableVector() && isTypeLegal(InVT) && "Only expect to cast between legal scalable vector types!"); - assert((VT.getVectorElementType() == MVT::i1) == - (InVT.getVectorElementType() == MVT::i1) && - "Cannot cast between data and predicate scalable vector types!"); + assert(VT.getVectorElementType() != MVT::i1 && + InVT.getVectorElementType() != MVT::i1 && + "For predicate bitcasts, use getSVEPredicateBitCast"); if (InVT == VT) return Op; - if (VT.getVectorElementType() == MVT::i1) - return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op); - EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType()); EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());