diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -5796,6 +5796,13 @@ return DAG.getNode(BaseOpc, DL, XLenVT, SetCC, Op.getOperand(0)); } +static bool hasNonZeroAVL(SDValue AVL) { + auto *RegisterAVL = dyn_cast(AVL); + auto *ImmAVL = dyn_cast(AVL); + return (RegisterAVL && RegisterAVL->getReg() == RISCV::X0) || + (ImmAVL && ImmAVL->getZExtValue() >= 1); +} + /// Helper to lower a reduction sequence of the form: /// scalar = reduce_op vec, scalar_start static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, SDValue Vec, SDValue Mask, SDValue VL, @@ -5808,7 +5815,8 @@ SDValue InitialSplat = lowerScalarSplat(SDValue(), StartValue, DAG.getConstant(1, DL, XLenVT), M1VT, DL, DAG, Subtarget); - SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, DAG.getUNDEF(M1VT), Vec, + SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat; + SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec, InitialSplat, Mask, VL); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction, DAG.getConstant(0, DL, XLenVT)); @@ -5951,29 +5959,17 @@ return SDValue(); MVT VecVT = VecEVT.getSimpleVT(); - MVT VecEltVT = VecVT.getVectorElementType(); unsigned RVVOpcode = getRVVVPReductionOp(Op.getOpcode()); - MVT ContainerVT = VecVT; if (VecVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(VecVT); + auto ContainerVT = getContainerForFixedLengthVector(VecVT); Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); } SDValue VL = Op.getOperand(3); SDValue Mask = Op.getOperand(2); - - MVT M1VT = getLMUL1VT(ContainerVT); - MVT XLenVT = Subtarget.getXLenVT(); - MVT ResVT = !VecVT.isInteger() || VecEltVT.bitsGE(XLenVT) ? VecEltVT : XLenVT; - - SDValue StartSplat = lowerScalarSplat(SDValue(), Op.getOperand(0), - DAG.getConstant(1, DL, XLenVT), M1VT, - DL, DAG, Subtarget); - SDValue Reduction = - DAG.getNode(RVVOpcode, DL, M1VT, StartSplat, Vec, StartSplat, Mask, VL); - SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction, - DAG.getConstant(0, DL, XLenVT)); + SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL, + DL, DAG, Subtarget); if (!VecVT.isInteger()) return Elt0; return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType());