Index: utils/TableGen/GlobalISelEmitter.cpp =================================================================== --- utils/TableGen/GlobalISelEmitter.cpp +++ utils/TableGen/GlobalISelEmitter.cpp @@ -110,16 +110,6 @@ //===- Matchers -----------------------------------------------------------===// -struct Matcher { - virtual ~Matcher() {} - virtual void emit(raw_ostream &OS) const = 0; -}; - -raw_ostream &operator<<(raw_ostream &S, const Matcher &M) { - M.emit(S); - return S; -} - struct MatchAction { virtual ~MatchAction() {} virtual void emit(raw_ostream &OS) const = 0; @@ -130,46 +120,211 @@ return S; } -struct MatchOpcode : public Matcher { - MatchOpcode(const CodeGenInstruction *I) : I(I) {} - const CodeGenInstruction *I; +template +class WithPredicates { +private: + typedef std::vector> PredicateVec; + PredicateVec Predicates; - virtual void emit(raw_ostream &OS) const { - OS << "I.getOpcode() == " << I->Namespace << "::" << I->TheDef->getName(); +public: + /// Construct a new operand predicate and add it to the matcher. + template + Kind &addPredicate(Args&&... args) { + Predicates.emplace_back(make_unique(std::forward(args)...)); + return *static_cast(Predicates.back().get()); + } + + typename PredicateVec::const_iterator predicates_begin() const { return Predicates.begin(); } + typename PredicateVec::const_iterator predicates_end() const { return Predicates.end(); } + iterator_range predicates() const { + return make_range(predicates_begin(), predicates_end()); } + +#if 0 + /// Emit a C++ expression that tests whether all the predicates are met. + template + void emitCxxPredicatesExpr(raw_ostream &OS, Args&&... args) const { + if (Predicates.empty()) { + OS << "true"; + return; + } + + StringRef Separator = ""; + for (const auto &Predicate : predicates()) { + OS << Separator << "("; + Predicate->emitCxxPredicateExpr(OS, std::forward(args)...); + OS << ")"; + Separator = " && "; + } + } +#else + /// Emit a C++ expression that tests whether all the predicates are met. + template + void emitCxxPredicatesExpr(raw_ostream &OS, Arg1&& arg1) const { + if (Predicates.empty()) { + OS << "true"; + return; + } + + StringRef Separator = ""; + for (const auto &Predicate : predicates()) { + OS << Separator << "("; + Predicate->emitCxxPredicateExpr(OS, std::forward(arg1)); + OS << ")"; + Separator = " && "; + } + } + + template + void emitCxxPredicatesExpr(raw_ostream &OS, Arg1&& arg1, Arg2&& arg2) const { + if (Predicates.empty()) { + OS << "true"; + return; + } + + StringRef Separator = ""; + for (const auto &Predicate : predicates()) { + OS << Separator << "("; + Predicate->emitCxxPredicateExpr(OS, std::forward(arg1), + std::forward(arg2)); + OS << ")"; + Separator = " && "; + } + } +#endif }; -struct MatchRegOpType : public Matcher { - MatchRegOpType(unsigned OpIdx, std::string Ty) - : OpIdx(OpIdx), Ty(Ty) {} - unsigned OpIdx; +/// Generates code to check a predicate of an operand. +/// +/// Typical predicates include: +/// * Operand is a particular register. +/// * Operand is assigned a particular register bank. +/// * Operand is an MBB. +class OperandPredicateMatcher { +public: + virtual ~OperandPredicateMatcher() {} + + /// Emit a C++ expression that checks the predicate for the OpIdx operand of + /// the instruction given in InsnVarName. + virtual void emitCxxPredicateExpr(raw_ostream &OS, + const StringRef InsnVarName, + unsigned OpIdx) const = 0; +}; + +/// Generates code to check that an operand is a particular LLT. +class LLTOperandMatcher : public OperandPredicateMatcher { +protected: std::string Ty; - virtual void emit(raw_ostream &OS) const { - OS << "MRI.getType(I.getOperand(" << OpIdx << ").getReg()) == (" << Ty - << ")"; +public: + LLTOperandMatcher(std::string Ty) : Ty(Ty) {} + + void emitCxxPredicateExpr(raw_ostream &OS, const StringRef InsnVarName, + unsigned OpIdx) const override { + OS << "MRI.getType(" << InsnVarName << ".getOperand(" << OpIdx + << ").getReg()) == (" << Ty << ")"; } }; -struct MatchRegOpBank : public Matcher { - MatchRegOpBank(unsigned OpIdx, const CodeGenRegisterClass &RC) - : OpIdx(OpIdx), RC(RC) {} - unsigned OpIdx; +/// Generates code to check that an operand is in a particular register bank. +class RegisterBankOperandMatcher : public OperandPredicateMatcher { +protected: const CodeGenRegisterClass &RC; - virtual void emit(raw_ostream &OS) const { +public: + RegisterBankOperandMatcher(const CodeGenRegisterClass &RC) : RC(RC) {} + + void emitCxxPredicateExpr(raw_ostream &OS, const StringRef InsnVarName, + unsigned OpIdx) const override { OS << "(&RBI.getRegBankFromRegClass(" << RC.getQualifiedName() - << "RegClass) == RBI.getRegBank(I.getOperand(" << OpIdx - << ").getReg(), MRI, TRI))"; + << "RegClass) == RBI.getRegBank(" << InsnVarName << ".getOperand(" + << OpIdx << ").getReg(), MRI, TRI))"; } }; -struct MatchMBBOp : public Matcher { - MatchMBBOp(unsigned OpIdx) : OpIdx(OpIdx) {} +/// Generates code to check that an operand is a basic block. +class MBBOperandMatcher : public OperandPredicateMatcher { +public: + void emitCxxPredicateExpr(raw_ostream &OS, const StringRef InsnVarName, + unsigned OpIdx) const override { + OS << InsnVarName << ".getOperand(" << OpIdx << ").isMBB()"; + } +}; + +/// Generates code to check that a set of predicates match for a particular +/// operand. +class OperandMatcher : public WithPredicates { +protected: unsigned OpIdx; - virtual void emit(raw_ostream &OS) const { - OS << "I.getOperand(" << OpIdx << ").isMBB()"; +public: + OperandMatcher(unsigned OpIdx) : OpIdx(OpIdx) {} + + /// Emit a C++ expression that tests whether the instruction named in + /// InsnVarName matches all the predicate and all the operands. + void emitCxxPredicateExpr(raw_ostream &OS, const StringRef InsnVarName) const { + OS << "("; + emitCxxPredicatesExpr(OS, InsnVarName, OpIdx); + OS << ")"; + } +}; + +/// Generates code to check a predicate on an instruction. +/// +/// Typical predicates include: +/// * The opcode of the instruction is a particular value. +/// * The nsw/nuw flag is/isn't set. +class InstructionPredicateMatcher { +public: + virtual ~InstructionPredicateMatcher() {} + + /// Emit a C++ expression that tests whether the instruction named in + /// InsnVarName matches the predicate. + virtual void emitCxxPredicateExpr(raw_ostream &OS, + const StringRef InsnVarName) const = 0; +}; + +/// Generates code to check the opcode of an instruction. +class InstructionOpcodeMatcher : public InstructionPredicateMatcher { +protected: + const CodeGenInstruction *I; + +public: + InstructionOpcodeMatcher(const CodeGenInstruction *I) : I(I) {} + + void emitCxxPredicateExpr(raw_ostream &OS, + const StringRef InsnVarName) const override { + OS << InsnVarName << ".getOpcode() == " << I->Namespace + << "::" << I->TheDef->getName(); + } +}; + +/// Generates code to check that a set of predicates and operands match for a +/// particular instruction. +/// +/// Typical predicates include: +/// * Has a specific opcode. +/// * Has an nsw/nuw flag or doesn't. +class InstructionMatcher : public WithPredicates { +protected: + std::vector Operands; + +public: + /// Add an operand to the matcher. + OperandMatcher &addOperand(unsigned OpIdx) { + Operands.emplace_back(OpIdx); + return Operands.back(); + } + + /// Emit a C++ expression that tests whether the instruction named in + /// InsnVarName matches all the predicates and all the operands. + void emitCxxPredicateExpr(raw_ostream &OS, const StringRef InsnVarName) const { + emitCxxPredicatesExpr(OS, InsnVarName); + for (const auto &Operand : Operands) { + OS << " && ("; + Operand.emitCxxPredicateExpr(OS, InsnVarName); + OS << ")"; + } } }; @@ -183,14 +338,25 @@ } }; -class MatcherEmitter { +/// Generates code to check that a match rule matches. +/// +/// This currently supports a single match position but could be extended to +/// support multiple positions to support div/rem fusion or load-multiple +/// instructions. +class RuleMatcher { const PatternToMatch &P; + std::vector> Matchers; + public: - std::vector> Matchers; std::vector> Actions; - MatcherEmitter(const PatternToMatch &P) : P(P) {} + RuleMatcher(const PatternToMatch &P) : P(P) {} + + InstructionMatcher &addInstructionMatcher() { + Matchers.emplace_back(new InstructionMatcher()); + return *Matchers.back(); + } void emit(raw_ostream &OS) { if (Matchers.empty()) @@ -199,9 +365,18 @@ OS << " // Src: " << *P.getSrcPattern() << "\n" << " // Dst: " << *P.getDstPattern() << "\n"; - OS << " if ((" << *Matchers.front() << ")"; - for (auto &MA : makeArrayRef(Matchers).drop_front()) - OS << " &&\n (" << *MA << ")"; + // The representation supports rules that require multiple roots such as: + // %ptr(p0) = ... + // %elt0(s32) = G_LOAD %ptr + // %1(p0) = G_ADD %ptr, 4 + // %elt1(s32) = G_LOAD p0 %1 + // which could be usefully folded into: + // %ptr(p0) = ... + // %elt0(s32), %elt1(s32) = TGT_LOAD_PAIR %ptr + // on some targets but we don't need to make use of that yet. + assert(Matchers.size() == 1 && "Cannot handle multi-root matchers yet"); + OS << " if ("; + Matchers.front()->emitCxxPredicateExpr(OS, "I"); OS << ") {\n"; for (auto &MA : Actions) @@ -235,7 +410,7 @@ GlobalISelEmitter::runOnPattern(const PatternToMatch &P, raw_ostream &OS) { // Keep track of the matchers and actions to emit. - MatcherEmitter M(P); + RuleMatcher M(P); // First, analyze the whole pattern. // If the entire pattern has a predicate (e.g., target features), ignore it. @@ -268,7 +443,8 @@ auto &SrcGI = *SrcGIOrNull; // The operators look good: match the opcode and mutate it to the new one. - M.Matchers.emplace_back(new MatchOpcode(&SrcGI)); + InstructionMatcher &InsnMatcher = M.addInstructionMatcher(); + InsnMatcher.addPredicate(&SrcGI); M.Actions.emplace_back(new MutateOpcode(&DstI)); // Next, analyze the children, only accepting patterns that don't require @@ -291,9 +467,10 @@ if (!OpTyOrNone) return SkipReason{"Dst operand has an unsupported type"}; - M.Matchers.emplace_back(new MatchRegOpType(OpIdx, *OpTyOrNone)); - M.Matchers.emplace_back( - new MatchRegOpBank(OpIdx, Target.getRegisterClass(DstIOpRec))); + OperandMatcher &OM = InsnMatcher.addOperand(OpIdx); + OM.addPredicate(*OpTyOrNone); + OM.addPredicate( + Target.getRegisterClass(DstIOpRec)); ++OpIdx; } @@ -316,7 +493,7 @@ if (SrcChild->getOperator()->isSubClassOf("SDNode")) { auto &ChildSDNI = CGP.getSDNodeInfo(SrcChild->getOperator()); if (ChildSDNI.getSDClassName() == "BasicBlockSDNode") { - M.Matchers.emplace_back(new MatchMBBOp(OpIdx++)); + InsnMatcher.addOperand(OpIdx++).addPredicate(); continue; } } @@ -341,9 +518,10 @@ if (!OpTyOrNone) return SkipReason{"Src operand has an unsupported type"}; - M.Matchers.emplace_back(new MatchRegOpType(OpIdx, *OpTyOrNone)); - M.Matchers.emplace_back( - new MatchRegOpBank(OpIdx, Target.getRegisterClass(ChildRec))); + OperandMatcher &OM = InsnMatcher.addOperand(OpIdx); + OM.addPredicate(*OpTyOrNone); + OM.addPredicate( + Target.getRegisterClass(ChildRec)); ++OpIdx; }