diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -191,6 +191,7 @@ struct RISCVMaskedPseudoInfo { uint16_t MaskedPseudo; uint16_t UnmaskedPseudo; + uint16_t UnmaskedTuPseudo; uint8_t MaskOpIdx; }; diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -2242,6 +2242,7 @@ const MCInstrDesc &MaskedMCID = TII->get(N->getMachineOpcode()); + bool IsTU = false; if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) { // The last operand of the pseudo is the policy op, but we're expecting a // Glue operand last. We may also have a chain. @@ -2251,27 +2252,46 @@ if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Other) (*TailPolicyOpIdx)--; - // If the policy isn't TAIL_AGNOSTIC we can't perform this optimization. - if (N->getConstantOperandVal(*TailPolicyOpIdx) != RISCVII::TAIL_AGNOSTIC) - return false; + // If the merge operand is undef when policy is TAIL_UNDISTURBED, + // we can't perform this optimization. + if (N->getConstantOperandVal(*TailPolicyOpIdx) == + RISCVII::TAIL_UNDISTURBED) { + if (N->getOperand(0).isUndef()) + return false; + IsTU = true; + } } - const MCInstrDesc &UnmaskedMCID = TII->get(I->UnmaskedPseudo); + if (IsTU) { + const MCInstrDesc &UnmaskedTuMCID = TII->get(I->UnmaskedTuPseudo); - // Check that we're dropping the merge operand, the mask operand, and any - // policy operand when we transform to this unmasked pseudo. - assert(!RISCVII::hasMergeOp(UnmaskedMCID.TSFlags) && - RISCVII::hasDummyMaskOp(UnmaskedMCID.TSFlags) && - !RISCVII::hasVecPolicyOp(UnmaskedMCID.TSFlags) && - "Unexpected pseudo to transform to"); - (void)UnmaskedMCID; + // Check that we're dropping the mask operand, and any policy operand + // when we transform to this unmasked tu pseudo. + assert(RISCVII::hasMergeOp(UnmaskedTuMCID.TSFlags) && + RISCVII::hasDummyMaskOp(UnmaskedTuMCID.TSFlags) && + !RISCVII::hasVecPolicyOp(UnmaskedTuMCID.TSFlags) && + "Unexpected pseudo to transform to"); + (void)UnmaskedTuMCID; + } else { + const MCInstrDesc &UnmaskedMCID = TII->get(I->UnmaskedPseudo); + + // Check that we're dropping the merge operand, the mask operand, and any + // policy operand when we transform to this unmasked pseudo. + assert(!RISCVII::hasMergeOp(UnmaskedMCID.TSFlags) && + RISCVII::hasDummyMaskOp(UnmaskedMCID.TSFlags) && + !RISCVII::hasVecPolicyOp(UnmaskedMCID.TSFlags) && + "Unexpected pseudo to transform to"); + (void)UnmaskedMCID; + } + unsigned Opc = IsTU ? I->UnmaskedTuPseudo : I->UnmaskedPseudo; + unsigned Id = IsTU ? 0 : 1; SmallVector Ops; // Skip the merge operand at index 0. - for (unsigned I = 1, E = N->getNumOperands(); I != E; I++) { + for (unsigned E = N->getNumOperands(); Id != E; Id++) { // Skip the mask, the policy, and the Glue. - SDValue Op = N->getOperand(I); - if (I == MaskOpIdx || I == TailPolicyOpIdx || + SDValue Op = N->getOperand(Id); + if (Id == MaskOpIdx || Id == TailPolicyOpIdx || Op.getValueType() == MVT::Glue) continue; Ops.push_back(Op); @@ -2281,8 +2301,7 @@ if (auto *TGlued = Glued->getGluedNode()) Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1)); - SDNode *Result = - CurDAG->getMachineNode(I->UnmaskedPseudo, SDLoc(N), N->getVTList(), Ops); + SDNode *Result = CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops); ReplaceUses(N, Result); return true; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -4770,11 +4770,12 @@ SDValue VL = Op.getOperand(NumOperands - 1); MVT MaskVT = MVT::getVectorVT(MVT::i1, VT.getVectorElementCount()); SDValue TrueMask = DAG.getNode(RISCVISD::VMSET_VL, DL, MaskVT, VL); + unsigned PolicyImm = Op.getOperand(0).isUndef() ? RISCVII::TAIL_AGNOSTIC + : RISCVII::TAIL_UNDISTURBED; + SDValue Policy = DAG.getTargetConstant(PolicyImm, DL, XLenVT); Ops.push_back(TrueMask); Ops.push_back(VL); - // Since unmasked intrinsics and pseudos have no policy operand, - // we use 0 here for pattern matching. - Ops.push_back(DAG.getConstant(0, DL, XLenVT)); // Policy + Ops.push_back(Policy); } return DAG.getNode(Opc, DL, VTs, Ops); } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -426,13 +426,14 @@ class RISCVMaskedPseudo MaskIdx> { Pseudo MaskedPseudo = !cast(NAME); Pseudo UnmaskedPseudo = !cast(!subst("_MASK", "", NAME)); + Pseudo UnmaskedTuPseudo = !cast(!subst("_MASK", "", NAME # "_TU")); bits<4> MaskOpIdx = MaskIdx; } def RISCVMaskedPseudosTable : GenericTable { let FilterClass = "RISCVMaskedPseudo"; let CppTypeName = "RISCVMaskedPseudoInfo"; - let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx"]; + let Fields = ["MaskedPseudo", "UnmaskedPseudo", "UnmaskedTuPseudo", "MaskOpIdx"]; let PrimaryKey = ["MaskedPseudo"]; let PrimaryKeyName = "getMaskedPseudoInfo"; } @@ -2656,7 +2657,8 @@ def "_" # MInfo.MX : VPseudoUnaryNoMask; def "_" # MInfo.MX # "_TU": VPseudoUnaryNoMaskTU; def "_" # MInfo.MX # "_MASK" : VPseudoUnaryMaskTA; + Constraint>, + RISCVMaskedPseudo; } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -621,21 +621,6 @@ LMULInfo vlmul, VReg result_reg_class, VReg op2_reg_class> { - def : Pat<(result_type (vop (result_type undef), - (op2_type op2_reg_class:$rs2), - (mask_type true_mask), - VLOpFrag, (XLenVT 0))), - (!cast(inst#"_"#kind#"_"#vlmul.MX) - (op2_type op2_reg_class:$rs2), - GPR:$vl, sew)>; - def : Pat<(result_type (vop (result_type result_reg_class:$merge), - (op2_type op2_reg_class:$rs2), - (mask_type true_mask), - VLOpFrag, (XLenVT 0))), - (!cast(inst#"_"#kind#"_"#vlmul.MX#"_TU") - (result_type result_reg_class:$merge), - (op2_type op2_reg_class:$rs2), - GPR:$vl, sew)>; def : Pat<(result_type (vop (result_type result_reg_class:$merge), (op2_type op2_reg_class:$rs2), (mask_type V0),