diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.h b/llvm/utils/TableGen/CodeGenDAGPatterns.h --- a/llvm/utils/TableGen/CodeGenDAGPatterns.h +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.h @@ -1047,92 +1047,43 @@ TreePatternNodePtr getResultPattern() const { return ResultPattern; } }; -/// This class represents a condition that has to be satisfied for a pattern -/// to be tried. It is a generalization of a class "Pattern" from Target.td: -/// in addition to the Target.td's predicates, this class can also represent -/// conditions associated with HW modes. Both types will eventually become -/// strings containing C++ code to be executed, the difference is in how -/// these strings are generated. -class Predicate { -public: - Predicate(Record *R, bool C = true) : Def(R), IfCond(C), IsHwMode(false) { - assert(R->isSubClassOf("Predicate") && - "Predicate objects should only be created for records derived" - "from Predicate class"); - } - Predicate(StringRef FS, bool C = true) : Def(nullptr), Features(FS.str()), - IfCond(C), IsHwMode(true) {} - - /// Return a string which contains the C++ condition code that will serve - /// as a predicate during instruction selection. - std::string getCondString() const { - // The string will excute in a subclass of SelectionDAGISel. - // Cast to std::string explicitly to avoid ambiguity with StringRef. - std::string C = IsHwMode - ? std::string("MF->getSubtarget().checkFeatures(\"" + - Features + "\")") - : std::string(Def->getValueAsString("CondString")); - if (C.empty()) - return ""; - return IfCond ? C : "!("+C+')'; - } - - bool operator==(const Predicate &P) const { - return IfCond == P.IfCond && IsHwMode == P.IsHwMode && Def == P.Def && - Features == P.Features; - } - bool operator<(const Predicate &P) const { - if (IsHwMode != P.IsHwMode) - return IsHwMode < P.IsHwMode; - assert(!Def == !P.Def && "Inconsistency between Def and IsHwMode"); - if (IfCond != P.IfCond) - return IfCond < P.IfCond; - if (Def) - return LessRecord()(Def, P.Def); - return Features < P.Features; - } - Record *Def; ///< Predicate definition from .td file, null for - ///< HW modes. - std::string Features; ///< Feature string for HW mode. - bool IfCond; ///< The boolean value that the condition has to - ///< evaluate to for this predicate to be true. - bool IsHwMode; ///< Does this predicate correspond to a HW mode? -}; - /// PatternToMatch - Used by CodeGenDAGPatterns to keep tab of patterns /// processed to produce isel. class PatternToMatch { Record *SrcRecord; // Originating Record for the pattern. + ListInit *Predicates; // Top level predicate conditions to match. TreePatternNodePtr SrcPattern; // Source pattern to match. TreePatternNodePtr DstPattern; // Resulting pattern. - std::vector Predicates; // Top level predicate conditions - // to match. std::vector Dstregs; // Physical register defs being matched. + std::string HwModeFeatures; int AddedComplexity; // Add to matching pattern complexity. unsigned ID; // Unique ID for the record. unsigned ForceMode; // Force this mode in type inference when set. public: - PatternToMatch(Record *srcrecord, std::vector preds, - TreePatternNodePtr src, TreePatternNodePtr dst, - std::vector dstregs, int complexity, - unsigned uid, unsigned setmode = 0) - : SrcRecord(srcrecord), SrcPattern(src), DstPattern(dst), - Predicates(std::move(preds)), Dstregs(std::move(dstregs)), - AddedComplexity(complexity), ID(uid), ForceMode(setmode) {} + PatternToMatch(Record *srcrecord, ListInit *preds, TreePatternNodePtr src, + TreePatternNodePtr dst, std::vector dstregs, + int complexity, unsigned uid, unsigned setmode = 0, + std::string hwmodefeatures = std::string()) + : SrcRecord(srcrecord), Predicates(preds), SrcPattern(src), + DstPattern(dst), Dstregs(std::move(dstregs)), + HwModeFeatures(std::move(hwmodefeatures)), AddedComplexity(complexity), + ID(uid), ForceMode(setmode) {} Record *getSrcRecord() const { return SrcRecord; } + ListInit *getPredicates() const { return Predicates; } TreePatternNode *getSrcPattern() const { return SrcPattern.get(); } TreePatternNodePtr getSrcPatternShared() const { return SrcPattern; } TreePatternNode *getDstPattern() const { return DstPattern.get(); } TreePatternNodePtr getDstPatternShared() const { return DstPattern; } const std::vector &getDstRegs() const { return Dstregs; } + const std::string &getHwModeFeatures() const { return HwModeFeatures; } int getAddedComplexity() const { return AddedComplexity; } - const std::vector &getPredicates() const { return Predicates; } unsigned getID() const { return ID; } unsigned getForceMode() const { return ForceMode; } std::string getPredicateCheck() const; + SmallVector getPredicateRecords() const; /// Compute the complexity metric for the input pattern. This roughly /// corresponds to the number of nodes that are covered. @@ -1290,8 +1241,6 @@ void GenerateVariants(); void VerifyInstructionFlags(); - std::vector makePredList(ListInit *L); - void ParseOnePattern(Record *TheDef, TreePattern &Pattern, TreePattern &Result, const std::vector &InstImpResults); diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp --- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp @@ -1433,24 +1433,51 @@ return getPatternSize(getSrcPattern(), CGP) + getAddedComplexity(); } +SmallVector PatternToMatch::getPredicateRecords() const { + SmallVector PredicateRecs; + for (Init *I : Predicates->getValues()) { + if (DefInit *Pred = dyn_cast(I)) { + Record *Def = Pred->getDef(); + if (!Def->isSubClassOf("Predicate")) { +#ifndef NDEBUG + Def->dump(); +#endif + llvm_unreachable("Unknown predicate type!"); + } + PredicateRecs.push_back(Def); + } + } + // Sort so that different orders get canonicalized to the same string. + std::sort(PredicateRecs.begin(), PredicateRecs.end(), LessRecord()); + + return PredicateRecs; +} + /// getPredicateCheck - Return a single string containing all of this /// pattern's predicates concatenated with "&&" operators. /// std::string PatternToMatch::getPredicateCheck() const { - SmallVector PredList; - for (const Predicate &P : Predicates) { - if (!P.getCondString().empty()) - PredList.push_back(&P); + SmallString<128> PredicateCheck; + for (Record *Pred : getPredicateRecords()) { + StringRef CondString = Pred->getValueAsString("CondString"); + if (CondString.empty()) + continue; + if (!PredicateCheck.empty()) + PredicateCheck += " && "; + PredicateCheck += "("; + PredicateCheck += CondString; + PredicateCheck += ")"; } - llvm::sort(PredList, deref>()); - std::string Check; - for (unsigned i = 0, e = PredList.size(); i != e; ++i) { - if (i != 0) - Check += " && "; - Check += '(' + PredList[i]->getCondString() + ')'; + if (!HwModeFeatures.empty()) { + if (!PredicateCheck.empty()) + PredicateCheck += " && "; + PredicateCheck += "(MF->getSubtarget().checkFeatures(\""; + PredicateCheck += HwModeFeatures; + PredicateCheck += "\"))"; } - return Check; + + return std::string(PredicateCheck); } //===----------------------------------------------------------------------===// @@ -3923,20 +3950,6 @@ } } -std::vector CodeGenDAGPatterns::makePredList(ListInit *L) { - std::vector Preds; - for (Init *I : L->getValues()) { - if (DefInit *Pred = dyn_cast(I)) - Preds.push_back(Pred->getDef()); - else - llvm_unreachable("Non-def on the list"); - } - - // Sort so that different orders get canonicalized to the same string. - llvm::sort(Preds); - return Preds; -} - void CodeGenDAGPatterns::AddPatternToMatch(TreePattern *Pattern, PatternToMatch &&PTM) { // Do some sanity checking on the pattern we're about to match. @@ -4247,8 +4260,7 @@ for (const auto &T : Pattern.getTrees()) if (T->hasPossibleType()) AddPatternToMatch(&Pattern, - PatternToMatch(TheDef, makePredList(Preds), - T, Temp.getOnlyTree(), + PatternToMatch(TheDef, Preds, T, Temp.getOnlyTree(), InstImpResults, Complexity, TheDef->getID())); } @@ -4299,7 +4311,7 @@ void CodeGenDAGPatterns::ExpandHwModeBasedTypes() { const CodeGenHwModes &CGH = getTargetInfo().getHwModes(); - std::map> ModeChecks; + std::map ModeChecks; std::vector Copy; PatternsToMatch.swap(Copy); @@ -4310,14 +4322,10 @@ return; } - std::vector Preds = P.getPredicates(); - const std::vector &MC = ModeChecks[Mode]; - llvm::append_range(Preds, MC); - PatternsToMatch.emplace_back(P.getSrcRecord(), std::move(Preds), + PatternsToMatch.emplace_back(P.getSrcRecord(), P.getPredicates(), std::move(NewSrc), std::move(NewDst), - P.getDstRegs(), - P.getAddedComplexity(), Record::getNewUID(), - Mode); + P.getDstRegs(), P.getAddedComplexity(), + Record::getNewUID(), Mode, ModeChecks[Mode]); }; for (PatternToMatch &P : Copy) { @@ -4337,19 +4345,6 @@ if (DstP) collectModes(Modes, DstP.get()); - // The predicate for the default mode needs to be constructed for each - // pattern separately. - // Since not all modes must be present in each pattern, if a mode m is - // absent, then there is no point in constructing a check for m. If such - // a check was created, it would be equivalent to checking the default - // mode, except not all modes' predicates would be a part of the checking - // code. The subsequently generated check for the default mode would then - // have the exact same patterns, but a different predicate code. To avoid - // duplicated patterns with different predicate checks, construct the - // default check as a negation of all predicates that are actually present - // in the source/destination patterns. - std::vector DefaultPred; - for (unsigned M : Modes) { if (M == DefaultMode) continue; @@ -4358,10 +4353,7 @@ // Fill the map entry for this mode. const HwMode &HM = CGH.getMode(M); - ModeChecks[M].emplace_back(Predicate(HM.Features, true)); - - // Add negations of the HM's predicates to the default predicate. - DefaultPred.emplace_back(Predicate(HM.Features, false)); + ModeChecks[M] = HM.Features; } for (unsigned M : Modes) { @@ -4687,8 +4679,8 @@ if (MatchedPatterns[i]) continue; - const std::vector &Predicates = - PatternsToMatch[i].getPredicates(); + ListInit *Predicates = PatternsToMatch[i].getPredicates(); + const std::string &HwModeFeatures = PatternsToMatch[i].getHwModeFeatures(); BitVector &Matches = MatchedPredicates[i]; MatchedPatterns.set(i); @@ -4697,7 +4689,8 @@ // Don't test patterns that have already been cached - it won't match. for (unsigned p = 0; p != NumOriginalPatterns; ++p) if (!MatchedPatterns[p]) - Matches[p] = (Predicates == PatternsToMatch[p].getPredicates()); + Matches[p] = (Predicates == PatternsToMatch[p].getPredicates()) && + (HwModeFeatures == PatternsToMatch[p].getHwModeFeatures()); // Copy this to all the matching patterns. for (int p = Matches.find_first(); p != -1; p = Matches.find_next(p)) @@ -4741,7 +4734,9 @@ PatternsToMatch[i].getSrcRecord(), PatternsToMatch[i].getPredicates(), Variant, PatternsToMatch[i].getDstPatternShared(), PatternsToMatch[i].getDstRegs(), - PatternsToMatch[i].getAddedComplexity(), Record::getNewUID()); + PatternsToMatch[i].getAddedComplexity(), Record::getNewUID(), + PatternsToMatch[i].getForceMode(), + PatternsToMatch[i].getHwModeFeatures()); MatchedPredicates.push_back(Matches); // Add a new match the same as this pattern. diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp --- a/llvm/utils/TableGen/GlobalISelEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp @@ -3536,7 +3536,7 @@ const CodeGenInstruction *getEquivNode(Record &Equiv, const TreePatternNode *N) const; - Error importRulePredicates(RuleMatcher &M, ArrayRef Predicates); + Error importRulePredicates(RuleMatcher &M, ArrayRef Predicates); Expected createAndImportSelDAGMatcher(RuleMatcher &Rule, InstructionMatcher &InsnMatcher, @@ -3723,14 +3723,13 @@ //===- Emitter ------------------------------------------------------------===// -Error -GlobalISelEmitter::importRulePredicates(RuleMatcher &M, - ArrayRef Predicates) { - for (const Predicate &P : Predicates) { - if (!P.Def || P.getCondString().empty()) +Error GlobalISelEmitter::importRulePredicates(RuleMatcher &M, + ArrayRef Predicates) { + for (Record *Pred : Predicates) { + if (Pred->getValueAsString("CondString").empty()) continue; - declareSubtargetFeature(P.Def); - M.addRequiredFeature(P.Def); + declareSubtargetFeature(Pred); + M.addRequiredFeature(Pred); } return Error::success(); @@ -5043,7 +5042,7 @@ " => " + llvm::to_string(*P.getDstPattern())); - if (auto Error = importRulePredicates(M, P.getPredicates())) + if (auto Error = importRulePredicates(M, P.getPredicateRecords())) return std::move(Error); // Next, analyze the pattern operators.