Index: llvm/include/llvm/CodeGen/TargetLowering.h =================================================================== --- llvm/include/llvm/CodeGen/TargetLowering.h +++ llvm/include/llvm/CodeGen/TargetLowering.h @@ -4447,6 +4447,9 @@ /// only the first Count elements of the vector are used. SDValue expandVecReduce(SDNode *Node, SelectionDAG &DAG) const; + /// Expand a VECREDUCE_SEQ_* into an explicit ordered calculation. + SDValue expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const; + /// Expand an SREM or UREM using SDIV/UDIV or SDIVREM/UDIVREM, if legal. /// Returns true if the expansion was successful. bool expandREM(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const; Index: llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -1165,6 +1165,10 @@ Action = TLI.getOperationAction( Node->getOpcode(), Node->getOperand(0).getValueType()); break; + case ISD::VECREDUCE_SEQ_FADD: + Action = TLI.getOperationAction( + Node->getOpcode(), Node->getOperand(1).getValueType()); + break; default: if (Node->getOpcode() >= ISD::BUILTIN_OP_END) { Action = TargetLowering::Legal; Index: llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -774,6 +774,7 @@ SDValue ScalarizeVecOp_FP_ROUND(SDNode *N, unsigned OpNo); SDValue ScalarizeVecOp_STRICT_FP_ROUND(SDNode *N, unsigned OpNo); SDValue ScalarizeVecOp_VECREDUCE(SDNode *N); + SDValue ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N); //===--------------------------------------------------------------------===// // Vector Splitting Support: LegalizeVectorTypes.cpp @@ -829,6 +830,7 @@ bool SplitVectorOperand(SDNode *N, unsigned OpNo); SDValue SplitVecOp_VSELECT(SDNode *N, unsigned OpNo); SDValue SplitVecOp_VECREDUCE(SDNode *N, unsigned OpNo); + SDValue SplitVecOp_VECREDUCE_SEQ(SDNode *N); SDValue SplitVecOp_UnaryOp(SDNode *N); SDValue SplitVecOp_TruncateHelper(SDNode *N); @@ -915,6 +917,7 @@ SDValue WidenVecOp_Convert(SDNode *N); SDValue WidenVecOp_FCOPYSIGN(SDNode *N); SDValue WidenVecOp_VECREDUCE(SDNode *N); + SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N); /// Helper function to generate a set of operations to perform /// a vector operation for a wider type. Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -471,10 +471,6 @@ Node->getValueType(0), Scale); break; } - case ISD::VECREDUCE_SEQ_FADD: - Action = TLI.getOperationAction(Node->getOpcode(), - Node->getOperand(1).getValueType()); - break; case ISD::SINT_TO_FP: case ISD::UINT_TO_FP: case ISD::VECREDUCE_ADD: @@ -493,6 +489,10 @@ Action = TLI.getOperationAction(Node->getOpcode(), Node->getOperand(0).getValueType()); break; + case ISD::VECREDUCE_SEQ_FADD: + Action = TLI.getOperationAction(Node->getOpcode(), + Node->getOperand(1).getValueType()); + break; } LLVM_DEBUG(dbgs() << "\nLegalizing vector op: "; Node->dump(&DAG)); @@ -874,6 +874,9 @@ case ISD::VECREDUCE_FMIN: Results.push_back(TLI.expandVecReduce(Node, DAG)); return; + case ISD::VECREDUCE_SEQ_FADD: + Results.push_back(TLI.expandVecReduceSeq(Node, DAG)); + return; case ISD::SREM: case ISD::UREM: ExpandREM(Node, Results); Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -623,6 +623,9 @@ case ISD::VECREDUCE_FMIN: Res = ScalarizeVecOp_VECREDUCE(N); break; + case ISD::VECREDUCE_SEQ_FADD: + Res = ScalarizeVecOp_VECREDUCE_SEQ(N); + break; } } @@ -803,6 +806,30 @@ return Res; } +SDValue DAGTypeLegalizer::ScalarizeVecOp_VECREDUCE_SEQ(SDNode *N) { + SDValue AccOp = N->getOperand(0); + SDValue VecOp = N->getOperand(1); + SDNodeFlags Flags = N->getFlags(); + + unsigned BaseOpc = 0; + switch (N->getOpcode()) { + default: +#ifndef NDEBUG + dbgs() << "WidenVectorOp: "; + N->dump(&DAG); + dbgs() << "\n"; +#endif + report_fatal_error("Do not know how to widen this operator!\n"); + case ISD::VECREDUCE_SEQ_FADD: + BaseOpc = ISD::FADD; + break; + } + + SDValue Op = GetScalarizedVector(VecOp); + return DAG.getNode(BaseOpc, SDLoc(N), N->getValueType(0), + AccOp, Op, Flags); +} + //===----------------------------------------------------------------------===// // Result Vector Splitting //===----------------------------------------------------------------------===// @@ -2075,6 +2102,9 @@ case ISD::VECREDUCE_FMIN: Res = SplitVecOp_VECREDUCE(N, OpNo); break; + case ISD::VECREDUCE_SEQ_FADD: + Res = SplitVecOp_VECREDUCE_SEQ(N); + break; } } @@ -2168,6 +2198,27 @@ return DAG.getNode(N->getOpcode(), dl, ResVT, Partial, N->getFlags()); } +SDValue DAGTypeLegalizer::SplitVecOp_VECREDUCE_SEQ(SDNode *N) { + EVT ResVT = N->getValueType(0); + SDValue Lo, Hi; + SDLoc dl(N); + + SDValue AccOp = N->getOperand(0); + SDValue VecOp = N->getOperand(1); + EVT VecVT = VecOp.getValueType(); + assert(VecVT.isVector() && "Can only split reduce vector operand"); + GetSplitVector(VecOp, Lo, Hi); + EVT LoOpVT, HiOpVT; + std::tie(LoOpVT, HiOpVT) = DAG.GetSplitDestVTs(VecVT); + + // Reduce low half. + SDValue Partial = DAG.getNode(N->getOpcode(), dl, ResVT, + AccOp, Lo, N->getFlags()); + + // Reduce high half, using low half result as initial value. + return DAG.getNode(N->getOpcode(), dl, ResVT, Partial, Hi, N->getFlags()); +} + SDValue DAGTypeLegalizer::SplitVecOp_UnaryOp(SDNode *N) { // The result has a legal vector type, but the input needs splitting. EVT ResVT = N->getValueType(0); @@ -4336,6 +4387,9 @@ case ISD::VECREDUCE_FMIN: Res = WidenVecOp_VECREDUCE(N); break; + case ISD::VECREDUCE_SEQ_FADD: + Res = WidenVecOp_VECREDUCE_SEQ(N); + break; } // If Res is null, the sub-method took care of registering the result. @@ -4828,6 +4882,41 @@ return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), Op, Flags); } +SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) { + SDLoc dl(N); + SDValue AccOp = N->getOperand(0); + SDValue VecOp = N->getOperand(1); + SDValue Op = GetWidenedVector(VecOp); + + EVT OrigVT = VecOp.getValueType(); + EVT WideVT = Op.getValueType(); + EVT ElemVT = OrigVT.getVectorElementType(); + SDNodeFlags Flags = N->getFlags(); + + SDValue NeutralElem; + switch (N->getOpcode()) { + default: +#ifndef NDEBUG + dbgs() << "WidenVectorOp: "; + N->dump(&DAG); + dbgs() << "\n"; +#endif + report_fatal_error("Do not know how to widen this operator!\n"); + case ISD::VECREDUCE_SEQ_FADD: + NeutralElem = DAG.getConstantFP(-0.0, dl, ElemVT); + break; + } + + // Pad the vector with the neutral element. + unsigned OrigElts = OrigVT.getVectorNumElements(); + unsigned WideElts = WideVT.getVectorNumElements(); + for (unsigned Idx = OrigElts; Idx < WideElts; Idx++) + Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem, + DAG.getVectorIdxConstant(Idx, dl)); + + return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), AccOp, Op, Flags); +} + SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) { // This only gets called in the case that the left and right inputs and // result are of a legal odd vector type, and the condition is illegal i1 of Index: llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -8043,6 +8043,31 @@ return Res; } +SDValue TargetLowering::expandVecReduceSeq(SDNode *Node, SelectionDAG &DAG) const { + SDLoc dl(Node); + SDValue AccOp = Node->getOperand(0); + SDValue VecOp = Node->getOperand(1); + + EVT VT = VecOp.getValueType(); + EVT EltVT = VT.getVectorElementType(); + unsigned NumElts = VT.getVectorNumElements(); + + SmallVector Ops; + DAG.ExtractVectorElements(VecOp, Ops, 0, NumElts); + + unsigned BaseOpcode = 0; + switch (Node->getOpcode()) { + default: llvm_unreachable("Expected VECREDUCE_SEQ opcode"); + case ISD::VECREDUCE_SEQ_FADD: BaseOpcode = ISD::FADD; break; + } + + SDValue Res = AccOp; + for (unsigned i = 0; i < NumElts; i++) + Res = DAG.getNode(BaseOpcode, dl, EltVT, Res, Ops[i], Node->getFlags()); + + return Res; +} + bool TargetLowering::expandREM(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const { EVT VT = Node->getValueType(0); Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -780,14 +780,6 @@ return !useSVEForFixedLengthVectors(); } - // FIXME: Move useSVEForFixedLengthVectors*() back to private scope once - // reduction legalization is complete. - bool useSVEForFixedLengthVectors() const; - // Normally SVE is only used for byte size vectors that do not fit within a - // NEON vector. This changes when OverrideNEON is true, allowing SVE to be - // used for 64bit and 128bit vectors as well. - bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const; - private: /// Keep a pointer to the AArch64Subtarget around so that we can /// make the right decision when generating code for different targets. @@ -1015,6 +1007,12 @@ bool shouldLocalize(const MachineInstr &MI, const TargetTransformInfo *TTI) const override; + + bool useSVEForFixedLengthVectors() const; + // Normally SVE is only used for byte size vectors that do not fit within a + // NEON vector. This changes when OverrideNEON is true, allowing SVE to be + // used for 64bit and 128bit vectors as well. + bool useSVEForFixedLengthVectorVT(EVT VT, bool OverrideNEON = false) const; }; namespace AArch64 { Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -861,6 +861,7 @@ setOperationAction(ISD::SELECT, MVT::v1f64, Expand); setOperationAction(ISD::SELECT_CC, MVT::v1f64, Expand); setOperationAction(ISD::FP_EXTEND, MVT::v1f64, Expand); + setOperationAction(ISD::VECREDUCE_SEQ_FADD, MVT::v1f64, Expand); setOperationAction(ISD::FP_TO_SINT, MVT::v1i64, Expand); setOperationAction(ISD::FP_TO_UINT, MVT::v1i64, Expand); @@ -925,6 +926,7 @@ MVT::v8f16, MVT::v4f32, MVT::v2f64 }) { setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); + setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Expand); } for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16, MVT::v4i32 }) { Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -222,17 +222,8 @@ bool shouldExpandReduction(const IntrinsicInst *II) const { switch (II->getIntrinsicID()) { - case Intrinsic::vector_reduce_fadd: { - Value *VecOp = II->getArgOperand(1); - EVT VT = TLI->getValueType(getDataLayout(), VecOp->getType()); - if (ST->hasSVE() && - TLI->useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true)) - return false; - - return !II->getFastMathFlags().allowReassoc(); - } case Intrinsic::vector_reduce_fmul: - // We don't have legalization support for ordered FP reductions. + // We don't have legalization support for ordered FMUL reductions. return !II->getFastMathFlags().allowReassoc(); default: Index: llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll +++ llvm/test/CodeGen/AArch64/sve-fixed-length-fp-reduce.ll @@ -63,8 +63,13 @@ ; VBITS_GE_512-NEXT: ret ; Ensure sensible type legalisation. -; VBITS_EQ_256-COUNT-32: fadd -; VBITS_EQ_256: ret +; VBITS_EQ_256: add x8, x0, #32 +; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].h, vl16 +; VBITS_EQ_256-DAG: ld1h { [[LO:z[0-9]+]].h }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: ld1h { [[HI:z[0-9]+]].h }, [[PG]]/z, [x8] +; VBITS_EQ_256-NEXT: fadda h0, [[PG]], h0, [[LO]].h +; VBITS_EQ_256-NEXT: fadda h0, [[PG]], h0, [[HI]].h +; VBITS_EQ_256-NEXT: ret %op = load <32 x half>, <32 x half>* %a %res = call half @llvm.vector.reduce.fadd.v32f16(half %start, <32 x half> %op) ret half %res @@ -131,8 +136,13 @@ ; VBITS_GE_512-NEXT: ret ; Ensure sensible type legalisation. -; VBITS_EQ_256-COUNT-16: fadd -; VBITS_EQ_256: ret +; VBITS_EQ_256: add x8, x0, #32 +; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].s, vl8 +; VBITS_EQ_256-DAG: ld1w { [[LO:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: ld1w { [[HI:z[0-9]+]].s }, [[PG]]/z, [x8] +; VBITS_EQ_256-NEXT: fadda s0, [[PG]], s0, [[LO]].s +; VBITS_EQ_256-NEXT: fadda s0, [[PG]], s0, [[HI]].s +; VBITS_EQ_256-NEXT: ret %op = load <16 x float>, <16 x float>* %a %res = call float @llvm.vector.reduce.fadd.v16f32(float %start, <16 x float> %op) ret float %res @@ -199,8 +209,13 @@ ; VBITS_GE_512-NEXT: ret ; Ensure sensible type legalisation. -; VBITS_EQ_256-COUNT-8: fadd -; VBITS_EQ_256: ret +; VBITS_EQ_256: add x8, x0, #32 +; VBITS_EQ_256-NEXT: ptrue [[PG:p[0-9]+]].d, vl4 +; VBITS_EQ_256-DAG: ld1d { [[LO:z[0-9]+]].d }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: ld1d { [[HI:z[0-9]+]].d }, [[PG]]/z, [x8] +; VBITS_EQ_256-NEXT: fadda d0, [[PG]], d0, [[LO]].d +; VBITS_EQ_256-NEXT: fadda d0, [[PG]], d0, [[HI]].d +; VBITS_EQ_256-NEXT: ret %op = load <8 x double>, <8 x double>* %a %res = call double @llvm.vector.reduce.fadd.v8f64(double %start, <8 x double> %op) ret double %res Index: llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll =================================================================== --- llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll +++ llvm/test/CodeGen/AArch64/vecreduce-fadd-legalization-strict.ll @@ -93,36 +93,36 @@ define float @test_v16f32(<16 x float> %a) nounwind { ; CHECK-LABEL: test_v16f32: ; CHECK: // %bb.0: -; CHECK-NEXT: fmov s4, wzr -; CHECK-NEXT: mov s5, v0.s[1] -; CHECK-NEXT: fadd s4, s0, s4 -; CHECK-NEXT: fadd s4, s4, s5 -; CHECK-NEXT: mov s5, v0.s[2] -; CHECK-NEXT: mov s0, v0.s[3] -; CHECK-NEXT: fadd s4, s4, s5 -; CHECK-NEXT: fadd s0, s4, s0 -; CHECK-NEXT: mov s5, v1.s[1] -; CHECK-NEXT: fadd s0, s0, s1 -; CHECK-NEXT: mov s4, v1.s[2] -; CHECK-NEXT: fadd s0, s0, s5 -; CHECK-NEXT: mov s1, v1.s[3] -; CHECK-NEXT: fadd s0, s0, s4 -; CHECK-NEXT: fadd s0, s0, s1 -; CHECK-NEXT: mov s5, v2.s[1] -; CHECK-NEXT: fadd s0, s0, s2 -; CHECK-NEXT: mov s4, v2.s[2] -; CHECK-NEXT: fadd s0, s0, s5 -; CHECK-NEXT: mov s1, v2.s[3] -; CHECK-NEXT: fadd s0, s0, s4 -; CHECK-NEXT: fadd s0, s0, s1 -; CHECK-NEXT: mov s2, v3.s[1] -; CHECK-NEXT: fadd s0, s0, s3 -; CHECK-NEXT: mov s5, v3.s[2] -; CHECK-NEXT: fadd s0, s0, s2 -; CHECK-NEXT: fadd s0, s0, s5 -; CHECK-NEXT: mov s1, v3.s[3] -; CHECK-NEXT: fadd s0, s0, s1 -; CHECK-NEXT: ret +; CHECK-NEXT: fmov s24, wzr +; CHECK-NEXT: mov s21, v0.s[3] +; CHECK-NEXT: mov s22, v0.s[2] +; CHECK-NEXT: mov s23, v0.s[1] +; CHECK-NEXT: fadd s0, s0, s24 +; CHECK-NEXT: fadd s0, s0, s23 +; CHECK-NEXT: fadd s0, s0, s22 +; CHECK-NEXT: fadd s0, s0, s21 +; CHECK-NEXT: mov s20, v1.s[1] +; CHECK-NEXT: fadd s0, s0, s1 +; CHECK-NEXT: mov s19, v1.s[2] +; CHECK-NEXT: fadd s0, s0, s20 +; CHECK-NEXT: mov s18, v1.s[3] +; CHECK-NEXT: fadd s0, s0, s19 +; CHECK-NEXT: fadd s0, s0, s18 +; CHECK-NEXT: mov s17, v2.s[1] +; CHECK-NEXT: fadd s0, s0, s2 +; CHECK-NEXT: mov s16, v2.s[2] +; CHECK-NEXT: fadd s0, s0, s17 +; CHECK-NEXT: mov s7, v2.s[3] +; CHECK-NEXT: fadd s0, s0, s16 +; CHECK-NEXT: fadd s0, s0, s7 +; CHECK-NEXT: mov s6, v3.s[1] +; CHECK-NEXT: fadd s0, s0, s3 +; CHECK-NEXT: mov s5, v3.s[2] +; CHECK-NEXT: fadd s0, s0, s6 +; CHECK-NEXT: mov s4, v3.s[3] +; CHECK-NEXT: fadd s0, s0, s5 +; CHECK-NEXT: fadd s0, s0, s4 +; CHECK-NEXT: ret %b = call float @llvm.vector.reduce.fadd.f32.v16f32(float 0.0, <16 x float> %a) ret float %b }