Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -12367,6 +12367,9 @@ const TargetOptions &Options = DAG.getTarget().Options; const SDNodeFlags Flags = N->getFlags(); + if (VT.isScalableVector()) + return SDValue(); + if (SDValue R = DAG.simplifyFPBinop(N->getOpcode(), N0, N1, Flags)) return R; @@ -18898,6 +18901,12 @@ EVT NVT = N->getValueType(0); SDValue V = N->getOperand(0); + // Don't combine if we're extracting a fixed-width vector from + // a scalable vector. + if (V.getValueType().isScalableVector() && + !NVT.isScalableVector()) + return SDValue(); + // Extract from UNDEF is UNDEF. if (V.isUndef()) return DAG.getUNDEF(NVT); @@ -20033,6 +20042,11 @@ SDValue N1 = N->getOperand(1); SDValue N2 = N->getOperand(2); + // Don't combine fixed->scalable vectors. + if (VT.isScalableVector() && + !N1->getValueType(0).isScalableVector()) + return SDValue(); + // If inserting an UNDEF, just return the original vector. if (N1.isUndef()) return N0; Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2521,6 +2521,9 @@ if (Depth >= MaxRecursionDepth) return Known; // Limit search depth. + if (Op.getValueType().isScalableVector()) + return Known; // Unknown number of elts so assume we don't know anything. + KnownBits Known2; unsigned NumElts = DemandedElts.getBitWidth(); assert((!Op.getValueType().isVector() || @@ -4718,7 +4721,7 @@ case ISD::BITCAST: // Basic sanity checking. assert(VT.getSizeInBits() == Operand.getValueSizeInBits() && - "Cannot BITCAST between types of different sizes!"); + "Cannot BITCAST between types of different sizes!"); if (VT == Operand.getValueType()) return Operand; // noop conversion. if (OpOpcode == ISD::BITCAST) // bitconv(bitconv(x)) -> bitconv(x) return getNode(ISD::BITCAST, DL, VT, Operand.getOperand(0)); @@ -5441,7 +5444,7 @@ assert(VT.getSimpleVT() <= N1.getSimpleValueType() && "Extract subvector must be from larger vector to smaller vector!"); - if (N2C) { + if (N2C && !N1.getValueType().isScalableVector()) { assert((VT.getVectorNumElements() + N2C->getZExtValue() <= N1.getValueType().getVectorNumElements()) && "Extract subvector overflow!"); @@ -5647,7 +5650,9 @@ "Dest and insert subvector source types must match!"); assert(N2.getSimpleValueType() <= N1.getSimpleValueType() && "Insert subvector must be from smaller vector to larger vector!"); - if (isa(Index)) { + if (isa(Index) && + !N1.getValueType().isScalableVector() && + !N2.getValueType().isScalableVector()) { assert((N2.getValueType().getVectorNumElements() + cast(Index)->getZExtValue() <= VT.getVectorNumElements()) Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3602,11 +3602,14 @@ NodeToMatch->getValueType(i) != MVT::Other && NodeToMatch->getValueType(i) != MVT::Glue && "Invalid number of results to complete!"); + // FIXME: What do we do about this assert for scalable<->non-scalable + // replacements, without giving up completely? assert((NodeToMatch->getValueType(i) == Res.getValueType() || NodeToMatch->getValueType(i) == MVT::iPTR || Res.getValueType() == MVT::iPTR || NodeToMatch->getValueType(i).getSizeInBits() == - Res.getValueSizeInBits()) && + Res.getValueSizeInBits() || + NodeToMatch->getValueType(i).isScalableVector()) && "invalid replacement"); ReplaceUses(SDValue(NodeToMatch, i), Res); } Index: llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -941,6 +941,8 @@ } case ISD::INSERT_SUBVECTOR: { SDValue Base = Op.getOperand(0); + if (Base.getValueType().isScalableVector()) + break; SDValue Sub = Op.getOperand(1); EVT SubVT = Sub.getValueType(); unsigned NumSubElts = SubVT.getVectorNumElements(); @@ -997,6 +999,8 @@ case ISD::EXTRACT_SUBVECTOR: { // If index isn't constant, assume we need all the source vector elements. SDValue Src = Op.getOperand(0); + if (Src.getValueType().isScalableVector()) + break; ConstantSDNode *SubIdx = dyn_cast(Op.getOperand(1)); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); APInt SrcElts = APInt::getAllOnesValue(NumSrcElts); @@ -2420,6 +2424,8 @@ if (!isa(Op.getOperand(2))) break; SDValue Base = Op.getOperand(0); + if (Base.getValueType().isScalableVector()) + break; SDValue Sub = Op.getOperand(1); EVT SubVT = Sub.getValueType(); unsigned NumSubElts = SubVT.getVectorNumElements(); @@ -2452,6 +2458,8 @@ } case ISD::EXTRACT_SUBVECTOR: { SDValue Src = Op.getOperand(0); + if (Src.getValueType().isScalableVector()) + break; ConstantSDNode *SubIdx = dyn_cast(Op.getOperand(1)); unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { Index: llvm/lib/Target/AArch64/AArch64CallingConvention.td =================================================================== --- llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -75,10 +75,10 @@ CCIfConsecutiveRegs>, CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16, - nxv2f32, nxv4f32, nxv2f64], + nxv2f32, nxv4f32, v16f32, nxv2f64], CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16, - nxv2f32, nxv4f32, nxv2f64], + nxv2f32, nxv4f32, v16f32, nxv2f64], CCPassIndirect>, CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], @@ -155,7 +155,7 @@ CCAssignToReg<[Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7]>>, CCIfType<[nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16, - nxv2f32, nxv4f32, nxv2f64], + nxv2f32, nxv4f32, v16f32, nxv2f64], CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>, CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1], Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -4549,6 +4549,58 @@ } break; } + case ISD::EXTRACT_SUBVECTOR: { + // Bail and use the default Select() if the index is not constant zero. + auto *LaneNode = dyn_cast(Node->getOperand(1)); + if (!LaneNode || (LaneNode->getZExtValue() != 0)) + break; + + EVT ContainerVT = Node->getOperand(0).getValueType(); + + // Bail when normal isel can do the job. + if (VT.isScalableVector() || !ContainerVT.isScalableVector()) + break; + // WARNING: Breaks for reasons unknown! + //if ((VT.getSizeInBits() == 64) || (VT.getSizeInBits() == 128)) + // break; + + auto RC = (VT.getVectorElementType() == MVT::i1) + ? CurDAG->getTargetConstant(AArch64::PPRRegClassID, SDLoc(Node), MVT::i64) + : CurDAG->getTargetConstant(AArch64::ZPRRegClassID, SDLoc(Node), MVT::i64); + auto COPY = CurDAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, SDLoc(Node), + VT, Node->getOperand(0), RC); + COPY->dump(CurDAG); + ReplaceNode(Node, COPY); + return; + } + case ISD::INSERT_SUBVECTOR: { + // Bail and use the default Select() if the index is not constant zero. + auto *LaneNode = dyn_cast(Node->getOperand(2)); + if (!LaneNode || (LaneNode->getZExtValue() != 0)) + break; + + EVT InsertVT = Node->getOperand(1).getValueType(); + + // Bail when normal isel can do the job. + if (!VT.isScalableVector() || InsertVT.isScalableVector()) + break; + // WARNING: Equivalent patterns to EXTRACT_SUBVECTOR are not available. + //if ((InsertVT.getSizeInBits() == 64) || (InsertVT.getSizeInBits() == 128)) + // break; + + // Bail when inserting into real data. (HACK) + if (!Node->getOperand(0).isUndef()) + break; + + auto RC = (VT.getVectorElementType() == MVT::i1) + ? CurDAG->getTargetConstant(AArch64::PPRRegClassID, SDLoc(Node), MVT::i64) + : CurDAG->getTargetConstant(AArch64::ZPRRegClassID, SDLoc(Node), MVT::i64); + auto COPY = CurDAG->getMachineNode(TargetOpcode::COPY_TO_REGCLASS, SDLoc(Node), + VT, Node->getOperand(1), RC); + + ReplaceNode(Node, COPY); + return; + } } // Select the default instruction Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -112,6 +112,9 @@ cl::desc("Allow AArch64 Local Dynamic TLS code generation"), cl::init(false)); +static cl::opt +SVEVectorBits("aarch64-isel-sve-vector-bits", cl::ReallyHidden, cl::init(0)); + static cl::opt EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden, cl::desc("Enable AArch64 logical imm instruction " @@ -183,6 +186,17 @@ addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass); addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass); + // Fixed-width vector MVT's are legal if they fit within an SVE register. + if (SVEVectorBits != 0) { + assert(SVEVectorBits == 512 && + "Unsupported SVE fixed-width vector size."); + for (MVT VT : MVT::vector_valuetypes()) + if (!VT.isScalableVector()) + if (VT.getSizeInBits() <= SVEVectorBits) + if (VT.getVectorElementType() != MVT::i1) + addRegisterClass(VT, &AArch64::ZPRRegClass); + } + for (auto VT : { MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64 }) { setOperationAction(ISD::SADDSAT, VT, Legal); setOperationAction(ISD::UADDSAT, VT, Legal); @@ -212,6 +226,30 @@ setCondCodeAction(ISD::SETUEQ, VT, Expand); setCondCodeAction(ISD::SETUNE, VT, Expand); } + + // TODO: addTypeForFixedWidthSVE? + // No operations are legal for fixed-width types when targeting SVE. + // We custom lower everything in order to map the legal fixed-width types to + // SVE's width agnostic types. + + // Legalisation based on floating point integer types. + if (Subtarget->hasSVE() && (SVEVectorBits != 0)) { + // Legalisation based on floating point vector types. + for (MVT VT : { MVT::v16f32 }) { + if (isTypeLegal(VT) && !VT.isScalableVector()) { + setOperationAction(ISD::LOAD, VT, Custom); + setOperationAction(ISD::STORE, VT, Custom); + + setOperationAction(ISD::MLOAD, VT, Expand); + setOperationAction(ISD::MSTORE, VT, Expand); + + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Legal); + setOperationAction(ISD::INSERT_SUBVECTOR, VT, Legal); + + setOperationAction(ISD::FADD, VT, Custom); + } + } + } } // Compute derived properties from the register classes @@ -3253,11 +3291,155 @@ return SDValue(); } +static bool isFixedWidthVectorType(EVT VT) { + if (VT.is256BitVector() || + VT.is512BitVector() || + VT.is1024BitVector() || + VT.is2048BitVector()) + return true; + + return false; +} + +static EVT getSVEEquivalentType(EVT VT) { + switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + default: llvm_unreachable("unimplemented operand"); + case MVT::i8: return EVT(MVT::nxv16i8); + case MVT::i16: return EVT(MVT::nxv8i16); + case MVT::i32: return EVT(MVT::nxv4i32); + case MVT::i64: return EVT(MVT::nxv2i64); + case MVT::f16: return EVT(MVT::nxv8f16); + case MVT::f32: return EVT(MVT::nxv4f32); + case MVT::f64: return EVT(MVT::nxv2f64); + } +} + +SDValue getSVEPredicate(EVT VT, const SDLoc &DL, SelectionDAG &DAG) { + assert(VT.isVector() && !VT.isScalableVector()); + + uint64_t PgPattern; + switch (VT.getVectorNumElements()) { + default: llvm_unreachable("unimplemented operand"); + case 32: PgPattern = AArch64SVEPredPattern::vl32; break; + case 16: PgPattern = AArch64SVEPredPattern::vl16; break; + case 8: PgPattern = AArch64SVEPredPattern::vl8; break; + case 4: PgPattern = AArch64SVEPredPattern::vl4; break; + case 2: PgPattern = AArch64SVEPredPattern::vl2; break; + case 1: PgPattern = AArch64SVEPredPattern::vl1; break; + } + + if (VT.getSizeInBits() == SVEVectorBits) + PgPattern = AArch64SVEPredPattern::all; + + MVT MaskVT; + switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + default: llvm_unreachable("unimplemented operand"); + case MVT::i8: MaskVT = MVT::nxv16i1; break; + case MVT::i16: MaskVT = MVT::nxv8i1; break; + case MVT::i32: MaskVT = MVT::nxv4i1; break; + case MVT::i64: MaskVT = MVT::nxv2i1; break; + case MVT::f16: MaskVT = MVT::nxv8i1; break; + case MVT::f32: MaskVT = MVT::nxv4i1; break; + case MVT::f64: MaskVT = MVT::nxv2i1; break; + } + + return DAG.getNode(AArch64ISD::PTRUE, DL, MaskVT, + DAG.getConstant(PgPattern, DL, MVT::i64)); +} + +static EVT getSVEScalableContainerType(EVT VT, SelectionDAG &DAG) { + assert(VT.isVector() && !VT.isScalableVector()); + assert(DAG.getTargetLoweringInfo().isTypeLegal(VT)); + + switch (VT.getVectorElementType().getSimpleVT().SimpleTy) { + default: llvm_unreachable("unimplemented operand"); + case MVT::i8: return EVT(MVT::nxv16i8); + case MVT::i16: return EVT(MVT::nxv8i16); + case MVT::i32: return EVT(MVT::nxv4i32); + case MVT::i64: return EVT(MVT::nxv2i64); + case MVT::f16: return EVT(MVT::nxv8f16); + case MVT::f32: return EVT(MVT::nxv4f32); + case MVT::f64: return EVT(MVT::nxv2f64); + } +} + +static SDValue LowerFixedWidthOperation(SDValue Op, + SelectionDAG &DAG) { + SDLoc DL(Op); + + switch (Op.getOpcode()) { + case ISD::FADD: { + EVT ContainerVT = getSVEScalableContainerType(Op.getValueType(), DAG); + auto Op0 = Op->getOperand(0); + auto Op1 = Op->getOperand(1); + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + Op0 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Op0, Zero); + Op1 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Op1, Zero); + auto Res = DAG.getNode(ISD::FADD, DL, ContainerVT, Op0, Op1); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Op.getValueType(), Res, Zero); + } + case ISD::LOAD: { + auto *Load = cast(Op); + EVT ContainerVT = getSVEScalableContainerType(Op.getValueType(), DAG); + + // NOTE: Not sure how important this is. + EVT MemVT = Load->getMemoryVT(); + auto NewLoad = DAG.getMaskedLoad(ContainerVT, DL, + Load->getChain(), + Load->getBasePtr(), + DAG.getUNDEF(ContainerVT), + getSVEPredicate(Op.getValueType(), DL, DAG), + DAG.getUNDEF(ContainerVT), + MemVT, + Load->getMemOperand(), + ISD::UNINDEXED, + Load->getExtensionType()); + + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); + auto Result = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, + Op.getValueType(), NewLoad, Zero); + + SDValue MergedValues[2] = { Result, Load->getChain() }; + return DAG.getMergeValues(MergedValues, DL); + } + case ISD::STORE: { + auto *Store = cast(Op); + EVT ContainerVT = + getSVEScalableContainerType(Store->getValue().getValueType(), DAG); + + // TODO: As with the loads, I don't know how important this is. + auto MemVT = Store->getMemoryVT();//.changeVectorElementTypeToInteger().getVectorElementType(); + + SDValue ZeroIdx = DAG.getConstant(0, DL, MVT::i64); + auto NewValue = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, + DAG.getUNDEF(ContainerVT), Store->getValue(), + ZeroIdx); + auto NewStore = DAG.getMaskedStore(Store->getChain(), DL, + NewValue, + Store->getBasePtr(), + DAG.getUNDEF(ContainerVT), + getSVEPredicate(Store->getValue().getValueType(), DL, DAG), + MemVT, + Store->getMemOperand(), + ISD::UNINDEXED, + Store->isTruncatingStore()); + + return NewStore; + } + } + + return SDValue(); +} + SDValue AArch64TargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { LLVM_DEBUG(dbgs() << "Custom lowering: "); LLVM_DEBUG(Op.dump()); + if (isFixedWidthVectorType(Op->getValueType(0))) + if (auto Fixed = LowerFixedWidthOperation(Op, DAG)) + return Fixed; + switch (Op.getOpcode()) { default: llvm_unreachable("unimplemented operand"); @@ -3528,6 +3710,8 @@ RC = &AArch64::PPRRegClass; else if (RegVT.isScalableVector()) RC = &AArch64::ZPRRegClass; + else if (RegVT.is512BitVector()) + RC = &AArch64::ZPRRegClass; else llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering"); Index: llvm/lib/Target/AArch64/AArch64RegisterInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64RegisterInfo.td +++ llvm/lib/Target/AArch64/AArch64RegisterInfo.td @@ -869,7 +869,7 @@ class ZPRClass : RegisterClass<"AArch64", [nxv16i8, nxv8i16, nxv4i32, nxv2i64, nxv2f16, nxv4f16, nxv8f16, - nxv2f32, nxv4f32, + nxv2f32, nxv4f32, v16f32, nxv2f64], 128, (sequence "Z%u", 0, lastreg)> { let Size = 128; Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -347,17 +347,17 @@ // Duplicate FP scalar into all vector elements def : Pat<(nxv8f16 (AArch64dup (f16 FPR16:$src))), - (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>; + (DUP_ZZI_H (INSERT_SUBREG (nxv8f16 (IMPLICIT_DEF)), FPR16:$src, hsub), 0)>; def : Pat<(nxv4f16 (AArch64dup (f16 FPR16:$src))), - (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>; + (DUP_ZZI_H (INSERT_SUBREG (nxv4f16 (IMPLICIT_DEF)), FPR16:$src, hsub), 0)>; def : Pat<(nxv2f16 (AArch64dup (f16 FPR16:$src))), - (DUP_ZZI_H (INSERT_SUBREG (IMPLICIT_DEF), FPR16:$src, hsub), 0)>; + (DUP_ZZI_H (INSERT_SUBREG (nxv2f16 (IMPLICIT_DEF)), FPR16:$src, hsub), 0)>; def : Pat<(nxv4f32 (AArch64dup (f32 FPR32:$src))), - (DUP_ZZI_S (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$src, ssub), 0)>; + (DUP_ZZI_S (INSERT_SUBREG (nxv4f32 (IMPLICIT_DEF)), FPR32:$src, ssub), 0)>; def : Pat<(nxv2f32 (AArch64dup (f32 FPR32:$src))), - (DUP_ZZI_S (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$src, ssub), 0)>; + (DUP_ZZI_S (INSERT_SUBREG (nxv2f32 (IMPLICIT_DEF)), FPR32:$src, ssub), 0)>; def : Pat<(nxv2f64 (AArch64dup (f64 FPR64:$src))), - (DUP_ZZI_D (INSERT_SUBREG (IMPLICIT_DEF), FPR64:$src, dsub), 0)>; + (DUP_ZZI_D (INSERT_SUBREG (nxv2f64 (IMPLICIT_DEF)), FPR64:$src, dsub), 0)>; // Duplicate +0.0 into all vector elements def : Pat<(nxv8f16 (AArch64dup (f16 fpimm0))), (DUP_ZI_H 0, 0)>; @@ -1303,12 +1303,12 @@ def : Pat<(AArch64ptest (nxv2i1 PPR:$pg), (nxv2i1 PPR:$src)), (PTEST_PP PPR:$pg, PPR:$src)>; - def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i32), (SXTW_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; - def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i16), (SXTH_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; - def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i8), (SXTB_ZPmZ_D (IMPLICIT_DEF), (PTRUE_D 31), ZPR:$Zs)>; - def : Pat<(sext_inreg (nxv4i32 ZPR:$Zs), nxv4i16), (SXTH_ZPmZ_S (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; - def : Pat<(sext_inreg (nxv4i32 ZPR:$Zs), nxv4i8), (SXTB_ZPmZ_S (IMPLICIT_DEF), (PTRUE_S 31), ZPR:$Zs)>; - def : Pat<(sext_inreg (nxv8i16 ZPR:$Zs), nxv8i8), (SXTB_ZPmZ_H (IMPLICIT_DEF), (PTRUE_H 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i32), (SXTW_ZPmZ_D (nxv2i64 (IMPLICIT_DEF)), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i16), (SXTH_ZPmZ_D (nxv2i64 (IMPLICIT_DEF)), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv2i64 ZPR:$Zs), nxv2i8), (SXTB_ZPmZ_D (nxv2i64 (IMPLICIT_DEF)), (PTRUE_D 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv4i32 ZPR:$Zs), nxv4i16), (SXTH_ZPmZ_S (nxv4i32 (IMPLICIT_DEF)), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv4i32 ZPR:$Zs), nxv4i8), (SXTB_ZPmZ_S (nxv4i32 (IMPLICIT_DEF)), (PTRUE_S 31), ZPR:$Zs)>; + def : Pat<(sext_inreg (nxv8i16 ZPR:$Zs), nxv8i8), (SXTB_ZPmZ_H (nxv8i16 (IMPLICIT_DEF)), (PTRUE_H 31), ZPR:$Zs)>; // General case that we ideally never want to match. def : Pat<(vscale GPR64:$scale), (MADDXrrr (UBFMXri (RDVLI_XI 1), 4, 63), $scale, XZR)>; @@ -1390,6 +1390,8 @@ def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv4f32 (bitconvert (v16f32 ZPR:$src))), (nxv4f32 ZPR:$src)>; + def : Pat<(nxv16i1 (and PPR:$Ps1, PPR:$Ps2)), (AND_PPzPP (PTRUE_B 31), PPR:$Ps1, PPR:$Ps2)>; def : Pat<(nxv8i1 (and PPR:$Ps1, PPR:$Ps2)), Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -390,7 +390,7 @@ class SVE_1_Op_AllActive_Pat : Pat<(vtd (op vt1:$Op1)), - (inst (IMPLICIT_DEF), (ptrue 31), $Op1)>; + (inst (vt1 (IMPLICIT_DEF)), (ptrue 31), $Op1)>; class SVE_2_Op_AllActive_Pat Index: llvm/test/CodeGen/AArch64/sve-fixed-width-arith.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-fixed-width-arith.ll @@ -0,0 +1,11 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -aarch64-isel-sve-vector-bits=512 < %s | FileCheck %s + +define <16 x float> @fadd_float_512b(<16 x float> %a, <16 x float> %b) { +; CHECK-LABEL: fadd_float_512b: +; CHECK: // %bb.0: +; CHECK-NEXT: fadd z0.s, z0.s, z1.s +; CHECK-NEXT: ret + %fadd = fadd <16 x float> %a, %b + ret <16 x float> %fadd +}