diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -901,6 +901,23 @@ } void SetWidenedVector(SDValue Op, SDValue Result); + /// Given a mask Mask, returns the larger vector into which Mask was widened. + SDValue GetWidenedMask(SDValue Mask, ElementCount EC) { + // For VP operations, we must also widen the mask. Note that the mask type + // may not actually need widening, leading it be split along with the VP + // operation. + // FIXME: This could lead to an infinite split/widen loop. We only handle + // the case where the mask needs widening to an identically-sized type as + // the vector inputs. + assert(getTypeAction(Mask.getValueType()) == + TargetLowering::TypeWidenVector && + "Unable to widen binary VP op"); + Mask = GetWidenedVector(Mask); + assert(Mask.getValueType().getVectorElementCount() == EC && + "Unable to widen binary VP op"); + return Mask; + } + // Widen Vector Result Promotion. void WidenVectorResult(SDNode *N, unsigned ResNo); SDValue WidenVecRes_MERGE_VALUES(SDNode* N, unsigned ResNo); @@ -960,6 +977,7 @@ SDValue WidenVecOp_FCOPYSIGN(SDNode *N); SDValue WidenVecOp_VECREDUCE(SDNode *N); SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N); + SDValue WidenVecOp_VP_REDUCE(SDNode *N); /// Helper function to generate a set of operations to perform /// a vector operation for a wider type. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -3397,20 +3397,8 @@ assert(N->getNumOperands() == 4 && "Unexpected number of operands!"); assert(N->isVPOpcode() && "Expected VP opcode"); - // For VP operations, we must also widen the mask. Note that the mask type - // may not actually need widening, leading it be split along with the VP - // operation. - // FIXME: This could lead to an infinite split/widen loop. We only handle the - // case where the mask needs widening to an identically-sized type as the - // vector inputs. - SDValue Mask = N->getOperand(2); - assert(getTypeAction(Mask.getValueType()) == - TargetLowering::TypeWidenVector && - "Unable to widen binary VP op"); - Mask = GetWidenedVector(Mask); - assert(Mask.getValueType().getVectorElementCount() == - WidenVT.getVectorElementCount() && - "Unable to widen binary VP op"); + SDValue Mask = + GetWidenedMask(N->getOperand(2), WidenVT.getVectorElementCount()); return DAG.getNode(N->getOpcode(), dl, WidenVT, {InOp1, InOp2, Mask, N->getOperand(3)}, N->getFlags()); } @@ -4930,6 +4918,23 @@ case ISD::VECREDUCE_SEQ_FMUL: Res = WidenVecOp_VECREDUCE_SEQ(N); break; + case ISD::VP_REDUCE_FADD: + case ISD::VP_REDUCE_SEQ_FADD: + case ISD::VP_REDUCE_FMUL: + case ISD::VP_REDUCE_SEQ_FMUL: + case ISD::VP_REDUCE_ADD: + case ISD::VP_REDUCE_MUL: + case ISD::VP_REDUCE_AND: + case ISD::VP_REDUCE_OR: + case ISD::VP_REDUCE_XOR: + case ISD::VP_REDUCE_SMAX: + case ISD::VP_REDUCE_SMIN: + case ISD::VP_REDUCE_UMAX: + case ISD::VP_REDUCE_UMIN: + case ISD::VP_REDUCE_FMAX: + case ISD::VP_REDUCE_FMIN: + Res = WidenVecOp_VP_REDUCE(N); + break; } // If Res is null, the sub-method took care of registering the result. @@ -5523,6 +5528,19 @@ return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags); } +SDValue DAGTypeLegalizer::WidenVecOp_VP_REDUCE(SDNode *N) { + assert(N->isVPOpcode() && "Expected VP opcode"); + + SDLoc dl(N); + SDValue Op = GetWidenedVector(N->getOperand(1)); + SDValue Mask = GetWidenedMask(N->getOperand(2), + Op.getValueType().getVectorElementCount()); + + return DAG.getNode(N->getOpcode(), dl, N->getValueType(0), + {N->getOperand(0), Op, Mask, N->getOperand(3)}, + N->getFlags()); +} + SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) { // This only gets called in the case that the left and right inputs and // result are of a legal odd vector type, and the condition is illegal i1 of diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll @@ -144,6 +144,34 @@ ret double %r } +declare double @llvm.vp.reduce.fadd.v3f64(double, <3 x double>, <3 x i1>, i32) + +define double @vpreduce_fadd_v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_fadd_v3f64: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu +; CHECK-NEXT: vfmv.s.f v10, fa0 +; CHECK-NEXT: vsetvli zero, a0, e64, m2, tu, mu +; CHECK-NEXT: vfredusum.vs v10, v8, v10, v0.t +; CHECK-NEXT: vfmv.f.s fa0, v10 +; CHECK-NEXT: ret + %r = call reassoc double @llvm.vp.reduce.fadd.v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 %evl) + ret double %r +} + +define double @vpreduce_ord_fadd_v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_ord_fadd_v3f64: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu +; CHECK-NEXT: vfmv.s.f v10, fa0 +; CHECK-NEXT: vsetvli zero, a0, e64, m2, tu, mu +; CHECK-NEXT: vfredosum.vs v10, v8, v10, v0.t +; CHECK-NEXT: vfmv.f.s fa0, v10 +; CHECK-NEXT: ret + %r = call double @llvm.vp.reduce.fadd.v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 %evl) + ret double %r +} + declare double @llvm.vp.reduce.fadd.v4f64(double, <4 x double>, <4 x i1>, i32) define double @vpreduce_fadd_v4f64(double %s, <4 x double> %v, <4 x i1> %m, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll @@ -172,6 +172,22 @@ ret i8 %r } +declare i8 @llvm.vp.reduce.umin.v3i8(i8, <3 x i8>, <3 x i1>, i32) + +define signext i8 @vpreduce_umin_v3i8(i8 signext %s, <3 x i8> %v, <3 x i1> %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_umin_v3i8: +; CHECK: # %bb.0: +; CHECK-NEXT: andi a0, a0, 255 +; CHECK-NEXT: vsetivli zero, 1, e8, m1, ta, mu +; CHECK-NEXT: vmv.s.x v9, a0 +; CHECK-NEXT: vsetvli zero, a1, e8, mf4, tu, mu +; CHECK-NEXT: vredminu.vs v9, v8, v9, v0.t +; CHECK-NEXT: vmv.x.s a0, v9 +; CHECK-NEXT: ret + %r = call i8 @llvm.vp.reduce.umin.v3i8(i8 %s, <3 x i8> %v, <3 x i1> %m, i32 %evl) + ret i8 %r +} + declare i8 @llvm.vp.reduce.umin.v4i8(i8, <4 x i8>, <4 x i1>, i32) define signext i8 @vpreduce_umin_v4i8(i8 signext %s, <4 x i8> %v, <4 x i1> %m, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-mask-vp.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-mask-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-mask-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-mask-vp.ll @@ -212,6 +212,23 @@ ret i1 %r } +declare i1 @llvm.vp.reduce.and.v10i1(i1, <10 x i1>, <10 x i1>, i32) + +define signext i1 @vpreduce_and_v10i1(i1 signext %s, <10 x i1> %v, <10 x i1> %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_and_v10i1: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, mu +; CHECK-NEXT: vmnand.mm v9, v0, v0 +; CHECK-NEXT: vmv1r.v v0, v8 +; CHECK-NEXT: vcpop.m a1, v9, v0.t +; CHECK-NEXT: seqz a1, a1 +; CHECK-NEXT: and a0, a1, a0 +; CHECK-NEXT: neg a0, a0 +; CHECK-NEXT: ret + %r = call i1 @llvm.vp.reduce.and.v10i1(i1 %s, <10 x i1> %v, <10 x i1> %m, i32 %evl) + ret i1 %r +} + declare i1 @llvm.vp.reduce.and.v16i1(i1, <16 x i1>, <16 x i1>, i32) define signext i1 @vpreduce_and_v16i1(i1 signext %s, <16 x i1> %v, <16 x i1> %m, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll @@ -228,6 +228,35 @@ ret double %r } +declare double @llvm.vp.reduce.fadd.nxv3f64(double, , , i32) + +define double @vpreduce_fadd_nxv3f64(double %s, %v, %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_fadd_nxv3f64: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu +; CHECK-NEXT: vfmv.s.f v12, fa0 +; CHECK-NEXT: vsetvli zero, a0, e64, m4, tu, mu +; CHECK-NEXT: vfredusum.vs v12, v8, v12, v0.t +; CHECK-NEXT: vfmv.f.s fa0, v12 +; CHECK-NEXT: ret + %r = call reassoc double @llvm.vp.reduce.fadd.nxv3f64(double %s, %v, %m, i32 %evl) + ret double %r +} + +define double @vpreduce_ord_fadd_nxv3f64(double %s, %v, %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_ord_fadd_nxv3f64: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu +; CHECK-NEXT: vfmv.s.f v12, fa0 +; CHECK-NEXT: vsetvli zero, a0, e64, m4, tu, mu +; CHECK-NEXT: vfredosum.vs v12, v8, v12, v0.t +; CHECK-NEXT: vfmv.f.s fa0, v12 +; CHECK-NEXT: ret + %r = call double @llvm.vp.reduce.fadd.nxv4f64(double %s, %v, %m, i32 %evl) + ret double %r +} + + declare double @llvm.vp.reduce.fadd.nxv4f64(double, , , i32) define double @vpreduce_fadd_nxv4f64(double %s, %v, %m, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll @@ -248,6 +248,21 @@ ret i8 %r } +declare i8 @llvm.vp.reduce.smax.nxv3i8(i8, , , i32) + +define signext i8 @vpreduce_smax_nxv3i8(i8 signext %s, %v, %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_smax_nxv3i8: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetivli zero, 1, e8, m1, ta, mu +; CHECK-NEXT: vmv.s.x v9, a0 +; CHECK-NEXT: vsetvli zero, a1, e8, mf2, tu, mu +; CHECK-NEXT: vredmax.vs v9, v8, v9, v0.t +; CHECK-NEXT: vmv.x.s a0, v9 +; CHECK-NEXT: ret + %r = call i8 @llvm.vp.reduce.smax.nxv3i8(i8 %s, %v, %m, i32 %evl) + ret i8 %r +} + declare i8 @llvm.vp.reduce.add.nxv4i8(i8, , , i32) define signext i8 @vpreduce_add_nxv4i8(i8 signext %s, %v, %m, i32 zeroext %evl) { diff --git a/llvm/test/CodeGen/RISCV/rvv/vreductions-mask-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vreductions-mask-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vreductions-mask-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vreductions-mask-vp.ll @@ -331,6 +331,24 @@ ret i1 %r } +declare i1 @llvm.vp.reduce.or.nxv40i1(i1, , , i32) + +define signext i1 @vpreduce_or_nxv40i1(i1 signext %s, %v, %m, i32 zeroext %evl) { +; CHECK-LABEL: vpreduce_or_nxv40i1: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v9, v0 +; CHECK-NEXT: vsetvli zero, a1, e8, m8, ta, mu +; CHECK-NEXT: vmv1r.v v0, v8 +; CHECK-NEXT: vcpop.m a1, v9, v0.t +; CHECK-NEXT: snez a1, a1 +; CHECK-NEXT: or a0, a1, a0 +; CHECK-NEXT: andi a0, a0, 1 +; CHECK-NEXT: neg a0, a0 +; CHECK-NEXT: ret + %r = call i1 @llvm.vp.reduce.or.nxv40i1(i1 %s, %v, %m, i32 %evl) + ret i1 %r +} + declare i1 @llvm.vp.reduce.or.nxv64i1(i1, , , i32) define signext i1 @vpreduce_or_nxv64i1(i1 signext %s, %v, %m, i32 zeroext %evl) {