diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -3463,11 +3463,19 @@ if (Scalar.getOpcode() == ISD::EXTRACT_VECTOR_ELT && isNullConstant(Scalar.getOperand(1))) { - MVT ExtractedVT = Scalar.getOperand(0).getSimpleValueType(); - if (ExtractedVT.bitsLE(VT)) - return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, - Scalar.getOperand(0), DAG.getConstant(0, DL, XLenVT)); - return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Scalar.getOperand(0), + SDValue ExtractedVal = Scalar.getOperand(0); + MVT ExtractedVT = ExtractedVal.getSimpleValueType(); + MVT ExtractedContainerVT = ExtractedVT; + if (ExtractedContainerVT.isFixedLengthVector()) { + ExtractedContainerVT = getContainerForFixedLengthVector( + DAG, ExtractedContainerVT, Subtarget); + ExtractedVal = convertToScalableVector(ExtractedContainerVT, ExtractedVal, + DAG, Subtarget); + } + if (ExtractedContainerVT.bitsLE(VT)) + return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, Passthru, ExtractedVal, + DAG.getConstant(0, DL, XLenVT)); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, ExtractedVal, DAG.getConstant(0, DL, XLenVT)); } @@ -7667,25 +7675,6 @@ DAG.getConstant(0, DL, XLenVT)); } -// Function to extract the first element of Vec. For fixed vector Vec, this -// converts it to a scalable vector before extraction, so subsequent -// optimizations don't have to handle fixed vectors. -static SDValue getFirstElement(SDValue Vec, SelectionDAG &DAG, - const RISCVSubtarget &Subtarget) { - SDLoc DL(Vec); - MVT XLenVT = Subtarget.getXLenVT(); - MVT VecVT = Vec.getSimpleValueType(); - MVT VecEltVT = VecVT.getVectorElementType(); - - MVT ContainerVT = VecVT; - if (VecVT.isFixedLengthVector()) { - ContainerVT = getContainerForFixedLengthVector(DAG, VecVT, Subtarget); - Vec = convertToScalableVector(ContainerVT, Vec, DAG, Subtarget); - } - return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec, - DAG.getConstant(0, DL, XLenVT)); -} - SDValue RISCVTargetLowering::lowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const { SDLoc DL(Op); @@ -7728,7 +7717,9 @@ case ISD::UMIN: case ISD::SMAX: case ISD::SMIN: - StartV = getFirstElement(Vec, DAG, Subtarget); + MVT XLenVT = Subtarget.getXLenVT(); + StartV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VecEltVT, Vec, + DAG.getConstant(0, DL, XLenVT)); } return lowerReductionSeq(RVVOpcode, Op.getSimpleValueType(), StartV, Vec, Mask, VL, DL, DAG, Subtarget); @@ -7756,11 +7747,16 @@ return std::make_tuple(RISCVISD::VECREDUCE_SEQ_FADD_VL, Op.getOperand(1), Op.getOperand(0)); case ISD::VECREDUCE_FMIN: - return std::make_tuple(RISCVISD::VECREDUCE_FMIN_VL, Op.getOperand(0), - getFirstElement(Op.getOperand(0), DAG, Subtarget)); - case ISD::VECREDUCE_FMAX: - return std::make_tuple(RISCVISD::VECREDUCE_FMAX_VL, Op.getOperand(0), - getFirstElement(Op.getOperand(0), DAG, Subtarget)); + case ISD::VECREDUCE_FMAX: { + MVT XLenVT = Subtarget.getXLenVT(); + SDValue Front = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Op.getOperand(0), + DAG.getConstant(0, DL, XLenVT)); + unsigned RVVOpc = (Opcode == ISD::VECREDUCE_FMIN) + ? RISCVISD::VECREDUCE_FMIN_VL + : RISCVISD::VECREDUCE_FMAX_VL; + return std::make_tuple(RVVOpc, Op.getOperand(0), Front); + } } } diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll @@ -1875,3 +1875,32 @@ %r = call i8 @llvm.vp.reduce.mul.v64i8(i8 %s, <64 x i8> %v, <64 x i1> %m, i32 %evl) ret i8 %r } + +; Test start value is the first element of a vector. +define zeroext i8 @front_ele_v4i8(<4 x i8> %v, <4 x i1> %m, i32 zeroext %evl) { +; CHECK-LABEL: front_ele_v4i8: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, mf4, ta, ma +; CHECK-NEXT: vredand.vs v8, v8, v8, v0.t +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: andi a0, a0, 255 +; CHECK-NEXT: ret + %s = extractelement <4 x i8> %v, i64 0 + %r = call i8 @llvm.vp.reduce.and.v4i8(i8 %s, <4 x i8> %v, <4 x i1> %m, i32 %evl) + ret i8 %r +} + +; Test start value is the first element of a vector which longer than M1. +declare i8 @llvm.vp.reduce.and.v32i8(i8, <32 x i8>, <32 x i1>, i32) +define zeroext i8 @front_ele_v32i8(<32 x i8> %v, <32 x i1> %m, i32 zeroext %evl) { +; CHECK-LABEL: front_ele_v32i8: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e8, m2, ta, ma +; CHECK-NEXT: vredand.vs v8, v8, v8, v0.t +; CHECK-NEXT: vmv.x.s a0, v8 +; CHECK-NEXT: andi a0, a0, 255 +; CHECK-NEXT: ret + %s = extractelement <32 x i8> %v, i64 0 + %r = call i8 @llvm.vp.reduce.and.v32i8(i8 %s, <32 x i8> %v, <32 x i1> %m, i32 %evl) + ret i8 %r +}