diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -66,11 +66,20 @@ /// is incorrect. def root : GIDefKind; +/// Declares data that is passed from the match stage to the apply stage. +class GIDefMatchData : GIDefKind { + /// A C++ type name indicating the storage type. + string Type = type; +} + +def extending_load_matchdata : GIDefMatchData<"PreferredTuple">; + /// The operator at the root of a GICombineRule.Match dag. def match; /// All arguments of the match operator must be either: /// * A subclass of GIMatchKind /// * A subclass of GIMatchKindWithArgs +/// * A subclass of Instruction /// * A MIR code block (deprecated) /// The GIMatchKind and GIMatchKindWithArgs cases are described in more detail /// in their definitions below. @@ -93,6 +102,11 @@ (apply [{ Helper.applyCombineCopy(${d}); }])>; def trivial_combines : GICombineGroup<[copy_prop]>; +def extending_loads : GICombineRule< + (defs root:$root, extending_load_matchdata:$matchinfo), + (match [{ return Helper.matchCombineExtendingLoads(${root}, ${matchinfo}); }]), + (apply [{ Helper.applyCombineExtendingLoads(${root}, ${matchinfo}); }])>; + // FIXME: Is there a reason this wasn't in tryCombine? I've left it out of // all_combines because it wasn't there. def elide_br_by_inverting_cond : GICombineRule< @@ -100,4 +114,6 @@ (match [{ return Helper.matchElideBrByInvertingCond(${d}); }]), (apply [{ Helper.applyElideBrByInvertingCond(${d}); }])>; -def all_combines : GICombineGroup<[trivial_combines]>; +def combines_for_extload: GICombineGroup<[extending_loads]>; + +def all_combines : GICombineGroup<[trivial_combines, combines_for_extload]>; diff --git a/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp b/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp --- a/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp +++ b/llvm/lib/Target/AArch64/AArch64PreLegalizerCombiner.cpp @@ -62,20 +62,6 @@ CombinerHelper Helper(Observer, B, KB, MDT); switch (MI.getOpcode()) { - case TargetOpcode::G_CONCAT_VECTORS: - return Helper.tryCombineConcatVectors(MI); - case TargetOpcode::G_SHUFFLE_VECTOR: - return Helper.tryCombineShuffleVector(MI); - case TargetOpcode::G_LOAD: - case TargetOpcode::G_SEXTLOAD: - case TargetOpcode::G_ZEXTLOAD: { - bool Changed = false; - Changed |= Helper.tryCombineExtendingLoads(MI); - Changed |= Helper.tryCombineIndexedLoadStore(MI); - return Changed; - } - case TargetOpcode::G_STORE: - return Helper.tryCombineIndexedLoadStore(MI); case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS: switch (MI.getIntrinsicID()) { case Intrinsic::memcpy: @@ -96,6 +82,18 @@ if (Generated.tryCombineAll(Observer, MI, B)) return true; + switch (MI.getOpcode()) { + case TargetOpcode::G_CONCAT_VECTORS: + return Helper.tryCombineConcatVectors(MI); + case TargetOpcode::G_SHUFFLE_VECTOR: + return Helper.tryCombineShuffleVector(MI); + case TargetOpcode::G_LOAD: + case TargetOpcode::G_SEXTLOAD: + case TargetOpcode::G_ZEXTLOAD: + case TargetOpcode::G_STORE: + return Helper.tryCombineIndexedLoadStore(MI); + } + return false; } diff --git a/llvm/utils/TableGen/GICombinerEmitter.cpp b/llvm/utils/TableGen/GICombinerEmitter.cpp --- a/llvm/utils/TableGen/GICombinerEmitter.cpp +++ b/llvm/utils/TableGen/GICombinerEmitter.cpp @@ -61,6 +61,24 @@ return StrTab.insert(S).first->first(); } +/// Declares data that is passed from the match stage to the apply stage. +class MatchDataInfo { + /// The symbol used in the tablegen patterns + StringRef PatternSymbol; + /// The data type for the variable + StringRef Type; + /// The name of the variable as declared in the generated matcher. + std::string VariableName; + +public: + MatchDataInfo(StringRef PatternSymbol, StringRef Type, StringRef VariableName) + : PatternSymbol(PatternSymbol), Type(Type), VariableName(VariableName) {} + + StringRef getPatternSymbol() const { return PatternSymbol; }; + StringRef getType() const { return Type; }; + StringRef getVariableName() const { return VariableName; }; +}; + class RootInfo { StringRef PatternSymbol; @@ -71,6 +89,10 @@ }; class CombineRule { +public: + + using const_matchdata_iterator = std::vector::const_iterator; + struct VarInfo { const GIMatchDagInstr *N; const GIMatchDagOperand *Op; @@ -108,6 +130,33 @@ /// FIXME: This is a temporary measure until we have actual pattern matching const CodeInit *MatchingFixupCode = nullptr; + /// The MatchData defined by the match stage and required by the apply stage. + /// This allows the plumbing of arbitrary data from C++ predicates between the + /// stages. + /// + /// For example, suppose you have: + /// %A = + /// %0 = G_ADD %1, %A + /// you could define a GIMatchPredicate that walks %A, constant folds as much + /// as possible and returns an APInt containing the discovered constant. You + /// could then declare: + /// def apint : GIDefMatchData<"APInt">; + /// add it to the rule with: + /// (defs root:$root, apint:$constant) + /// evaluate it in the pattern with a C++ function that takes a + /// MachineOperand& and an APInt& with: + /// (match [{MIR %root = G_ADD %0, %A }], + /// (constantfold operand:$A, apint:$constant)) + /// and finally use it in the apply stage with: + /// (apply (create_operand + /// [{ MachineOperand::CreateImm(${constant}.getZExtValue()); + /// ]}, apint:$constant), + /// [{MIR %root = FOO %0, %constant }]) + std::vector MatchDataDecls; + + void declareMatchData(StringRef PatternSymbol, StringRef Type, + StringRef VarName); + bool parseInstructionMatcher(const CodeGenTarget &Target, StringInit *ArgName, const Init &Arg, StringMap> &NamedEdgeDefs, @@ -139,6 +188,16 @@ return llvm::make_range(Roots.begin(), Roots.end()); } + iterator_range matchdata_decls() const { + return make_range(MatchDataDecls.begin(), MatchDataDecls.end()); + } + + /// Export expansions for this rule + void declareExpansions(CodeExpansions &Expansions) const { + for (const auto &I : matchdata_decls()) + Expansions.declare(I.getPatternSymbol(), I.getVariableName()); + } + /// The matcher will begin from the roots and will perform the match by /// traversing the edges to cover the whole DAG. This function reverses DAG /// edges such that everything is reachable from a root. This is part of the @@ -243,6 +302,11 @@ to_string(format("__anonpred%d_%d", Rule.getID(), Rule.allocUID()))); } +void CombineRule::declareMatchData(StringRef PatternSymbol, StringRef Type, + StringRef VarName) { + MatchDataDecls.emplace_back(PatternSymbol, Type, VarName); +} + bool CombineRule::parseDefs() { NamedRegionTimer T("parseDefs", "Time spent parsing the defs", "Rule Parsing", "Time spent on rule parsing", TimeRegions); @@ -260,6 +324,17 @@ continue; } + // Subclasses of GIDefMatchData should declare that this rule needs to pass + // data from the match stage to the apply stage, and ensure that the + // generated matcher has a suitable variable for it to do so. + if (Record *MatchDataRec = + getDefOfSubClass(*Defs->getArg(I), "GIDefMatchData")) { + declareMatchData(Defs->getArgNameStr(I), + MatchDataRec->getValueAsString("Type"), + llvm::to_string(llvm::format("MatchData%d", ID))); + continue; + } + // Otherwise emit an appropriate error message. if (getDefOfSubClass(*Defs->getArg(I), "GIDefKind")) PrintError(TheDef.getLoc(), @@ -556,6 +631,8 @@ for (const RootInfo &Root : Rule->roots()) { Expansions.declare(Root.getPatternSymbol(), "MI"); } + Rule->declareExpansions(Expansions); + DagInit *Applyer = RuleDef.getValueAsDag("Apply"); if (Applyer->getOperatorAsDef(RuleDef.getLoc())->getName() != "apply") { @@ -695,6 +772,12 @@ << " MachineRegisterInfo &MRI = MF->getRegInfo();\n" << " (void)MBB; (void)MF; (void)MRI;\n\n"; + OS << " // Match data\n"; + for (const auto &Rule : Rules) + for (const auto &I : Rule->matchdata_decls()) + OS << " " << I.getType() << " " << I.getVariableName() << ";\n"; + OS << "\n"; + for (const auto &Rule : Rules) generateCodeForRule(OS, Rule.get(), " "); OS << "\n return false;\n"