diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3157,37 +3157,42 @@ const RISCVInstrInfo &TII = *Subtarget->getInstrInfo(); const MCInstrDesc &MaskedMCID = TII.get(N->getMachineOpcode()); - bool IsTA = true; + bool UseTUPseudo = false; if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) { - TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID); - if (!(N->getConstantOperandVal(*TailPolicyOpIdx) & - RISCVII::TAIL_AGNOSTIC)) { - // Keep the true-masked instruction when there is no unmasked TU - // instruction - if (I->UnmaskedTUPseudo == I->MaskedPseudo && !N->getOperand(0).isUndef()) - return false; - // We can't use TA if the tie-operand is not IMPLICIT_DEF - if (!N->getOperand(0).isUndef()) - IsTA = false; + // Some operations are their own TU. + if (I->UnmaskedTUPseudo == I->UnmaskedPseudo) { + UseTUPseudo = true; + } else { + TailPolicyOpIdx = getVecPolicyOpIdx(N, MaskedMCID); + if (!(N->getConstantOperandVal(*TailPolicyOpIdx) & + RISCVII::TAIL_AGNOSTIC)) { + // We can't use TA if the tie-operand is not IMPLICIT_DEF + if (!N->getOperand(0).isUndef()) { + // Keep the true-masked instruction when there is no unmasked TU + // instruction + if (I->UnmaskedTUPseudo == I->MaskedPseudo) + return false; + UseTUPseudo = true; + } + } } } - unsigned Opc = IsTA ? I->UnmaskedPseudo : I->UnmaskedTUPseudo; + unsigned Opc = UseTUPseudo ? I->UnmaskedTUPseudo : I->UnmaskedPseudo; // Check that we're dropping the mask operand and any policy operand // when we transform to this unmasked pseudo. Additionally, if this // instruction is tail agnostic, the unmasked instruction should not have a // merge op. uint64_t TSFlags = TII.get(Opc).TSFlags; - assert((IsTA != RISCVII::hasMergeOp(TSFlags)) && + assert((UseTUPseudo == RISCVII::hasMergeOp(TSFlags)) && RISCVII::hasDummyMaskOp(TSFlags) && - !RISCVII::hasVecPolicyOp(TSFlags) && "Unexpected pseudo to transform to"); (void)TSFlags; SmallVector Ops; - // Skip the merge operand at index 0 if IsTA - for (unsigned I = IsTA, E = N->getNumOperands(); I != E; I++) { + // Skip the merge operand at index 0 if !UseTUPseudo. + for (unsigned I = !UseTUPseudo, E = N->getNumOperands(); I != E; I++) { // Skip the mask, the policy, and the Glue. SDValue Op = N->getOperand(I); if (I == MaskOpIdx || I == TailPolicyOpIdx || diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -472,10 +472,12 @@ let PrimaryKeyName = "getRISCVVIntrinsicInfo"; } -class RISCVMaskedPseudo MaskIdx, bit HasTU = true> { +class RISCVMaskedPseudo MaskIdx, bit HasTU = true, bit IsTernary = false> { Pseudo MaskedPseudo = !cast(NAME); Pseudo UnmaskedPseudo = !cast(!subst("_MASK", "", NAME)); - Pseudo UnmaskedTUPseudo = !if(HasTU, !cast(!subst("_MASK", "", NAME # "_TU")), MaskedPseudo); + Pseudo UnmaskedTUPseudo = !cond(HasTU : !cast(!subst("_MASK", "", NAME # "_TU")), + IsTernary : UnmaskedPseudo, + true : MaskedPseudo); bits<4> MaskOpIdx = MaskIdx; } @@ -3192,7 +3194,8 @@ let VLMul = MInfo.value in { let isCommutable = Commutable in def "_" # MInfo.MX : VPseudoTernaryNoMaskWithPolicy; - def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy; + def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicy, + RISCVMaskedPseudo; } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -1459,12 +1459,6 @@ multiclass VPatFPMulAddVL_VV_VF { foreach vti = AllFloatVectors in { defvar suffix = vti.LMul.MX; - def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd, - vti.RegClass:$rs2, (vti.Mask true_mask), - VLOpFrag)), - (!cast(instruction_name#"_VV_"# suffix) - vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; def : Pat<(vti.Vector (vop vti.RegClass:$rs1, vti.RegClass:$rd, vti.RegClass:$rs2, (vti.Mask V0), VLOpFrag)), @@ -1472,13 +1466,6 @@ vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), - vti.RegClass:$rd, vti.RegClass:$rs2, - (vti.Mask true_mask), - VLOpFrag)), - (!cast(instruction_name#"_V" # vti.ScalarSuffix # "_" # suffix) - vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; def : Pat<(vti.Vector (vop (SplatFPOp vti.ScalarRegClass:$rs1), vti.RegClass:$rd, vti.RegClass:$rs2, (vti.Mask V0),