diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -117,6 +117,7 @@ private: bool doPeepholeLoadStoreADDI(SDNode *Node); bool doPeepholeSExtW(SDNode *Node); + bool doPeepholeMaskedRVV(SDNode *Node); }; namespace RISCV { @@ -187,6 +188,12 @@ uint16_t Pseudo; }; +struct RISCVMaskedPseudoInfo { + uint16_t MaskedPseudo; + uint16_t UnmaskedPseudo; + uint8_t MaskOpIdx; +}; + #define GET_RISCVVSSEGTable_DECL #define GET_RISCVVLSEGTable_DECL #define GET_RISCVVLXSEGTable_DECL @@ -195,6 +202,7 @@ #define GET_RISCVVSETable_DECL #define GET_RISCVVLXTable_DECL #define GET_RISCVVSXTable_DECL +#define GET_RISCVMaskedPseudosTable_DECL #include "RISCVGenSearchableTables.inc" } // namespace RISCV diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -37,6 +37,7 @@ #define GET_RISCVVSETable_IMPL #define GET_RISCVVLXTable_IMPL #define GET_RISCVVSXTable_IMPL +#define GET_RISCVMaskedPseudosTable_IMPL #include "RISCVGenSearchableTables.inc" } // namespace RISCV } // namespace llvm @@ -123,6 +124,7 @@ MadeChange |= doPeepholeSExtW(N); MadeChange |= doPeepholeLoadStoreADDI(N); + MadeChange |= doPeepholeMaskedRVV(N); } if (MadeChange) @@ -2139,6 +2141,99 @@ return false; } +// Optimize masked RVV pseudo instructions with a known all-ones mask to their +// corresponding "unmasked" pseudo versions. The mask we're interested in will +// take the form of a V0 physical register operand, with a glued +// register-setting instruction. +bool RISCVDAGToDAGISel::doPeepholeMaskedRVV(SDNode *N) { + const RISCV::RISCVMaskedPseudoInfo *I = + RISCV::getMaskedPseudoInfo(N->getMachineOpcode()); + if (!I) + return false; + + unsigned MaskOpIdx = I->MaskOpIdx; + + // Check that we're using V0 as a mask register. + if (!isa(N->getOperand(MaskOpIdx)) || + cast(N->getOperand(MaskOpIdx))->getReg() != RISCV::V0) + return false; + + // The glued user defines V0. + const auto *Glued = N->getGluedNode(); + + if (!Glued || Glued->getOpcode() != ISD::CopyToReg) + return false; + + // Check that we're defining V0 as a mask register. + if (!isa(Glued->getOperand(1)) || + cast(Glued->getOperand(1))->getReg() != RISCV::V0) + return false; + + // Check the instruction defining V0; it needs to be a VMSET pseudo. + SDValue MaskSetter = Glued->getOperand(2); + + const auto IsVMSet = [](unsigned Opc) { + return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 || + Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 || + Opc == RISCV::PseudoVMSET_M_B4 || Opc == RISCV::PseudoVMSET_M_B64 || + Opc == RISCV::PseudoVMSET_M_B8; + }; + + // TODO: Check that the VMSET is the expected bitwidth? The pseudo has + // undefined behaviour if it's the wrong bitwidth, so we could choose to + // assume that it's all-ones? Same applies to its VL. + if (!MaskSetter->isMachineOpcode() || + !IsVMSet(MaskSetter.getMachineOpcode())) + return false; + + // Retrieve the tail policy operand index, if any. + Optional TailPolicyOpIdx; + const RISCVInstrInfo *TII = static_cast( + CurDAG->getSubtarget().getInstrInfo()); + + const MCInstrDesc &MaskedMCID = TII->get(N->getMachineOpcode()); + + if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags)) { + // The last operand of the pseudo is the policy op, but we're expecting a + // Glue operand last. We may also have a chain. + TailPolicyOpIdx = N->getNumOperands() - 1; + if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Glue) + (*TailPolicyOpIdx)--; + if (N->getOperand(*TailPolicyOpIdx).getValueType() == MVT::Other) + (*TailPolicyOpIdx)--; + } + + const MCInstrDesc &UnmaskedMCID = TII->get(I->UnmaskedPseudo); + + // Check that we're dropping the merge operand, the mask operand, and any + // policy operand when we transform to this unmasked pseudo. + assert(!RISCVII::hasMergeOp(UnmaskedMCID.TSFlags) && + RISCVII::hasDummyMaskOp(UnmaskedMCID.TSFlags) && + !RISCVII::hasVecPolicyOp(UnmaskedMCID.TSFlags) && + "Unexpected pseudo to transform to"); + + SmallVector Ops; + // Skip the merge operand at index 0. + for (unsigned I = 1, E = N->getNumOperands(); I != E; I++) { + // Skip the mask, the policy, and the Glue. + SDValue Op = N->getOperand(I); + if (I == MaskOpIdx || I == TailPolicyOpIdx || + Op.getValueType() == MVT::Glue) + continue; + Ops.push_back(Op); + } + + // Transitively apply any node glued to our new node. + if (auto *TGlued = Glued->getGluedNode()) + Ops.push_back(SDValue(TGlued, TGlued->getNumValues() - 1)); + + SDNode *Result = + CurDAG->getMachineNode(I->UnmaskedPseudo, SDLoc(N), N->getVTList(), Ops); + ReplaceUses(N, Result); + + return true; +} + // This pass converts a legalized DAG into a RISCV-specific DAG, ready // for instruction scheduling. FunctionPass *llvm::createRISCVISelDag(RISCVTargetMachine &TM) { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -424,6 +424,20 @@ let PrimaryKeyName = "getRISCVVIntrinsicInfo"; } +class RISCVMaskedPseudo MaskIdx> { + Pseudo MaskedPseudo = !cast(NAME); + Pseudo UnmaskedPseudo = !cast(!subst("_MASK", "", NAME)); + bits<4> MaskOpIdx = MaskIdx; +} + +def RISCVMaskedPseudosTable : GenericTable { + let FilterClass = "RISCVMaskedPseudo"; + let CppTypeName = "RISCVMaskedPseudoInfo"; + let Fields = ["MaskedPseudo", "UnmaskedPseudo", "MaskOpIdx"]; + let PrimaryKey = ["MaskedPseudo"]; + let PrimaryKeyName = "getMaskedPseudoInfo"; +} + class RISCVVLE S, bits<3> L> { bits<1> Masked = M; bits<1> IsTU = TU; @@ -1639,7 +1653,8 @@ def "_" # MInfo.MX : VPseudoBinaryNoMask; def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskTA; + Constraint>, + RISCVMaskedPseudo; } } diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -275,15 +275,6 @@ int sew, LMULInfo vlmul, VReg op_reg_class> { - def : Pat<(result_type (vop - (op_type op_reg_class:$rs1), - (op_type op_reg_class:$rs2), - (mask_type true_mask), - VLOpFrag)), - (!cast(instruction_name#"_VV_"# vlmul.MX) - op_reg_class:$rs1, - op_reg_class:$rs2, - GPR:$vl, sew)>; def : Pat<(result_type (vop (op_type op_reg_class:$rs1), (op_type op_reg_class:$rs2), @@ -307,15 +298,6 @@ VReg vop_reg_class, ComplexPattern SplatPatKind, DAGOperand xop_kind> { - def : Pat<(result_type (vop - (vop_type vop_reg_class:$rs1), - (vop_type (SplatPatKind (XLenVT xop_kind:$rs2))), - (mask_type true_mask), - VLOpFrag)), - (!cast(instruction_name#_#suffix#_# vlmul.MX) - vop_reg_class:$rs1, - xop_kind:$rs2, - GPR:$vl, sew)>; def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1), (vop_type (SplatPatKind (XLenVT xop_kind:$rs2))), @@ -373,14 +355,6 @@ LMULInfo vlmul, VReg vop_reg_class, RegisterClass scalar_reg_class> { - def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1), - (vop_type (SplatFPOp scalar_reg_class:$rs2)), - (mask_type true_mask), - VLOpFrag)), - (!cast(instruction_name#"_"#vlmul.MX) - vop_reg_class:$rs1, - scalar_reg_class:$rs2, - GPR:$vl, sew)>; def : Pat<(result_type (vop (vop_type vop_reg_class:$rs1), (vop_type (SplatFPOp scalar_reg_class:$rs2)), (mask_type V0), @@ -405,13 +379,6 @@ multiclass VPatBinaryFPVL_R_VF { foreach fvti = AllFloatVectors in { - def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2), - fvti.RegClass:$rs1, - (fvti.Mask true_mask), - VLOpFrag)), - (!cast(instruction_name#"_V"#fvti.ScalarSuffix#"_"#fvti.LMul.MX) - fvti.RegClass:$rs1, fvti.ScalarRegClass:$rs2, - GPR:$vl, fvti.Log2SEW)>; def : Pat<(fvti.Vector (vop (SplatFPOp fvti.ScalarRegClass:$rs2), fvti.RegClass:$rs1, (fvti.Mask V0), @@ -698,22 +665,12 @@ // Handle VRSUB specially since it's the only integer binary op with reversed // pattern operands foreach vti = AllIntegerVectors in { - def : Pat<(riscv_sub_vl (vti.Vector (SplatPat (XLenVT GPR:$rs2))), - (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), - VLOpFrag), - (!cast("PseudoVRSUB_VX_"# vti.LMul.MX) - vti.RegClass:$rs1, GPR:$rs2, GPR:$vl, vti.Log2SEW)>; def : Pat<(riscv_sub_vl (vti.Vector (SplatPat (XLenVT GPR:$rs2))), (vti.Vector vti.RegClass:$rs1), (vti.Mask V0), VLOpFrag), (!cast("PseudoVRSUB_VX_"# vti.LMul.MX#"_MASK") (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, GPR:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(riscv_sub_vl (vti.Vector (SplatPat_simm5 simm5:$rs2)), - (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), - VLOpFrag), - (!cast("PseudoVRSUB_VI_"# vti.LMul.MX) - vti.RegClass:$rs1, simm5:$rs2, GPR:$vl, vti.Log2SEW)>; def : Pat<(riscv_sub_vl (vti.Vector (SplatPat_simm5 simm5:$rs2)), (vti.Vector vti.RegClass:$rs1), (vti.Mask V0), VLOpFrag),