Index: include/llvm/CodeGen/GlobalISel/InstructionSelector.h =================================================================== --- include/llvm/CodeGen/GlobalISel/InstructionSelector.h +++ include/llvm/CodeGen/GlobalISel/InstructionSelector.h @@ -19,9 +19,11 @@ #include "llvm/ADT/Optional.h" #include #include +#include namespace llvm { class MachineInstr; +class MachineInstrBuilder; class MachineFunction; class MachineOperand; class MachineRegisterInfo; @@ -74,6 +76,9 @@ virtual bool select(MachineInstr &I) const = 0; protected: + typedef std::function ComplexRendererFn; + typedef Optional OptionalComplexRendererFn; + InstructionSelector(); /// Mutate the newly-selected instruction \p I to constrain its (possibly Index: include/llvm/Target/GlobalISel/Target.td =================================================================== --- include/llvm/Target/GlobalISel/Target.td +++ include/llvm/Target/GlobalISel/Target.td @@ -31,21 +31,13 @@ // Definitions that inherit from this may also inherit from // GIComplexPatternEquiv to enable the import of SelectionDAG patterns involving // those ComplexPatterns. -class GIComplexOperandMatcher { +class GIComplexOperandMatcher { // The expected type of the root of the match. // // TODO: We should probably support, any-type, any-scalar, and multiple types // in the future. LLT Type = type; - // The operands that result from a successful match - // Should be of the form '(ops ty1, ty2, ...)' where ty1/ty2 are definitions - // that inherit from Operand. - // - // FIXME: Which definition is used for ty1/ty2 doesn't actually matter at the - // moment. Only the number of operands is used. - dag Operands = operands; - // The function that determines whether the operand matches. It should be of // the form: // bool select(const MatchOperand &Root, MatchOperand &Result1) Index: lib/Target/AArch64/AArch64InstrFormats.td =================================================================== --- lib/Target/AArch64/AArch64InstrFormats.td +++ lib/Target/AArch64/AArch64InstrFormats.td @@ -693,11 +693,11 @@ def addsub_shifted_imm64_neg : addsub_shifted_imm_neg; def gi_addsub_shifted_imm32 : - GIComplexOperandMatcher, + GIComplexOperandMatcher, GIComplexPatternEquiv; def gi_addsub_shifted_imm64 : - GIComplexOperandMatcher, + GIComplexOperandMatcher, GIComplexPatternEquiv; class neg_addsub_shifted_imm Index: lib/Target/AArch64/AArch64InstructionSelector.cpp =================================================================== --- lib/Target/AArch64/AArch64InstructionSelector.cpp +++ lib/Target/AArch64/AArch64InstructionSelector.cpp @@ -65,8 +65,7 @@ bool selectCompareBranch(MachineInstr &I, MachineFunction &MF, MachineRegisterInfo &MRI) const; - bool selectArithImmed(MachineOperand &Root, MachineOperand &Result1, - MachineOperand &Result2) const; + OptionalComplexRendererFn selectArithImmed(MachineOperand &Root) const; const AArch64TargetMachine &TM; const AArch64Subtarget &STI; @@ -1328,9 +1327,8 @@ /// SelectArithImmed - Select an immediate value that can be represented as /// a 12-bit value shifted left by either 0 or 12. If so, return true with /// Val set to the 12-bit value and Shift set to the shifter operand. -bool AArch64InstructionSelector::selectArithImmed( - MachineOperand &Root, MachineOperand &Result1, - MachineOperand &Result2) const { +InstructionSelector::OptionalComplexRendererFn +AArch64InstructionSelector::selectArithImmed(MachineOperand &Root) const { MachineInstr &MI = *Root.getParent(); MachineBasicBlock &MBB = *MI.getParent(); MachineFunction &MF = *MBB.getParent(); @@ -1349,13 +1347,13 @@ else if (Root.isReg()) { MachineInstr *Def = MRI.getVRegDef(Root.getReg()); if (Def->getOpcode() != TargetOpcode::G_CONSTANT) - return false; + return None; MachineOperand &Op1 = Def->getOperand(1); if (!Op1.isCImm() || Op1.getCImm()->getBitWidth() > 64) - return false; + return None; Immed = Op1.getCImm()->getZExtValue(); } else - return false; + return None; unsigned ShiftAmt; @@ -1365,14 +1363,11 @@ ShiftAmt = 12; Immed = Immed >> 12; } else - return false; + return None; unsigned ShVal = AArch64_AM::getShifterImm(AArch64_AM::LSL, ShiftAmt); - Result1.ChangeToImmediate(Immed); - Result1.clearParent(); - Result2.ChangeToImmediate(ShVal); - Result2.clearParent(); - return true; + return OptionalComplexRendererFn( + [=](MachineInstrBuilder &MIB) { MIB.addImm(Immed).addImm(ShVal); }); } namespace llvm { Index: test/TableGen/GlobalISelEmitter-GIRule.td =================================================================== --- test/TableGen/GlobalISelEmitter-GIRule.td +++ test/TableGen/GlobalISelEmitter-GIRule.td @@ -19,8 +19,7 @@ } def complex : Operand; -def gi_complex : - GIComplexOperandMatcher; +def gi_complex : GIComplexOperandMatcher; def m1 : OperandWithDefaultOps ; def Z : OperandWithDefaultOps ; @@ -68,17 +67,15 @@ // CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && // CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) && // CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && -// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1)))) && +// CHECK-NEXT: ((Renderer0 = selectComplexPattern(MI0.getOperand(2)), Renderer0.hasValue())))) && // CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(3).getReg()) == (LLT::scalar(32))) && -// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(3), TempOp2, TempOp3))))) { +// CHECK-NEXT: ((Renderer1 = selectComplexPattern(MI0.getOperand(3)), Renderer1.hasValue()))))) { // CHECK-NEXT: // Rule1 // CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN2)); // CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); // CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/); -// CHECK-NEXT: MIB.add(TempOp2); -// CHECK-NEXT: MIB.add(TempOp3); -// CHECK-NEXT: MIB.add(TempOp0); -// CHECK-NEXT: MIB.add(TempOp1); +// CHECK-NEXT: Renderer1.getValue()(MIB); +// CHECK-NEXT: Renderer0.getValue()(MIB); // CHECK-NEXT: for (const auto *FromMI : {&MI0, }) // CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) // CHECK-NEXT: MIB.addMemOperand(MMO); @@ -314,13 +311,12 @@ // CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && // CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) && // CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && -// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1))))) { +// CHECK-NEXT: ((Renderer0 = selectComplexPattern(MI0.getOperand(2)), Renderer0.hasValue()))))) { // CHECK-NEXT: // Rule5 // CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN1)); // CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); // CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/); -// CHECK-NEXT: MIB.add(TempOp0); -// CHECK-NEXT: MIB.add(TempOp1); +// CHECK-NEXT: Renderer0.getValue()(MIB); // CHECK-NEXT: for (const auto *FromMI : {&MI0, }) // CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) // CHECK-NEXT: MIB.addMemOperand(MMO); @@ -344,6 +340,7 @@ ]>; //===- Test a simple pattern with a default operand. ----------------------===// +// // CHECK-LABEL: if ([&]() { // CHECK-NEXT: MachineInstr &MI0 = I; @@ -471,7 +468,6 @@ // CHECK-NEXT: return false; // CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_BR) && // CHECK-NEXT: ((/* target */ (MI0.getOperand(0).isMBB())))) { - // CHECK-NEXT: // (br (bb:Other):$target) => (BR (bb:Other):$target) // CHECK-NEXT: I.setDesc(TII.get(MyTarget::BR)); // CHECK-NEXT: MachineInstr &NewI = I; Index: test/TableGen/GlobalISelEmitter.td =================================================================== --- test/TableGen/GlobalISelEmitter.td +++ test/TableGen/GlobalISelEmitter.td @@ -22,7 +22,7 @@ let MIOperandInfo = (ops i32imm, i32imm); } def gi_complex : - GIComplexOperandMatcher, + GIComplexOperandMatcher, GIComplexPatternEquiv; def m1 : OperandWithDefaultOps ; @@ -72,17 +72,15 @@ // CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && // CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) && // CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && -// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1)))) && +// CHECK-NEXT: ((Renderer0 = selectComplexPattern(MI0.getOperand(2)), Renderer0.hasValue())))) && // CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(3).getReg()) == (LLT::scalar(32))) && -// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(3), TempOp2, TempOp3))))) { +// CHECK-NEXT: ((Renderer1 = selectComplexPattern(MI0.getOperand(3)), Renderer1.hasValue()))))) { // CHECK-NEXT: // (select:i32 GPR32:i32:$src1, complex:i32:$src2, complex:i32:$src3) => (INSN2:i32 GPR32:i32:$src1, complex:i32:$src3, complex:i32:$src2) // CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN2)); // CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); // CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/); -// CHECK-NEXT: MIB.add(TempOp2); -// CHECK-NEXT: MIB.add(TempOp3); -// CHECK-NEXT: MIB.add(TempOp0); -// CHECK-NEXT: MIB.add(TempOp1); +// CHECK-NEXT: Renderer1.getValue()(MIB); +// CHECK-NEXT: Renderer0.getValue()(MIB); // CHECK-NEXT: for (const auto *FromMI : {&MI0, }) // CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) // CHECK-NEXT: MIB.addMemOperand(MMO); @@ -263,13 +261,12 @@ // CHECK-NEXT: ((/* src1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && // CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) && // CHECK-NEXT: ((/* src2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && -// CHECK-NEXT: (selectComplexPattern(MI0.getOperand(2), TempOp0, TempOp1))))) { +// CHECK-NEXT: ((Renderer0 = selectComplexPattern(MI0.getOperand(2)), Renderer0.hasValue()))))) { // CHECK-NEXT: // (sub:i32 GPR32:i32:$src1, complex:i32:$src2) => (INSN1:i32 GPR32:i32:$src1, complex:i32:$src2) // CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::INSN1)); // CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); // CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/); -// CHECK-NEXT: MIB.add(TempOp0); -// CHECK-NEXT: MIB.add(TempOp1); +// CHECK-NEXT: Renderer0.getValue()(MIB); // CHECK-NEXT: for (const auto *FromMI : {&MI0, }) // CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) // CHECK-NEXT: MIB.addMemOperand(MMO); Index: utils/TableGen/GlobalISelEmitter.cpp =================================================================== --- utils/TableGen/GlobalISelEmitter.cpp +++ utils/TableGen/GlobalISelEmitter.cpp @@ -110,35 +110,6 @@ }; class InstructionMatcher; -class OperandPlaceholder { -private: - enum PlaceholderKind { - OP_Temporary, - } Kind; - - struct TemporaryData { - unsigned OpIdx; - }; - - union { - struct TemporaryData Temporary; - }; - - OperandPlaceholder(PlaceholderKind Kind) : Kind(Kind) {} - -public: - ~OperandPlaceholder() {} - - static OperandPlaceholder CreateTemporary(unsigned OpIdx) { - OperandPlaceholder Result(OP_Temporary); - Result.Temporary.OpIdx = OpIdx; - return Result; - } - - void emitCxxValueExpr(raw_ostream &OS) const; - void emitTblgen(raw_ostream &OS, IndentStr Indent) const; -}; - /// Convert an MVT to an equivalent LLT if possible, or the invalid LLT() for /// MVTs that don't map cleanly to an LLT (e.g., iPTR, *any, ...). static Optional MVTToLLT(MVT::SimpleValueType SVT) { @@ -259,7 +230,7 @@ /// Report the maximum number of temporary operands needed by the rule /// matcher. - unsigned countTemporaryOperands() const; + unsigned countRendererFns() const; // FIXME: Remove this as soon as possible InstructionMatcher &insnmatcher_front() const { return *Matchers.front(); } @@ -370,7 +341,7 @@ /// Report the maximum number of temporary operands needed by the predicate /// matcher. - virtual unsigned countTemporaryOperands() const { return 0; } + virtual unsigned countRendererFns() const { return 0; } }; /// Generates code to check that an operand is a particular LLT. @@ -404,10 +375,6 @@ const OperandMatcher &Operand; const Record &TheDef; - unsigned getNumOperands() const { - return TheDef.getValueAsDag("Operands")->getNumArgs(); - } - unsigned getAllocatedTemporariesBaseID() const; public: @@ -422,21 +389,17 @@ void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule, StringRef OperandExpr) const override { - OS << TheDef.getValueAsString("MatcherFn") << "(" << OperandExpr; - for (unsigned I = 0; I < getNumOperands(); ++I) { - OS << ", "; - OperandPlaceholder::CreateTemporary(getAllocatedTemporariesBaseID() + I) - .emitCxxValueExpr(OS); - } - OS << ")"; + unsigned ID = getAllocatedTemporariesBaseID(); + OS << "(Renderer" << ID << " = " << TheDef.getValueAsString("MatcherFn") + << "(" << OperandExpr << "), Renderer" << ID << ".hasValue())"; } void emitTblgen(raw_ostream &OS, IndentStr Indent) const override { OS << Indent << "GIMatchComplexPattern<" << TheDef.getName() << ">"; } - unsigned countTemporaryOperands() const override { - return getNumOperands(); + unsigned countRendererFns() const override { + return 1; } }; @@ -540,7 +503,7 @@ /// The index of the first temporary variable allocated to this operand. The /// number of allocated temporaries can be found with - /// countTemporaryOperands(). + /// countRendererFns(). unsigned AllocatedTemporariesBaseID; public: @@ -633,12 +596,12 @@ /// Report the maximum number of temporary operands needed by the operand /// matcher. - unsigned countTemporaryOperands() const { + unsigned countRendererFns() const { return std::accumulate( predicates().begin(), predicates().end(), 0, [](unsigned A, const std::unique_ptr &Predicate) { - return A + Predicate->countTemporaryOperands(); + return A + Predicate->countRendererFns(); }); } @@ -689,7 +652,7 @@ /// Report the maximum number of temporary operands needed by the predicate /// matcher. - virtual unsigned countTemporaryOperands() const { return 0; } + virtual unsigned countRendererFns() const { return 0; } }; /// Generates code to check the opcode of an instruction. @@ -869,17 +832,17 @@ /// Report the maximum number of temporary operands needed by the instruction /// matcher. - unsigned countTemporaryOperands() const { + unsigned countRendererFns() const { return std::accumulate(predicates().begin(), predicates().end(), 0, [](unsigned A, const std::unique_ptr &Predicate) { - return A + Predicate->countTemporaryOperands(); + return A + Predicate->countRendererFns(); }) + std::accumulate( Operands.begin(), Operands.end(), 0, [](unsigned A, const std::unique_ptr &Operand) { - return A + Operand->countTemporaryOperands(); + return A + Operand->countRendererFns(); }); } }; @@ -940,22 +903,6 @@ }; //===- Actions ------------------------------------------------------------===// -void OperandPlaceholder::emitCxxValueExpr(raw_ostream &OS) const { - switch (Kind) { - case OP_Temporary: - OS << "TempOp" << Temporary.OpIdx; - break; - } -} - -void OperandPlaceholder::emitTblgen(raw_ostream &OS, IndentStr Indent) const { - switch (Kind) { - case OP_Temporary: - OS << "TempOp" << Temporary.OpIdx; - break; - } -} - class OperandRenderer { public: enum RendererKind { OR_Copy, OR_Imm, OR_Register, OR_ComplexPattern }; @@ -1059,7 +1006,7 @@ const Record &TheDef; /// The name of the operand. const StringRef SymbolicName; - std::vector Sources; + unsigned RendererID; unsigned getNumOperands() const { return TheDef.getValueAsDag("Operands")->getNumArgs(); @@ -1067,33 +1014,21 @@ public: RenderComplexPatternOperand(const Record &TheDef, StringRef SymbolicName, - const ArrayRef Sources) + unsigned RendererID) : OperandRenderer(OR_ComplexPattern), TheDef(TheDef), - SymbolicName(SymbolicName), Sources(Sources) {} + SymbolicName(SymbolicName), RendererID(RendererID) {} static bool classof(const OperandRenderer *R) { return R->getKind() == OR_ComplexPattern; } void emitCxxRenderStmts(raw_ostream &OS, RuleMatcher &Rule) const override { - assert(Sources.size() == getNumOperands() && "Inconsistent number of operands"); - for (const auto &Source : Sources) { - OS << "MIB.add("; - Source.emitCxxValueExpr(OS); - OS << ");\n"; - } + OS << "Renderer" << RendererID << ".getValue()(MIB);\n"; } void emitTblgen(raw_ostream &OS, IndentStr Indent) const override { OS << Indent << "GIAddComplexOperand<" << TheDef.getName() << ", \"" - << SymbolicName << "\" /* "; - StringRef Separator = ""; - for (const OperandPlaceholder &Placeholder : Sources) { - OS << Separator; - Placeholder.emitTblgen(OS, Indent); - Separator = ", "; - } - OS << " */>"; + << SymbolicName << "\" /* Renderer" << RendererID << " */>"; } }; @@ -1409,11 +1344,11 @@ return false; } -unsigned RuleMatcher::countTemporaryOperands() const { +unsigned RuleMatcher::countRendererFns() const { return std::accumulate( Matchers.begin(), Matchers.end(), 0, [](unsigned A, const std::unique_ptr &Matcher) { - return A + Matcher->countTemporaryOperands(); + return A + Matcher->countRendererFns(); }); } @@ -1625,9 +1560,9 @@ return failedImport("SelectionDAG ComplexPattern (" + ChildRec->getName() + ") not mapped to GlobalISel"); - const auto &Predicate = OM.addPredicate( - OM, *ComplexPattern->second); - TempOpIdx += Predicate.countTemporaryOperands(); + OM.addPredicate(OM, + *ComplexPattern->second); + TempOpIdx++; return Error::success(); } @@ -1698,13 +1633,10 @@ return failedImport( "SelectionDAG ComplexPattern not mapped to GlobalISel"); - SmallVector RenderedOperands; const OperandMatcher &OM = InsnMatcher.getOperand(DstChild->getName()); - for (unsigned I = 0; I < OM.countTemporaryOperands(); ++I) - RenderedOperands.push_back(OperandPlaceholder::CreateTemporary( - OM.getAllocatedTemporariesBaseID() + I)); DstMIBuilder.addRenderer( - *ComplexPattern->second, DstChild->getName(), RenderedOperands); + *ComplexPattern->second, DstChild->getName(), + OM.getAllocatedTemporariesBaseID()); return Error::success(); } @@ -1925,9 +1857,9 @@ if (PredicateMatchDef->isSubClassOf("GIMatchComplexPattern")) { Record *CPDef = PredicateMatchDef->getValueAsDef("Matcher"); - auto &Predicate = OpndMatcher.addPredicate( - OpndMatcher, *CPDef); - TempOpIdx += Predicate.countTemporaryOperands(); + OpndMatcher.addPredicate(OpndMatcher, + *CPDef); + TempOpIdx++; continue; } @@ -2052,14 +1984,10 @@ if (RendererDef->isSubClassOf("GIAddComplexOperand")) { std::string Name = RendererDef->getValueAsString("Name"); - SmallVector RenderedOperands; const OperandMatcher &OM = Matcher.insnmatcher_front().getOperand(Name); - for (unsigned I = 0; I < OM.countTemporaryOperands(); ++I) - RenderedOperands.push_back(OperandPlaceholder::CreateTemporary( - OM.getAllocatedTemporariesBaseID() + I)); - Action.addRenderer( - *RendererDef->getValueAsDef("Matcher"), Name, RenderedOperands); + *RendererDef->getValueAsDef("Matcher"), Name, + OM.getAllocatedTemporariesBaseID()); continue; } @@ -2093,16 +2021,16 @@ unsigned MaxTemporaries = 0; for (const auto &Rule : Rules) - MaxTemporaries = std::max(MaxTemporaries, Rule.countTemporaryOperands()); + MaxTemporaries = std::max(MaxTemporaries, Rule.countRendererFns()); OS << "#ifdef GET_GLOBALISEL_TEMPORARIES_DECL\n"; for (unsigned I = 0; I < MaxTemporaries; ++I) - OS << " mutable MachineOperand TempOp" << I << ";\n"; + OS << " mutable OptionalComplexRendererFn Renderer" << I << ";\n"; OS << "#endif // ifdef GET_GLOBALISEL_TEMPORARIES_DECL\n\n"; OS << "#ifdef GET_GLOBALISEL_TEMPORARIES_INIT\n"; for (unsigned I = 0; I < MaxTemporaries; ++I) - OS << ", TempOp" << I << "(MachineOperand::CreatePlaceholder())\n"; + OS << ", Renderer" << I << "(None)\n"; OS << "#endif // ifdef GET_GLOBALISEL_TEMPORARIES_INIT\n\n"; OS << "#ifdef GET_GLOBALISEL_IMPL\n";