diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2327,6 +2327,14 @@ return SDValue(); // Don't custom lower most intrinsics. } +static MVT getLMUL1VT(MVT VT) { + assert(VT.getVectorElementType().getSizeInBits() <= 64 && + "Unexpected vector MVT"); + return MVT::getScalableVectorVT( + VT.getVectorElementType(), + RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits()); +} + static std::pair getRVVReductionOpAndIdentityVal(unsigned ISDOpcode, unsigned EltSizeBits) { switch (ISDOpcode) { @@ -2360,18 +2368,13 @@ assert(Op.getValueType().isSimple() && Op.getOperand(0).getValueType().isSimple() && "Unexpected vector-reduce lowering"); - MVT VecEltVT = Op.getOperand(0).getSimpleValueType().getVectorElementType(); + MVT VecVT = Op.getOperand(0).getSimpleValueType(); + MVT VecEltVT = VecVT.getVectorElementType(); unsigned RVVOpcode; uint64_t IdentityVal; std::tie(RVVOpcode, IdentityVal) = getRVVReductionOpAndIdentityVal(Op.getOpcode(), VecEltVT.getSizeInBits()); - // We have to perform a bit of a dance to get from our vector type to the - // correct LMUL=1 vector type. We divide our minimum VLEN (64) by the vector - // element type to find the type which fills a single register. Be careful to - // use the operand's vector element type rather than the reduction's value - // type, as that has likely been extended to XLEN. - unsigned NumElts = 64 / VecEltVT.getSizeInBits(); - MVT M1VT = MVT::getScalableVectorVT(VecEltVT, NumElts); + MVT M1VT = getLMUL1VT(VecVT); SDValue IdentitySplat = DAG.getSplatVector(M1VT, DL, DAG.getConstant(IdentityVal, DL, VecEltVT)); SDValue Reduction = @@ -2403,30 +2406,19 @@ SelectionDAG &DAG) const { SDLoc DL(Op); MVT VecEltVT = Op.getSimpleValueType(); - // We have to perform a bit of a dance to get from our vector type to the - // correct LMUL=1 vector type. See above for an explanation. - unsigned NumElts = 64 / VecEltVT.getSizeInBits(); - MVT M1VT = MVT::getScalableVectorVT(VecEltVT, NumElts); unsigned RVVOpcode; SDValue VectorVal, ScalarVal; std::tie(RVVOpcode, VectorVal, ScalarVal) = getRVVFPReductionOpAndOperands(Op, DAG, VecEltVT); + MVT M1VT = getLMUL1VT(VectorVal.getSimpleValueType()); SDValue ScalarSplat = DAG.getSplatVector(M1VT, DL, ScalarVal); SDValue Reduction = DAG.getNode(RVVOpcode, DL, M1VT, VectorVal, ScalarSplat); return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Reduction, DAG.getConstant(0, DL, Subtarget.getXLenVT())); } -static MVT getLMUL1VT(MVT VT) { - assert(VT.getVectorElementType().getSizeInBits() <= 64 && - "Unexpected vector MVT"); - return MVT::getScalableVectorVT( - VT.getVectorElementType(), - RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits()); -} - SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const { SDValue Vec = Op.getOperand(0);