diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.h b/llvm/utils/TableGen/CodeGenDAGPatterns.h --- a/llvm/utils/TableGen/CodeGenDAGPatterns.h +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.h @@ -1047,92 +1047,43 @@ TreePatternNodePtr getResultPattern() const { return ResultPattern; } }; -/// This class represents a condition that has to be satisfied for a pattern -/// to be tried. It is a generalization of a class "Pattern" from Target.td: -/// in addition to the Target.td's predicates, this class can also represent -/// conditions associated with HW modes. Both types will eventually become -/// strings containing C++ code to be executed, the difference is in how -/// these strings are generated. -class Predicate { -public: - Predicate(Record *R, bool C = true) : Def(R), IfCond(C), IsHwMode(false) { - assert(R->isSubClassOf("Predicate") && - "Predicate objects should only be created for records derived" - "from Predicate class"); - } - Predicate(StringRef FS, bool C = true) : Def(nullptr), Features(FS.str()), - IfCond(C), IsHwMode(true) {} - - /// Return a string which contains the C++ condition code that will serve - /// as a predicate during instruction selection. - std::string getCondString() const { - // The string will excute in a subclass of SelectionDAGISel. - // Cast to std::string explicitly to avoid ambiguity with StringRef. - std::string C = IsHwMode - ? std::string("MF->getSubtarget().checkFeatures(\"" + - Features + "\")") - : std::string(Def->getValueAsString("CondString")); - if (C.empty()) - return ""; - return IfCond ? C : "!("+C+')'; - } - - bool operator==(const Predicate &P) const { - return IfCond == P.IfCond && IsHwMode == P.IsHwMode && Def == P.Def && - Features == P.Features; - } - bool operator<(const Predicate &P) const { - if (IsHwMode != P.IsHwMode) - return IsHwMode < P.IsHwMode; - assert(!Def == !P.Def && "Inconsistency between Def and IsHwMode"); - if (IfCond != P.IfCond) - return IfCond < P.IfCond; - if (Def) - return LessRecord()(Def, P.Def); - return Features < P.Features; - } - Record *Def; ///< Predicate definition from .td file, null for - ///< HW modes. - std::string Features; ///< Feature string for HW mode. - bool IfCond; ///< The boolean value that the condition has to - ///< evaluate to for this predicate to be true. - bool IsHwMode; ///< Does this predicate correspond to a HW mode? -}; - /// PatternToMatch - Used by CodeGenDAGPatterns to keep tab of patterns /// processed to produce isel. class PatternToMatch { Record *SrcRecord; // Originating Record for the pattern. + ListInit *Predicates; // Top level predicate conditions to match. TreePatternNodePtr SrcPattern; // Source pattern to match. TreePatternNodePtr DstPattern; // Resulting pattern. - std::vector Predicates; // Top level predicate conditions - // to match. std::vector Dstregs; // Physical register defs being matched. + std::string HwModeFeatures; int AddedComplexity; // Add to matching pattern complexity. unsigned ID; // Unique ID for the record. unsigned ForceMode; // Force this mode in type inference when set. public: - PatternToMatch(Record *srcrecord, std::vector preds, - TreePatternNodePtr src, TreePatternNodePtr dst, - std::vector dstregs, int complexity, - unsigned uid, unsigned setmode = 0) - : SrcRecord(srcrecord), SrcPattern(src), DstPattern(dst), - Predicates(std::move(preds)), Dstregs(std::move(dstregs)), - AddedComplexity(complexity), ID(uid), ForceMode(setmode) {} + PatternToMatch(Record *srcrecord, ListInit *preds, TreePatternNodePtr src, + TreePatternNodePtr dst, std::vector dstregs, + int complexity, unsigned uid, unsigned setmode = 0, + const Twine &hwmodefeatures = "") + : SrcRecord(srcrecord), Predicates(preds), SrcPattern(src), + DstPattern(dst), Dstregs(std::move(dstregs)), + HwModeFeatures(hwmodefeatures.str()), AddedComplexity(complexity), + ID(uid), ForceMode(setmode) {} Record *getSrcRecord() const { return SrcRecord; } + ListInit *getPredicates() const { return Predicates; } TreePatternNode *getSrcPattern() const { return SrcPattern.get(); } TreePatternNodePtr getSrcPatternShared() const { return SrcPattern; } TreePatternNode *getDstPattern() const { return DstPattern.get(); } TreePatternNodePtr getDstPatternShared() const { return DstPattern; } const std::vector &getDstRegs() const { return Dstregs; } + StringRef getHwModeFeatures() const { return HwModeFeatures; } int getAddedComplexity() const { return AddedComplexity; } - const std::vector &getPredicates() const { return Predicates; } unsigned getID() const { return ID; } unsigned getForceMode() const { return ForceMode; } std::string getPredicateCheck() const; + void getPredicateRecords(SmallVectorImpl &PredicateRecs) const; /// Compute the complexity metric for the input pattern. This roughly /// corresponds to the number of nodes that are covered. @@ -1290,8 +1241,6 @@ void GenerateVariants(); void VerifyInstructionFlags(); - std::vector makePredList(ListInit *L); - void ParseOnePattern(Record *TheDef, TreePattern &Pattern, TreePattern &Result, const std::vector &InstImpResults); diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp --- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp @@ -1440,24 +1440,50 @@ return getPatternSize(getSrcPattern(), CGP) + getAddedComplexity(); } +void PatternToMatch::getPredicateRecords( + SmallVectorImpl &PredicateRecs) const { + for (Init *I : Predicates->getValues()) { + if (DefInit *Pred = dyn_cast(I)) { + Record *Def = Pred->getDef(); + if (!Def->isSubClassOf("Predicate")) { +#ifndef NDEBUG + Def->dump(); +#endif + llvm_unreachable("Unknown predicate type!"); + } + PredicateRecs.push_back(Def); + } + } + // Sort so that different orders get canonicalized to the same string. + llvm::sort(PredicateRecs, LessRecord()); +} + /// getPredicateCheck - Return a single string containing all of this /// pattern's predicates concatenated with "&&" operators. /// std::string PatternToMatch::getPredicateCheck() const { - SmallVector PredList; - for (const Predicate &P : Predicates) { - if (!P.getCondString().empty()) - PredList.push_back(&P); + SmallVector PredicateRecs; + getPredicateRecords(PredicateRecs); + + SmallString<128> PredicateCheck; + for (Record *Pred : PredicateRecs) { + StringRef CondString = Pred->getValueAsString("CondString"); + if (CondString.empty()) + continue; + if (!PredicateCheck.empty()) + PredicateCheck += " && "; + PredicateCheck += "("; + PredicateCheck += CondString; + PredicateCheck += ")"; } - llvm::sort(PredList, deref>()); - std::string Check; - for (unsigned i = 0, e = PredList.size(); i != e; ++i) { - if (i != 0) - Check += " && "; - Check += '(' + PredList[i]->getCondString() + ')'; + if (!HwModeFeatures.empty()) { + if (!PredicateCheck.empty()) + PredicateCheck += " && "; + PredicateCheck += HwModeFeatures; } - return Check; + + return std::string(PredicateCheck); } //===----------------------------------------------------------------------===// @@ -3930,20 +3956,6 @@ } } -std::vector CodeGenDAGPatterns::makePredList(ListInit *L) { - std::vector Preds; - for (Init *I : L->getValues()) { - if (DefInit *Pred = dyn_cast(I)) - Preds.push_back(Pred->getDef()); - else - llvm_unreachable("Non-def on the list"); - } - - // Sort so that different orders get canonicalized to the same string. - llvm::sort(Preds); - return Preds; -} - void CodeGenDAGPatterns::AddPatternToMatch(TreePattern *Pattern, PatternToMatch &&PTM) { // Do some sanity checking on the pattern we're about to match. @@ -4254,8 +4266,7 @@ for (const auto &T : Pattern.getTrees()) if (T->hasPossibleType()) AddPatternToMatch(&Pattern, - PatternToMatch(TheDef, makePredList(Preds), - T, Temp.getOnlyTree(), + PatternToMatch(TheDef, Preds, T, Temp.getOnlyTree(), InstImpResults, Complexity, TheDef->getID())); } @@ -4310,20 +4321,17 @@ PatternsToMatch.swap(Copy); auto AppendPattern = [this](PatternToMatch &P, unsigned Mode, - ArrayRef Check) { + StringRef Check) { TreePatternNodePtr NewSrc = P.getSrcPattern()->clone(); TreePatternNodePtr NewDst = P.getDstPattern()->clone(); if (!NewSrc->setDefaultMode(Mode) || !NewDst->setDefaultMode(Mode)) { return; } - std::vector Preds = P.getPredicates(); - llvm::append_range(Preds, Check); - PatternsToMatch.emplace_back(P.getSrcRecord(), std::move(Preds), + PatternsToMatch.emplace_back(P.getSrcRecord(), P.getPredicates(), std::move(NewSrc), std::move(NewDst), - P.getDstRegs(), - P.getAddedComplexity(), Record::getNewUID(), - Mode); + P.getDstRegs(), P.getAddedComplexity(), + Record::getNewUID(), Mode, Check); }; for (PatternToMatch &P : Copy) { @@ -4354,7 +4362,7 @@ // duplicated patterns with different predicate checks, construct the // default check as a negation of all predicates that are actually present // in the source/destination patterns. - SmallVector DefaultCheck; + SmallString<128> DefaultCheck; for (unsigned M : Modes) { if (M == DefaultMode) @@ -4362,10 +4370,14 @@ // Fill the map entry for this mode. const HwMode &HM = CGH.getMode(M); - AppendPattern(P, M, Predicate(HM.Features, true)); + AppendPattern(P, M, "(MF->getSubtarget().checkFeatures(\"" + HM.Features + "\"))"); // Add negations of the HM's predicates to the default predicate. - DefaultCheck.push_back(Predicate(HM.Features, false)); + if (!DefaultCheck.empty()) + DefaultCheck += " && "; + DefaultCheck += "(!(MF->getSubtarget().checkFeatures(\""; + DefaultCheck += HM.Features; + DefaultCheck += "\")))"; } bool HasDefault = Modes.count(DefaultMode); @@ -4685,8 +4697,8 @@ if (MatchedPatterns[i]) continue; - const std::vector &Predicates = - PatternsToMatch[i].getPredicates(); + ListInit *Predicates = PatternsToMatch[i].getPredicates(); + StringRef HwModeFeatures = PatternsToMatch[i].getHwModeFeatures(); BitVector &Matches = MatchedPredicates[i]; MatchedPatterns.set(i); @@ -4695,7 +4707,8 @@ // Don't test patterns that have already been cached - it won't match. for (unsigned p = 0; p != NumOriginalPatterns; ++p) if (!MatchedPatterns[p]) - Matches[p] = (Predicates == PatternsToMatch[p].getPredicates()); + Matches[p] = (Predicates == PatternsToMatch[p].getPredicates()) && + (HwModeFeatures == PatternsToMatch[p].getHwModeFeatures()); // Copy this to all the matching patterns. for (int p = Matches.find_first(); p != -1; p = Matches.find_next(p)) @@ -4739,7 +4752,9 @@ PatternsToMatch[i].getSrcRecord(), PatternsToMatch[i].getPredicates(), Variant, PatternsToMatch[i].getDstPatternShared(), PatternsToMatch[i].getDstRegs(), - PatternsToMatch[i].getAddedComplexity(), Record::getNewUID()); + PatternsToMatch[i].getAddedComplexity(), Record::getNewUID(), + PatternsToMatch[i].getForceMode(), + PatternsToMatch[i].getHwModeFeatures().str()); MatchedPredicates.push_back(Matches); // Add a new match the same as this pattern. diff --git a/llvm/utils/TableGen/GlobalISelEmitter.cpp b/llvm/utils/TableGen/GlobalISelEmitter.cpp --- a/llvm/utils/TableGen/GlobalISelEmitter.cpp +++ b/llvm/utils/TableGen/GlobalISelEmitter.cpp @@ -3536,7 +3536,7 @@ const CodeGenInstruction *getEquivNode(Record &Equiv, const TreePatternNode *N) const; - Error importRulePredicates(RuleMatcher &M, ArrayRef Predicates); + Error importRulePredicates(RuleMatcher &M, ArrayRef Predicates); Expected createAndImportSelDAGMatcher(RuleMatcher &Rule, InstructionMatcher &InsnMatcher, @@ -3723,14 +3723,13 @@ //===- Emitter ------------------------------------------------------------===// -Error -GlobalISelEmitter::importRulePredicates(RuleMatcher &M, - ArrayRef Predicates) { - for (const Predicate &P : Predicates) { - if (!P.Def || P.getCondString().empty()) +Error GlobalISelEmitter::importRulePredicates(RuleMatcher &M, + ArrayRef Predicates) { + for (Record *Pred : Predicates) { + if (Pred->getValueAsString("CondString").empty()) continue; - declareSubtargetFeature(P.Def); - M.addRequiredFeature(P.Def); + declareSubtargetFeature(Pred); + M.addRequiredFeature(Pred); } return Error::success(); @@ -5042,7 +5041,9 @@ " => " + llvm::to_string(*P.getDstPattern())); - if (auto Error = importRulePredicates(M, P.getPredicates())) + SmallVector Predicates; + P.getPredicateRecords(Predicates); + if (auto Error = importRulePredicates(M, Predicates)) return std::move(Error); // Next, analyze the pattern operators.