diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3249,18 +3249,40 @@ uint64_t TrueTSFlags = TII->get(TrueOpc).TSFlags; bool HasMergeOp = RISCVII::hasMergeOp(TrueTSFlags); + bool IsMasked = false; + const RISCV::RISCVMaskedPseudoInfo *Info = + RISCV::lookupMaskedIntrinsicByUnmaskedTA(TrueOpc); + if (!Info && HasMergeOp) { + Info = RISCV::getMaskedPseudoInfo(TrueOpc); + IsMasked = true; + } + + if (!Info) + return false; + if (HasMergeOp) { // The vmerge instruction must be TU. + // FIXME: This could be relaxed, but we need to handle the policy for the + // resulting op correctly. if (IsTA) return false; - SDValue MergeOpN = N->getOperand(0); SDValue MergeOpTrue = True->getOperand(0); // Both the vmerge instruction and the True instruction must have the same - // merge operand. The vmerge instruction must have an all 1s mask since - // we're going to keep the mask from the True instruction. + // merge operand. + if (False != MergeOpTrue) + return false; + } + + if (IsMasked) { + assert(HasMergeOp && "Expected merge op"); + // The vmerge instruction must be TU. + if (IsTA) + return false; + // The vmerge instruction must have an all 1s mask since we're going to keep + // the mask from the True instruction. // FIXME: Support mask agnostic True instruction which would have an // undef merge operand. - if (MergeOpN != MergeOpTrue || !usesAllOnesMask(N, /* MaskOpIdx */ 3)) + if (!usesAllOnesMask(N, /* MaskOpIdx */ 3)) return false; } @@ -3269,13 +3291,6 @@ if (TII->get(TrueOpc).hasUnmodeledSideEffects()) return false; - const RISCV::RISCVMaskedPseudoInfo *Info = - HasMergeOp ? RISCV::getMaskedPseudoInfo(TrueOpc) - : RISCV::lookupMaskedIntrinsicByUnmaskedTA(TrueOpc); - - if (!Info) - return false; - // The last operand of a masked instruction may be glued. bool HasGlueOp = True->getGluedNode() != nullptr; @@ -3324,14 +3339,15 @@ "Expected instructions with mask have merge operand."); SmallVector Ops; - if (HasMergeOp) { + if (IsMasked) { Ops.append(True->op_begin(), True->op_begin() + TrueVLIndex); Ops.append({VL, /* SEW */ True.getOperand(TrueVLIndex + 1)}); Ops.push_back( CurDAG->getTargetConstant(Policy, DL, Subtarget->getXLenVT())); Ops.append(True->op_begin() + TrueVLIndex + 3, True->op_end()); } else { - Ops.push_back(False); + if (!HasMergeOp) + Ops.push_back(False); Ops.append(True->op_begin(), True->op_begin() + TrueVLIndex); Ops.append({Mask, VL, /* SEW */ True.getOperand(TrueVLIndex + 1)}); Ops.push_back( diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll @@ -7,6 +7,7 @@ declare @llvm.vp.fma.nxv1f32(, , , , i32) declare @llvm.vp.fneg.nxv1f32(, , i32) declare @llvm.vp.fpext.nxv1f32.nxv1f16(, , i32) +declare @llvm.vp.merge.nxv1f32(, , , i32) define @vfmacc_vv_nxv1f32( %a, %b, %c, %m, i32 zeroext %evl) { ; CHECK-LABEL: vfmacc_vv_nxv1f32: @@ -36,6 +37,56 @@ ret %v } +define @vfmacc_vv_nxv1f32_tu( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vv_nxv1f32_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmacc.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %c, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %c, i32 %evl) + ret %u +} + +; FIXME: Support this case? +define @vfmacc_vv_nxv1f32_masked__tu( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vv_nxv1f32_masked__tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, ta, ma +; CHECK-NEXT: vmv1r.v v11, v10 +; CHECK-NEXT: vfwmacc.vv v11, v8, v9, v0.t +; CHECK-NEXT: vsetvli zero, zero, e32, mf2, tu, ma +; CHECK-NEXT: vmerge.vvm v10, v10, v11, v0 +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %m, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %m, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %c, %m, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %c, i32 %evl) + ret %u +} + +define @vfmacc_vv_nxv1f32_unmasked_tu( %a, %b, %c, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vv_nxv1f32_unmasked_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma +; CHECK-NEXT: vfwmacc.vv v10, v8, v9 +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %c, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %allones, %v, %c, i32 %evl) + ret %u +} + define @vfmacc_vf_nxv1f32( %va, half %b, %vc, %m, i32 zeroext %evl) { ; CHECK-LABEL: vfmacc_vf_nxv1f32: ; CHECK: # %bb.0: @@ -83,6 +134,60 @@ ret %v } +define @vfmacc_vf_nxv1f32_tu( %va, half %b, %vc, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vf_nxv1f32_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmacc.vf v9, fa0, v8, v0.t +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %elt.head = insertelement poison, half %b, i32 0 + %vb = shufflevector %elt.head, poison, zeroinitializer + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %vaext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %va, %allones, i32 %evl) + %vbext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %vb, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %vaext, %vbext, %vc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %vc, i32 %evl) + ret %u +} + +define @vfmacc_vf_nxv1f32_commute_tu( %va, half %b, %vc, %m, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vf_nxv1f32_commute_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmacc.vf v9, fa0, v8, v0.t +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %elt.head = insertelement poison, half %b, i32 0 + %vb = shufflevector %elt.head, poison, zeroinitializer + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %vaext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %va, %allones, i32 %evl) + %vbext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %vb, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %vbext, %vaext, %vc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %vc, i32 %evl) + ret %u +} + +define @vfmacc_vf_nxv1f32_unmasked_tu( %va, half %b, %vc, i32 zeroext %evl) { +; CHECK-LABEL: vfmacc_vf_nxv1f32_unmasked_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma +; CHECK-NEXT: vfwmacc.vf v9, fa0, v8 +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %elt.head = insertelement poison, half %b, i32 0 + %vb = shufflevector %elt.head, poison, zeroinitializer + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %vaext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %va, %allones, i32 %evl) + %vbext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %vb, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %vaext, %vbext, %vc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %allones, %v, %vc, i32 %evl) + ret %u +} + declare @llvm.vp.fma.nxv2f32(, , , , i32) declare @llvm.vp.fneg.nxv2f32(, , i32) declare @llvm.vp.merge.nxv2f32(, , , i32) diff --git a/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vfwmsac-vp.ll @@ -7,6 +7,7 @@ declare @llvm.vp.fma.nxv1f32(, , , , i32) declare @llvm.vp.fneg.nxv1f32(, , i32) declare @llvm.vp.fpext.nxv1f32.nxv1f16(, , i32) +declare @llvm.vp.merge.nxv1f32(, , , i32) define @vmfsac_vv_nxv1f32( %a, %b, %c, %m, i32 zeroext %evl) { ; CHECK-LABEL: vmfsac_vv_nxv1f32: @@ -38,6 +39,40 @@ ret %v } +define @vmfsac_vv_nxv1f32_tu( %a, %b, %c, %m, i32 zeroext %evl) { +; CHECK-LABEL: vmfsac_vv_nxv1f32_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, mu +; CHECK-NEXT: vfwmsac.vv v10, v8, v9, v0.t +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %negc = call @llvm.vp.fneg.nxv1f32( %c, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %negc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %m, %v, %c, i32 %evl) + ret %u +} + +define @vmfsac_vv_nxv1f32_unmasked_tu( %a, %b, %c, i32 zeroext %evl) { +; CHECK-LABEL: vmfsac_vv_nxv1f32_unmasked_tu: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a0, e16, mf4, tu, ma +; CHECK-NEXT: vfwmsac.vv v10, v8, v9 +; CHECK-NEXT: vmv1r.v v8, v10 +; CHECK-NEXT: ret + %splat = insertelement poison, i1 -1, i32 0 + %allones = shufflevector %splat, poison, zeroinitializer + %aext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %a, %allones, i32 %evl) + %bext = call @llvm.vp.fpext.nxv1f32.nxv1f16( %b, %allones, i32 %evl) + %negc = call @llvm.vp.fneg.nxv1f32( %c, %allones, i32 %evl) + %v = call @llvm.vp.fma.nxv1f32( %aext, %bext, %negc, %allones, i32 %evl) + %u = call @llvm.vp.merge.nxv1f32( %allones, %v, %c, i32 %evl) + ret %u +} + define @vmfsac_vf_nxv1f32( %a, half %b, %c, %m, i32 zeroext %evl) { ; CHECK-LABEL: vmfsac_vf_nxv1f32: ; CHECK: # %bb.0: