diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3158,10 +3158,15 @@ return true; } -// Try to fold VMERGE_VVM with unmasked intrinsic to masked intrinsic. The -// peephole only deals with VMERGE_VVM which is TU and has false operand same as -// its true operand now. E.g. (VMERGE_VVM_M1_TU False, False, (VADD_M1 ...), -// ...) -> (VADD_VV_M1_MASK) +// Try to fold away VMERGE_VVM instructions. We handle these cases: +// -Masked TU VMERGE_VVM combined with an unmasked TA instruction instruction +// folds to a masked TU instruction. VMERGE_VVM must have have merge operand +// same as false operand. +// -Masked TA VMERGE_VVM combined with an unmasked TA instruction fold to a +// masked TA instruction. +// -Unmasked TU VMERGE_VVM combined with a masked MU TA instruction folds to +// masked TU instruction. Both instructions must have the same merge operand. +// VMERGE_VVM must have have merge operand same as false operand. bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N, bool IsTA) { unsigned Offset = IsTA ? 0 : 1; uint64_t Policy = IsTA ? RISCVII::TAIL_AGNOSTIC : /*TUMU*/ 0; @@ -3184,25 +3189,44 @@ unsigned TrueOpc = True.getMachineOpcode(); // Skip if True has merge operand. - // TODO: Deal with True having same merge operand with N. - if (RISCVII::hasMergeOp(TII->get(TrueOpc).TSFlags)) - return false; + uint64_t TrueTSFlags = TII->get(TrueOpc).TSFlags; + bool HasMergeOp = RISCVII::hasMergeOp(TrueTSFlags); + + if (HasMergeOp) { + // The vmerge instruction must be TU. + if (IsTA) + return false; + SDValue MergeOpN = N->getOperand(0); + SDValue MergeOpTrue = True->getOperand(0); + // Both the vmerge instruction and the True instruction must have the same + // merge operand. The vmerge instruction must have an all 1s mask since + // we're going to keep the mask from the True instruction. + // FIXME: Support mask agnostic True instruction which would have an + // undef merge operand. + if (MergeOpN != MergeOpTrue || !usesAllOnesMask(N, /* MaskOpIdx */ 3)) + return false; + } // Skip if True has side effect. - // TODO: Support velff and vlsegff. + // TODO: Support vleff and vlsegff. if (TII->get(TrueOpc).hasUnmodeledSideEffects()) return false; - // Only deal with True when True is unmasked intrinsic now. const RISCV::RISCVMaskedPseudoInfo *Info = - RISCV::lookupMaskedIntrinsicByUnmaskedTA(TrueOpc); + HasMergeOp ? RISCV::getMaskedPseudoInfo(TrueOpc) + : RISCV::lookupMaskedIntrinsicByUnmaskedTA(TrueOpc); if (!Info) return false; - // The last operand of unmasked intrinsic should be sew or chain. + // The last operand of a masked instruction may be glued. + bool HasGlueOp = True->getGluedNode() != nullptr; + + // The chain operand may exist either before the glued operands or in the last + // position. + unsigned TrueChainOpIdx = True.getNumOperands() - HasGlueOp - 1; bool HasChainOp = - True.getOperand(True.getNumOperands() - 1).getValueType() == MVT::Other; + True.getOperand(TrueChainOpIdx).getValueType() == MVT::Other; if (HasChainOp) { // Avoid creating cycles in the DAG. We must ensure that none of the other @@ -3218,8 +3242,10 @@ return false; } - // Need True has same VL with N. - unsigned TrueVLIndex = True.getNumOperands() - HasChainOp - 2; + // The vector policy operand may be present for masked intrinsics + bool HasVecPolicyOp = RISCVII::hasVecPolicyOp(TrueTSFlags); + unsigned TrueVLIndex = + True.getNumOperands() - HasVecPolicyOp - HasChainOp - HasGlueOp - 2; SDValue TrueVL = True.getOperand(TrueVLIndex); auto IsNoFPExcept = [this](SDValue N) { @@ -3241,18 +3267,26 @@ "Expected instructions with mask have merge operand."); SmallVector Ops; - Ops.push_back(False); - Ops.append(True->op_begin(), True->op_begin() + TrueVLIndex); - Ops.append({Mask, VL, /* SEW */ True.getOperand(TrueVLIndex + 1)}); - Ops.push_back(CurDAG->getTargetConstant(Policy, DL, Subtarget->getXLenVT())); - - // Result node should have chain operand of True. - if (HasChainOp) - Ops.push_back(True.getOperand(True.getNumOperands() - 1)); - - // Result node should take over glued node of N. - if (N->getGluedNode()) - Ops.push_back(N->getOperand(N->getNumOperands() - 1)); + if (HasMergeOp) { + Ops.append(True->op_begin(), True->op_begin() + TrueVLIndex); + Ops.append({VL, /* SEW */ True.getOperand(TrueVLIndex + 1)}); + Ops.push_back( + CurDAG->getTargetConstant(Policy, DL, Subtarget->getXLenVT())); + Ops.append(True->op_begin() + TrueVLIndex + 3, True->op_end()); + } else { + Ops.push_back(False); + Ops.append(True->op_begin(), True->op_begin() + TrueVLIndex); + Ops.append({Mask, VL, /* SEW */ True.getOperand(TrueVLIndex + 1)}); + Ops.push_back( + CurDAG->getTargetConstant(Policy, DL, Subtarget->getXLenVT())); + + // Result node should have chain operand of True. + if (HasChainOp) + Ops.push_back(True.getOperand(TrueChainOpIdx)); + + if (N->getGluedNode()) + Ops.push_back(N->getOperand(N->getNumOperands() - 1)); + } SDNode *Result = CurDAG->getMachineNode(MaskedOpc, DL, True->getVTList(), Ops); diff --git a/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll @@ -0,0 +1,95 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s + +declare @llvm.riscv.vmerge.nxv2i32.nxv2i32(, , , , i64); +declare @llvm.riscv.vmerge.nxv2f32.nxv2f32(, , , , i64); + +define @vpmerge_vadd( %passthru, %x, %y, %m, i64 %vl) { +; CHECK-LABEL: vpmerge_vadd: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, mu +; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t +; CHECK-NEXT: ret + %a = call @llvm.riscv.vadd.mask.nxv2i32.nxv2i32( %passthru, %x, %y, %m, i64 %vl, i64 1) + %splat = insertelement poison, i1 -1, i32 0 + %mask = shufflevector %splat, poison, zeroinitializer + %b = call @llvm.riscv.vmerge.nxv2i32.nxv2i32( %passthru, %passthru, %a, %mask, i64 %vl) + ret %b +} +declare @llvm.riscv.vadd.mask.nxv2i32.nxv2i32(, , , , i64, i64) + +define @vpmerge_vsub( %passthru, %x, %y, %m, i64 %vl) { +; CHECK-LABEL: vpmerge_vsub: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, mu +; CHECK-NEXT: vsub.vv v8, v9, v10, v0.t +; CHECK-NEXT: ret + %a = call @llvm.riscv.vsub.mask.nxv2i32.nxv2i32( %passthru, %x, %y, %m, i64 %vl, i64 1) + %splat = insertelement poison, i1 -1, i32 0 + %mask = shufflevector %splat, poison, zeroinitializer + %b = call @llvm.riscv.vmerge.nxv2i32.nxv2i32( %passthru, %passthru, %a, %mask, i64 %vl) + ret %b +} +declare @llvm.riscv.vsub.mask.nxv2i32.nxv2i32(, , , , i64, i64) + +define @vpmerge_vfadd( %passthru, %x, %y, %m, i64 %vl) { +; CHECK-LABEL: vpmerge_vfadd: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, mu +; CHECK-NEXT: vfadd.vv v8, v9, v10, v0.t +; CHECK-NEXT: ret + %a = call @llvm.riscv.vfadd.mask.nxv2f32.nxv2f32( %passthru, %x, %y, %m, i64 %vl, i64 1) + %splat = insertelement poison, i1 -1, i32 0 + %mask = shufflevector %splat, poison, zeroinitializer + %b = call @llvm.riscv.vmerge.nxv2f32.nxv2f32( %passthru, %passthru, %a, %mask, i64 %vl) + + ret %b +} +declare @llvm.riscv.vfadd.mask.nxv2f32.nxv2f32(, , , , i64, i64) + +define @vpmerge_vfsub( %passthru, %x, %y, %m, i64 %vl) { +; CHECK-LABEL: vpmerge_vfsub: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, mu +; CHECK-NEXT: vfsub.vv v8, v9, v10, v0.t +; CHECK-NEXT: ret + %a = call @llvm.riscv.vfsub.mask.nxv2f32.nxv2f32( %passthru, %x, %y, %m, i64 %vl, i64 1) + %splat = insertelement poison, i1 -1, i32 0 + %mask = shufflevector %splat, poison, zeroinitializer + %b = call @llvm.riscv.vmerge.nxv2f32.nxv2f32( %passthru, %passthru, %a, %mask, i64 %vl) + ret %b +} +declare @llvm.riscv.vfsub.mask.nxv2f32.nxv2f32(, , , , i64, i64) + +define @vpmerge_vwadd( %passthru, %x, %y, %m, i64 %vl) { +; CHECK-LABEL: vpmerge_vwadd: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf2, tu, mu +; CHECK-NEXT: vwadd.vv v8, v9, v10, v0.t +; CHECK-NEXT: ret + %a = call @llvm.riscv.vwadd.mask.nxv2i32.nxv2i16.nxv2i16( %passthru, %x, %y, %m, i64 %vl, i64 1) + %splat = insertelement poison, i1 -1, i32 0 + %mask = shufflevector %splat, poison, zeroinitializer + %b = call @llvm.riscv.vmerge.nxv2i32.nxv2i32( %passthru, %passthru, %a, %mask, i64 %vl) + ret %b +} +declare @llvm.riscv.vwadd.mask.nxv2i32.nxv2i16.nxv2i16(, , , , i64, i64) + +define @vpmerge_vle( %passthru, * %p, %m, i64 %vl) { +; CHECK-LABEL: vpmerge_vle: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a1, e32, m1, tu, mu +; CHECK-NEXT: vle32.v v8, (a0), v0.t +; CHECK-NEXT: ret + %a = call @llvm.riscv.vle.mask.nxv2i32( + %passthru, + * %p, + %m, + i64 %vl, i64 1) + %splat = insertelement poison, i1 -1, i32 0 + %mask = shufflevector %splat, poison, zeroinitializer + %b = call @llvm.riscv.vmerge.nxv2i32.nxv2i32( %passthru, %passthru, %a, %mask, i64 %vl) + ret %b +} +declare @llvm.riscv.vle.mask.nxv2i32(, *, , i64, i64) +