diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -85,29 +85,42 @@ /// Node predicates - /// If N is a BUILD_VECTOR node whose elements are all the same constant or - /// undefined, return true and return the constant value in \p SplatValue. - bool isConstantSplatVector(const SDNode *N, APInt &SplatValue); - - /// Return true if the specified node is a BUILD_VECTOR where all of the - /// elements are ~0 or undef. - bool isBuildVectorAllOnes(const SDNode *N); - - /// Return true if the specified node is a BUILD_VECTOR where all of the - /// elements are 0 or undef. - bool isBuildVectorAllZeros(const SDNode *N); - - /// Return true if the specified node is a BUILD_VECTOR node of all - /// ConstantSDNode or undef. - bool isBuildVectorOfConstantSDNodes(const SDNode *N); - - /// Return true if the specified node is a BUILD_VECTOR node of all - /// ConstantFPSDNode or undef. - bool isBuildVectorOfConstantFPSDNodes(const SDNode *N); - - /// Return true if the node has at least one operand and all operands of the - /// specified node are ISD::UNDEF. - bool allOperandsUndef(const SDNode *N); +/// If N is a BUILD_VECTOR or SPLAT_VECTOR node whose elements are all the +/// same constant or undefined, return true and return the constant value in +/// \p SplatValue. +bool isConstantSplatVector(const SDNode *N, APInt &SplatValue); + +/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where +/// all of the elements are ~0 or undef. If \p BuildVectorOnly is set to +/// true, it only checks BUILD_VECTOR. +bool isConstantSplatVectorAllOnes(const SDNode *N, + bool BuildVectorOnly = false); + +/// Return true if the specified node is a BUILD_VECTOR or SPLAT_VECTOR where +/// all of the elements are 0 or undef. If \p BuildVectorOnly is set to true, it +/// only checks BUILD_VECTOR. +bool isConstantSplatVectorAllZeros(const SDNode *N, + bool BuildVectorOnly = false); + +/// Return true if the specified node is a BUILD_VECTOR where all of the +/// elements are ~0 or undef. +bool isBuildVectorAllOnes(const SDNode *N); + +/// Return true if the specified node is a BUILD_VECTOR where all of the +/// elements are 0 or undef. +bool isBuildVectorAllZeros(const SDNode *N); + +/// Return true if the specified node is a BUILD_VECTOR node of all +/// ConstantSDNode or undef. +bool isBuildVectorOfConstantSDNodes(const SDNode *N); + +/// Return true if the specified node is a BUILD_VECTOR node of all +/// ConstantFPSDNode or undef. +bool isBuildVectorOfConstantFPSDNodes(const SDNode *N); + +/// Return true if the node has at least one operand and all operands of the +/// specified node are ISD::UNDEF. +bool allOperandsUndef(const SDNode *N); } // end namespace ISD diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -909,11 +909,13 @@ def vtInt : PatLeaf<(vt), [{ return N->getVT().isInteger(); }]>; def vtFP : PatLeaf<(vt), [{ return N->getVT().isFloatingPoint(); }]>; -// Use ISD::isBuildVectorAllOnes or ISD::isBuildVectorAllZeros to look for -// the corresponding build_vector. Will look through bitcasts except when used -// as a pattern root. -def immAllOnesV; // ISD::isBuildVectorAllOnes -def immAllZerosV; // ISD::isBuildVectorAllZeros +// Use ISD::isConstantSplatVectorAllOnes or ISD::isConstantSplatVectorAllZeros +// to look for the corresponding build_vector or splat_vector. Will look through +// bitcasts and check for either opcode, except when used as a pattern root. +// When used as a pattern root, only fixed-length build_vector and scalable +// splat_vector are supported. +def immAllOnesV; // ISD::isConstantSplatVectorAllOnes +def immAllZerosV; // ISD::isConstantSplatVectorAllZeros // Other helper fragments. def not : PatFrag<(ops node:$in), (xor node:$in, -1)>; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -164,11 +164,16 @@ // FIXME: AllOnes and AllZeros duplicate a lot of code. Could these be // specializations of the more general isConstantSplatVector()? -bool ISD::isBuildVectorAllOnes(const SDNode *N) { +bool ISD::isConstantSplatVectorAllOnes(const SDNode *N, bool BuildVectorOnly) { // Look through a bit convert. while (N->getOpcode() == ISD::BITCAST) N = N->getOperand(0).getNode(); + if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) { + APInt SplatVal; + return isConstantSplatVector(N, SplatVal) && SplatVal.isAllOnesValue(); + } + if (N->getOpcode() != ISD::BUILD_VECTOR) return false; unsigned i = 0, e = N->getNumOperands(); @@ -208,11 +213,16 @@ return true; } -bool ISD::isBuildVectorAllZeros(const SDNode *N) { +bool ISD::isConstantSplatVectorAllZeros(const SDNode *N, bool BuildVectorOnly) { // Look through a bit convert. while (N->getOpcode() == ISD::BITCAST) N = N->getOperand(0).getNode(); + if (!BuildVectorOnly && N->getOpcode() == ISD::SPLAT_VECTOR) { + APInt SplatVal; + return isConstantSplatVector(N, SplatVal) && SplatVal.isNullValue(); + } + if (N->getOpcode() != ISD::BUILD_VECTOR) return false; bool IsAllUndef = true; @@ -245,6 +255,14 @@ return true; } +bool ISD::isBuildVectorAllOnes(const SDNode *N) { + return isConstantSplatVectorAllOnes(N, /*BuildVectorOnly*/ true); +} + +bool ISD::isBuildVectorAllZeros(const SDNode *N) { + return isConstantSplatVectorAllZeros(N, /*BuildVectorOnly*/ true); +} + bool ISD::isBuildVectorOfConstantSDNodes(const SDNode *N) { if (N->getOpcode() != ISD::BUILD_VECTOR) return false; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3202,10 +3202,12 @@ if (!::CheckOrImm(MatcherTable, MatcherIndex, N, *this)) break; continue; case OPC_CheckImmAllOnesV: - if (!ISD::isBuildVectorAllOnes(N.getNode())) break; + if (!ISD::isConstantSplatVectorAllOnes(N.getNode())) + break; continue; case OPC_CheckImmAllZerosV: - if (!ISD::isBuildVectorAllZeros(N.getNode())) break; + if (!ISD::isConstantSplatVectorAllZeros(N.getNode())) + break; continue; case OPC_CheckFoldableChainNode: { diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -35,11 +35,6 @@ def SplatPat_simm5 : ComplexPattern; def SplatPat_uimm5 : ComplexPattern; -// A mask-vector version of the standard 'vnot' fragment but using splat_vector -// rather than (the implicit) build_vector -def riscv_m_vnot : PatFrag<(ops node:$in), - (xor node:$in, (splat_vector (XLenVT 1)))>; - multiclass VPatUSLoadStoreSDNode("PseudoVMXOR_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (riscv_m_vnot (and VR:$rs1, VR:$rs2))), + def : Pat<(mti.Mask (vnot (and VR:$rs1, VR:$rs2))), (!cast("PseudoVMNAND_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (riscv_m_vnot (or VR:$rs1, VR:$rs2))), + def : Pat<(mti.Mask (vnot (or VR:$rs1, VR:$rs2))), (!cast("PseudoVMNOR_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (riscv_m_vnot (xor VR:$rs1, VR:$rs2))), + def : Pat<(mti.Mask (vnot (xor VR:$rs1, VR:$rs2))), (!cast("PseudoVMXNOR_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (and VR:$rs1, (riscv_m_vnot VR:$rs2))), + def : Pat<(mti.Mask (and VR:$rs1, (vnot VR:$rs2))), (!cast("PseudoVMANDNOT_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; - def : Pat<(mti.Mask (or VR:$rs1, (riscv_m_vnot VR:$rs2))), + def : Pat<(mti.Mask (or VR:$rs1, (vnot VR:$rs2))), (!cast("PseudoVMORNOT_MM_"#mti.LMul.MX) VR:$rs1, VR:$rs2, VLMax, mti.SEW)>; } @@ -233,9 +228,9 @@ } foreach mti = AllMasks in { - def : Pat<(mti.Mask (splat_vector (XLenVT 1))), + def : Pat<(mti.Mask immAllOnesV), (!cast("PseudoVMSET_M_"#mti.BX) VLMax, mti.SEW)>; - def : Pat<(mti.Mask (splat_vector (XLenVT 0))), + def : Pat<(mti.Mask immAllZerosV), (!cast("PseudoVMCLR_M_"#mti.BX) VLMax, mti.SEW)>; } } // Predicates = [HasStdExtV] diff --git a/llvm/utils/TableGen/DAGISelMatcher.h b/llvm/utils/TableGen/DAGISelMatcher.h --- a/llvm/utils/TableGen/DAGISelMatcher.h +++ b/llvm/utils/TableGen/DAGISelMatcher.h @@ -763,8 +763,8 @@ } }; -/// CheckImmAllOnesVMatcher - This check if the current node is an build vector -/// of all ones. +/// CheckImmAllOnesVMatcher - This checks if the current node is a build_vector +/// or splat_vector of all ones. class CheckImmAllOnesVMatcher : public Matcher { public: CheckImmAllOnesVMatcher() : Matcher(CheckImmAllOnesV) {} @@ -779,8 +779,8 @@ bool isContradictoryImpl(const Matcher *M) const override; }; -/// CheckImmAllZerosVMatcher - This check if the current node is an build vector -/// of all zeros. +/// CheckImmAllZerosVMatcher - This checks if the current node is a +/// build_vector or splat_vector of all zeros. class CheckImmAllZerosVMatcher : public Matcher { public: CheckImmAllZerosVMatcher() : Matcher(CheckImmAllZerosV) {} diff --git a/llvm/utils/TableGen/DAGISelMatcherGen.cpp b/llvm/utils/TableGen/DAGISelMatcherGen.cpp --- a/llvm/utils/TableGen/DAGISelMatcherGen.cpp +++ b/llvm/utils/TableGen/DAGISelMatcherGen.cpp @@ -282,7 +282,9 @@ // check to ensure that this gets folded into the normal top-level // OpcodeSwitch. if (N == Pattern.getSrcPattern()) { - const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector")); + MVT VT = N->getSimpleType(0); + StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector"; + const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name)); AddMatcher(new CheckOpcodeMatcher(NI)); } return AddMatcher(new CheckImmAllOnesVMatcher()); @@ -292,7 +294,9 @@ // check to ensure that this gets folded into the normal top-level // OpcodeSwitch. if (N == Pattern.getSrcPattern()) { - const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed("build_vector")); + MVT VT = N->getSimpleType(0); + StringRef Name = VT.isScalableVector() ? "splat_vector" : "build_vector"; + const SDNodeInfo &NI = CGP.getSDNodeInfo(CGP.getSDNodeNamed(Name)); AddMatcher(new CheckOpcodeMatcher(NI)); } return AddMatcher(new CheckImmAllZerosVMatcher());