diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -5805,12 +5805,11 @@ /// Helper to lower a reduction sequence of the form: /// scalar = reduce_op vec, scalar_start -static SDValue lowerReductionSeq(unsigned RVVOpcode, SDValue StartValue, - SDValue Vec, SDValue Mask, SDValue VL, - SDLoc DL, SelectionDAG &DAG, +static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT, + SDValue StartValue, SDValue Vec, SDValue Mask, + SDValue VL, SDLoc DL, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { const MVT VecVT = Vec.getSimpleValueType(); - const MVT VecEltVT = VecVT.getVectorElementType(); const MVT M1VT = getLMUL1VT(VecVT); const MVT XLenVT = Subtarget.getXLenVT(); @@ -5820,7 +5819,7 @@ SDValue PassThru = hasNonZeroAVL(VL) ? DAG.getUNDEF(M1VT) : InitialSplat; SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, PassThru, Vec, InitialSplat, Mask, VL); - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction, + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Reduction, DAG.getConstant(0, DL, XLenVT)); } @@ -5860,9 +5859,8 @@ SDValue NeutralElem = DAG.getNeutralElement(BaseOpc, DL, VecEltVT, SDNodeFlags()); - SDValue Elt0 = lowerReductionSeq(RVVOpcode, NeutralElem, Vec, Mask, VL, - DL, DAG, Subtarget); - return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType()); + return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), NeutralElem, Vec, + Mask, VL, DL, DAG, Subtarget); } // Given a reduction op, this function returns the matching reduction opcode, @@ -5913,8 +5911,8 @@ } auto [Mask, VL] = getDefaultVLOps(VecVT, ContainerVT, DL, DAG, Subtarget); - return lowerReductionSeq(RVVOpcode, ScalarVal, VectorVal, Mask, VL, DL, DAG, - Subtarget); + return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), ScalarVal, + VectorVal, Mask, VL, DL, DAG, Subtarget); } static unsigned getRVVVPReductionOp(unsigned ISDOpcode) { @@ -5969,11 +5967,8 @@ SDValue VL = Op.getOperand(3); SDValue Mask = Op.getOperand(2); - SDValue Elt0 = lowerReductionSeq(RVVOpcode, Op.getOperand(0), Vec, Mask, VL, - DL, DAG, Subtarget); - if (!VecVT.isInteger()) - return Elt0; - return DAG.getSExtOrTrunc(Elt0, DL, Op.getValueType()); + return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), Op.getOperand(0), + Vec, Mask, VL, DL, DAG, Subtarget); } SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,