diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h --- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -85,8 +85,9 @@ /// legal and the surrounding code makes it useful. bool tryCombineIndexedLoadStore(MachineInstr &MI); - bool matchCombineBr(MachineInstr &MI); - bool tryCombineBr(MachineInstr &MI); + bool matchElideBrByInvertingCond(MachineInstr &MI); + void applyElideBrByInvertingCond(MachineInstr &MI); + bool tryElideBrByInvertingCond(MachineInstr &MI); /// Optimize memcpy intrinsics et al, e.g. constant len calls. /// /p MaxLen if non-zero specifies the max length of a mem libcall to inline. diff --git a/llvm/include/llvm/TableGen/Error.h b/llvm/include/llvm/TableGen/Error.h --- a/llvm/include/llvm/TableGen/Error.h +++ b/llvm/include/llvm/TableGen/Error.h @@ -18,6 +18,7 @@ namespace llvm { +void PrintNote(const Twine &Msg); void PrintNote(ArrayRef NoteLoc, const Twine &Msg); void PrintWarning(ArrayRef WarningLoc, const Twine &Msg); diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -10,8 +10,91 @@ // //===----------------------------------------------------------------------===// +// Common base class for GICombineRule and GICombineGroup. +class GICombine { + // See GICombineGroup. We only declare it here to make the tablegen pass + // simpler. + list Rules = ?; +} + +// A group of combine rules that can be added to a GICombiner or another group. +class GICombineGroup rules> : GICombine { + // The rules contained in this group. The rules in a group are flattened into + // a single list and sorted into whatever order is most efficient. However, + // they will never be re-ordered such that behaviour differs from the + // specified order. It is therefore possible to use the order of rules in this + // list to describe priorities. + let Rules = rules; +} + // Declares a combiner helper class -class GICombinerHelper { +class GICombinerHelper rules> + : GICombineGroup { // The class name to use in the generated output. string Classname = classname; } +class GICombineRule : GICombine { + /// Defines the external interface of the match rule. This includes: + /// * The names of the root nodes (requires at least one) + /// See GIDefKind for details. + dag Defs = defs; + + /// Defines the things which must be true for the pattern to match + /// See GIMatchKind for details. + dag Match = match; + + /// Defines the things which happen after the decision is made to apply a + /// combine rule. + /// See GIApplyKind for details. + dag Apply = apply; +} + +/// The operator at the root of a GICombineRule.Defs dag. +def defs; + +/// All arguments of the defs operator must be subclasses of GIDefKind or +/// sub-dags whose operator is GIDefKindWithArgs. +class GIDefKind; +class GIDefKindWithArgs; +/// Declare a root node. There must be at least one of these in every combine +/// rule. +/// TODO: The plan is to elide `root` definitions and determine it from the DAG +/// itself with an overide for situations where the usual determination +/// is incorrect. +def root : GIDefKind; + +/// The operator at the root of a GICombineRule.Match dag. +def match; +/// All arguments of the match operator must be either: +/// * A subclass of GIMatchKind +/// * A subclass of GIMatchKindWithArgs +/// * A MIR code block (deprecated) +/// The GIMatchKind and GIMatchKindWithArgs cases are described in more detail +/// in their definitions below. +/// For the Instruction case, these are collected into a DAG where operand names +/// that occur multiple times introduce edges. +class GIMatchKind; +class GIMatchKindWithArgs; + +/// The operator at the root of a GICombineRule.Apply dag. +def apply; +/// All arguments of the apply operator must be subclasses of GIApplyKind, or +/// sub-dags whose operator is GIApplyKindWithArgs, or an MIR block +/// (deprecated). +class GIApplyKind; +class GIApplyKindWithArgs; + +def copy_prop : GICombineRule< + (defs root:$d), + (match [{ return Helper.matchCombineCopy(${d}); }]), + (apply [{ Helper.applyCombineCopy(${d}); }])>; +def trivial_combines : GICombineGroup<[copy_prop]>; + +// FIXME: Is there a reason this wasn't in tryCombine? I've left it out of +// all_combines because it wasn't there. +def elide_br_by_inverting_cond : GICombineRule< + (defs root:$d), + (match [{ return Helper.matchElideBrByInvertingCond(${d}); }]), + (apply [{ Helper.applyElideBrByInvertingCond(${d}); }])>; + +def all_combines : GICombineGroup<[trivial_combines]>; diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -561,8 +561,10 @@ return true; } -bool CombinerHelper::matchCombineBr(MachineInstr &MI) { - assert(MI.getOpcode() == TargetOpcode::G_BR && "Expected a G_BR"); +bool CombinerHelper::matchElideBrByInvertingCond(MachineInstr &MI) { + if (MI.getOpcode() != TargetOpcode::G_BR) + return false; + // Try to match the following: // bb1: // %c(s32) = G_ICMP pred, %a, %b @@ -599,9 +601,14 @@ return true; } -bool CombinerHelper::tryCombineBr(MachineInstr &MI) { - if (!matchCombineBr(MI)) +bool CombinerHelper::tryElideBrByInvertingCond(MachineInstr &MI) { + if (!matchElideBrByInvertingCond(MI)) return false; + applyElideBrByInvertingCond(MI); + return true; +} + +void CombinerHelper::applyElideBrByInvertingCond(MachineInstr &MI) { MachineBasicBlock *BrTarget = MI.getOperand(0).getMBB(); MachineBasicBlock::iterator BrIt(MI); MachineInstr *BrCond = &*std::prev(BrIt); @@ -620,7 +627,6 @@ BrCond->getOperand(1).setMBB(BrTarget); Observer.changedInstr(*BrCond); MI.eraseFromParent(); - return true; } static bool shouldLowerMemFuncForSize(const MachineFunction &MF) { diff --git a/llvm/lib/TableGen/Error.cpp b/llvm/lib/TableGen/Error.cpp --- a/llvm/lib/TableGen/Error.cpp +++ b/llvm/lib/TableGen/Error.cpp @@ -39,6 +39,8 @@ "instantiated from multiclass"); } +void PrintNote(const Twine &Msg) { WithColor::note() << Msg << "\n"; } + void PrintNote(ArrayRef NoteLoc, const Twine &Msg) { PrintMessage(NoteLoc, SourceMgr::DK_Note, Msg); } diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -12,4 +12,5 @@ include "llvm/Target/GlobalISel/Combine.td" def AArch64PreLegalizerCombinerHelper: GICombinerHelper< - "AArch64GenPreLegalizerCombinerHelper">; + "AArch64GenPreLegalizerCombinerHelper", [all_combines, + elide_br_by_inverting_cond]>; diff --git a/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp --- a/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp +++ b/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp @@ -58,12 +58,6 @@ CombinerHelper Helper(Observer, B, KB, MDT); switch (MI.getOpcode()) { - default: - return false; - case TargetOpcode::COPY: - return Helper.tryCombineCopy(MI); - case TargetOpcode::G_BR: - return Helper.tryCombineBr(MI); case TargetOpcode::G_LOAD: case TargetOpcode::G_SEXTLOAD: case TargetOpcode::G_ZEXTLOAD: { @@ -118,7 +112,7 @@ private: bool IsOptNone; }; -} +} // end anonymous namespace void AArch64PreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired(); diff --git a/llvm/utils/TableGen/GICombinerEmitter.cpp b/llvm/utils/TableGen/GICombinerEmitter.cpp --- a/llvm/utils/TableGen/GICombinerEmitter.cpp +++ b/llvm/utils/TableGen/GICombinerEmitter.cpp @@ -11,16 +11,23 @@ /// //===----------------------------------------------------------------------===// +#include "llvm/ADT/Statistic.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Timer.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/TableGenBackend.h" #include "CodeGenTarget.h" +#include "GlobalISel/CodeExpander.h" +#include "GlobalISel/CodeExpansions.h" using namespace llvm; #define DEBUG_TYPE "gicombiner-emitter" +// FIXME: Use ALWAYS_ENABLED_STATISTIC once it's available. +unsigned NumPatternTotal = 0; +STATISTIC(NumPatternTotalStatistic, "Total number of patterns"); + cl::OptionCategory GICombinerEmitterCat("Options for -gen-global-isel-combiner"); static cl::list @@ -32,12 +39,162 @@ cl::cat(GICombinerEmitterCat)); namespace { +typedef uint64_t RuleID; + +class RootInfo { + StringRef PatternSymbol; + +public: + RootInfo(StringRef PatternSymbol) : PatternSymbol(PatternSymbol) {} + + StringRef getPatternSymbol() const { return PatternSymbol; } +}; + +class CombineRule { +protected: + /// A unique ID for this rule + /// ID's are used for debugging and run-time disabling of rules among other + /// things. + RuleID ID; + + /// The record defining this rule. + const Record &TheDef; + + /// The roots of a match. These are the leaves of the DAG that are closest to + /// the end of the function. I.e. the nodes that are encountered without + /// following any edges of the DAG described by the pattern as we work our way + /// from the bottom of the function to the top. + std::vector Roots; + + /// A block of arbitrary C++ to finish testing the match. + /// FIXME: This is a temporary measure until we have actual pattern matching + const CodeInit *MatchingFixupCode = nullptr; +public: + CombineRule(const CodeGenTarget &Target, RuleID ID, const Record &R) + : ID(ID), TheDef(R) {} + bool parseDefs(); + bool parseMatcher(const CodeGenTarget &Target); + + const Record &getDef() const { return TheDef; } + const CodeInit *getMatchingFixupCode() const { return MatchingFixupCode; } + size_t getNumRoots() const { return Roots.size(); } + + using const_root_iterator = std::vector::const_iterator; + const_root_iterator roots_begin() const { return Roots.begin(); } + const_root_iterator roots_end() const { return Roots.end(); } + iterator_range roots() const { + return llvm::make_range(Roots.begin(), Roots.end()); + } +}; + +/// A convenience function to check that an Init refers to a specific def. This +/// is primarily useful for testing for defs and similar in DagInit's since +/// DagInit's support any type inside them. +static bool isSpecificDef(const Init &N, StringRef Def) { + if (const DefInit *OpI = dyn_cast(&N)) + if (OpI->getDef()->getName() == Def) + return true; + return false; +} + +/// A convenience function to check that an Init refers to a def that is a +/// subclass of the given class and coerce it to a def if it is. This is +/// primarily useful for testing for subclasses of GIMatchKind and similar in +/// DagInit's since DagInit's support any type inside them. +static Record *getDefOfSubClass(const Init &N, StringRef Cls) { + if (const DefInit *OpI = dyn_cast(&N)) + if (OpI->getDef()->isSubClassOf(Cls)) + return OpI->getDef(); + return nullptr; +} + +bool CombineRule::parseDefs() { + NamedRegionTimer T("parseDefs", "Time spent parsing the defs", "Rule Parsing", + "Time spent on rule parsing", TimeRegions); + DagInit *Defs = TheDef.getValueAsDag("Defs"); + + if (Defs->getOperatorAsDef(TheDef.getLoc())->getName() != "defs") { + PrintError(TheDef.getLoc(), "Expected defs operator"); + return false; + } + + for (unsigned I = 0, E = Defs->getNumArgs(); I < E; ++I) { + // Roots should be collected into Roots + if (isSpecificDef(*Defs->getArg(I), "root")) { + Roots.emplace_back(Defs->getArgNameStr(I)); + continue; + } + + // Otherwise emit an appropriate error message. + if (getDefOfSubClass(*Defs->getArg(I), "GIDefKind")) + PrintError(TheDef.getLoc(), + "This GIDefKind not implemented in tablegen"); + else if (getDefOfSubClass(*Defs->getArg(I), "GIDefKindWithArgs")) + PrintError(TheDef.getLoc(), + "This GIDefKindWithArgs not implemented in tablegen"); + else + PrintError(TheDef.getLoc(), + "Expected a subclass of GIDefKind or a sub-dag whose " + "operator is of type GIDefKindWithArgs"); + return false; + } + + if (Roots.empty()) { + PrintError(TheDef.getLoc(), "Combine rules must have at least one root"); + return false; + } + return true; +} + +bool CombineRule::parseMatcher(const CodeGenTarget &Target) { + NamedRegionTimer T("parseMatcher", "Time spent parsing the matcher", + "Rule Parsing", "Time spent on rule parsing", TimeRegions); + DagInit *Matchers = TheDef.getValueAsDag("Match"); + + if (Matchers->getOperatorAsDef(TheDef.getLoc())->getName() != "match") { + PrintError(TheDef.getLoc(), "Expected match operator"); + return false; + } + + if (Matchers->getNumArgs() == 0) { + PrintError(TheDef.getLoc(), "Matcher is empty"); + return false; + } + + // The match section consists of a list of matchers and predicates. Parse each + // one and add the equivalent GIMatchDag nodes, predicates, and edges. + for (unsigned I = 0; I < Matchers->getNumArgs(); ++I) { + + // Parse arbitrary C++ code we have in lieu of supporting MIR matching + if (const CodeInit *CodeI = dyn_cast(Matchers->getArg(I))) { + assert(!MatchingFixupCode && + "Only one block of arbitrary code is currently permitted"); + MatchingFixupCode = CodeI; + continue; + } + + PrintError(TheDef.getLoc(), + "Expected a subclass of GIMatchKind or a sub-dag whose " + "operator is either of a GIMatchKindWithArgs or Instruction"); + PrintNote("Pattern was `" + Matchers->getArg(I)->getAsString() + "'"); + return false; + } + return true; +} + class GICombinerEmitter { StringRef Name; + const CodeGenTarget &Target; Record *Combiner; + std::vector> Rules; + std::unique_ptr makeCombineRule(const Record &R); + + void gatherRules(std::vector> &ActiveRules, + const std::vector &&RulesAndGroups); + public: - explicit GICombinerEmitter(RecordKeeper &RK, StringRef Name, - Record *Combiner); + explicit GICombinerEmitter(RecordKeeper &RK, const CodeGenTarget &Target, + StringRef Name, Record *Combiner); ~GICombinerEmitter() {} StringRef getClassName() const { @@ -45,13 +202,109 @@ } void run(raw_ostream &OS); + void generateCodeForRule(raw_ostream &OS, const CombineRule *Rule, + StringRef Indent) const; }; -GICombinerEmitter::GICombinerEmitter(RecordKeeper &RK, StringRef Name, - Record *Combiner) - : Name(Name), Combiner(Combiner) {} +GICombinerEmitter::GICombinerEmitter(RecordKeeper &RK, + const CodeGenTarget &Target, + StringRef Name, Record *Combiner) + : Name(Name), Target(Target), Combiner(Combiner) {} + +std::unique_ptr +GICombinerEmitter::makeCombineRule(const Record &TheDef) { + std::unique_ptr Rule = + std::make_unique(Target, NumPatternTotal, TheDef); + + if (!Rule->parseDefs()) + return nullptr; + if (!Rule->parseMatcher(Target)) + return nullptr; + // For now, don't support multi-root rules. We'll come back to this later + // once we have the algorithm changes to support it. + if (Rule->getNumRoots() > 1) { + PrintError(TheDef.getLoc(), "Multi-root matches are not supported (yet)"); + return nullptr; + } + return Rule; +} + +/// Recurse into GICombineGroup's and flatten the ruleset into a simple list. +void GICombinerEmitter::gatherRules( + std::vector> &ActiveRules, + const std::vector &&RulesAndGroups) { + for (Record *R : RulesAndGroups) { + if (R->isValueUnset("Rules")) { + std::unique_ptr Rule = makeCombineRule(*R); + if (Rule == nullptr) { + PrintError(R->getLoc(), "Failed to parse rule"); + continue; + } + ActiveRules.emplace_back(std::move(Rule)); + ++NumPatternTotal; + } else + gatherRules(ActiveRules, R->getValueAsListOfDefs("Rules")); + } +} + +void GICombinerEmitter::generateCodeForRule(raw_ostream &OS, + const CombineRule *Rule, + StringRef Indent) const { + { + const Record &RuleDef = Rule->getDef(); + + OS << Indent << "// Rule: " << RuleDef.getName() << "\n" + << Indent << "{\n"; + + CodeExpansions Expansions; + for (const RootInfo &Root : Rule->roots()) { + Expansions.declare(Root.getPatternSymbol(), "MI"); + } + DagInit *Applyer = RuleDef.getValueAsDag("Apply"); + if (Applyer->getOperatorAsDef(RuleDef.getLoc())->getName() != + "apply") { + PrintError(RuleDef.getLoc(), "Expected apply operator"); + return; + } + + OS << Indent << " if (1\n"; + + if (Rule->getMatchingFixupCode() && + !Rule->getMatchingFixupCode()->getValue().empty()) { + // FIXME: Single-use lambda's like this are a serious compile-time + // performance and memory issue. It's convenient for this early stage to + // defer some work to successive patches but we need to eliminate this + // before the ruleset grows to small-moderate size. Last time, it became + // a big problem for low-mem systems around the 500 rule mark but by the + // time we grow that large we should have merged the ISel match table + // mechanism with the Combiner. + OS << Indent << " && [&]() {\n" + << Indent << " " + << CodeExpander(Rule->getMatchingFixupCode()->getValue(), Expansions, + Rule->getMatchingFixupCode()->getLoc(), ShowExpansions) + << "\n" + << Indent << " return true;\n" + << Indent << " }()"; + } + OS << ") {\n" << Indent << " "; + + if (const CodeInit *Code = dyn_cast(Applyer->getArg(0))) { + OS << CodeExpander(Code->getAsUnquotedString(), Expansions, + Code->getLoc(), ShowExpansions) + << "\n" + << Indent << " return true;\n" + << Indent << " }\n"; + } else { + PrintError(RuleDef.getLoc(), "Expected apply code block"); + return; + } + + OS << Indent << "}\n"; + } +} void GICombinerEmitter::run(raw_ostream &OS) { + gatherRules(Rules, Combiner->getValueAsListOfDefs("Rules")); NamedRegionTimer T("Emit", "Time spent emitting the combiner", "Code Generation", "Time spent generating code", TimeRegions); @@ -74,10 +327,14 @@ << " GISelChangeObserver &Observer,\n" << " MachineInstr &MI,\n" << " MachineIRBuilder &B) const {\n" + << " CombinerHelper Helper(Observer, B);\n" << " MachineBasicBlock *MBB = MI.getParent();\n" << " MachineFunction *MF = MBB->getParent();\n" << " MachineRegisterInfo &MRI = MF->getRegInfo();\n" << " (void)MBB; (void)MF; (void)MRI;\n\n"; + + for (const auto &Rule : Rules) + generateCodeForRule(OS, Rule.get(), " "); OS << "\n return false;\n" << "}\n" << "#endif // ifdef " << Name.upper() << "_GENCOMBINERHELPER_CPP\n"; @@ -98,8 +355,9 @@ Record *CombinerDef = RK.getDef(Combiner); if (!CombinerDef) PrintFatalError("Could not find " + Combiner); - GICombinerEmitter(RK, Combiner, CombinerDef).run(OS); + GICombinerEmitter(RK, Target, Combiner, CombinerDef).run(OS); } + NumPatternTotalStatistic = NumPatternTotal; } } // namespace llvm