Index: llvm/trunk/include/llvm/CodeGen/GlobalISel/InstructionSelector.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/GlobalISel/InstructionSelector.h +++ llvm/trunk/include/llvm/CodeGen/GlobalISel/InstructionSelector.h @@ -67,6 +67,8 @@ bool isOperandImmEqual(const MachineOperand &MO, int64_t Value, const MachineRegisterInfo &MRI) const; + + bool isObviouslySafeToFold(MachineInstr &MI) const; }; } // End namespace llvm. Index: llvm/trunk/include/llvm/Target/GlobalISel/SelectionDAGCompat.td =================================================================== --- llvm/trunk/include/llvm/Target/GlobalISel/SelectionDAGCompat.td +++ llvm/trunk/include/llvm/Target/GlobalISel/SelectionDAGCompat.td @@ -25,6 +25,8 @@ SDNode Node = node; } +def : GINodeEquiv; +def : GINodeEquiv; def : GINodeEquiv; def : GINodeEquiv; def : GINodeEquiv; Index: llvm/trunk/lib/CodeGen/GlobalISel/InstructionSelector.cpp =================================================================== --- llvm/trunk/lib/CodeGen/GlobalISel/InstructionSelector.cpp +++ llvm/trunk/lib/CodeGen/GlobalISel/InstructionSelector.cpp @@ -94,3 +94,8 @@ return *VRegVal == Value; return false; } + +bool InstructionSelector::isObviouslySafeToFold(MachineInstr &MI) const { + return !MI.mayLoadOrStore() && !MI.hasUnmodeledSideEffects() && + MI.implicit_operands().begin() == MI.implicit_operands().end(); +} Index: llvm/trunk/lib/Target/AArch64/AArch64InstructionSelector.cpp =================================================================== --- llvm/trunk/lib/Target/AArch64/AArch64InstructionSelector.cpp +++ llvm/trunk/lib/Target/AArch64/AArch64InstructionSelector.cpp @@ -840,42 +840,6 @@ // operands to use appropriate classes. return constrainSelectedInstRegOperands(I, TII, TRI, RBI); } - case TargetOpcode::G_MUL: { - // Reject the various things we don't support yet. - if (unsupportedBinOp(I, RBI, MRI, TRI)) - return false; - - const unsigned DefReg = I.getOperand(0).getReg(); - const RegisterBank &RB = *RBI.getRegBank(DefReg, MRI, TRI); - - if (RB.getID() != AArch64::GPRRegBankID) { - DEBUG(dbgs() << "G_MUL on bank: " << RB << ", expected: GPR\n"); - return false; - } - - unsigned ZeroReg; - unsigned NewOpc; - if (Ty.isScalar() && Ty.getSizeInBits() <= 32) { - NewOpc = AArch64::MADDWrrr; - ZeroReg = AArch64::WZR; - } else if (Ty == LLT::scalar(64)) { - NewOpc = AArch64::MADDXrrr; - ZeroReg = AArch64::XZR; - } else { - DEBUG(dbgs() << "G_MUL has type: " << Ty << ", expected: " - << LLT::scalar(32) << " or " << LLT::scalar(64) << '\n'); - return false; - } - - I.setDesc(TII.get(NewOpc)); - - I.addOperand(MachineOperand::CreateReg(ZeroReg, /*isDef=*/false)); - - // Now that we selected an opcode, we need to constrain the register - // operands to use appropriate classes. - return constrainSelectedInstRegOperands(I, TII, TRI, RBI); - } - case TargetOpcode::G_FADD: case TargetOpcode::G_FSUB: case TargetOpcode::G_FMUL: Index: llvm/trunk/test/CodeGen/AArch64/GlobalISel/select-int-ext.mir =================================================================== --- llvm/trunk/test/CodeGen/AArch64/GlobalISel/select-int-ext.mir +++ llvm/trunk/test/CodeGen/AArch64/GlobalISel/select-int-ext.mir @@ -3,21 +3,23 @@ --- | target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" - define void @anyext_s64_s32() { ret void } - define void @anyext_s32_s8() { ret void } + define void @anyext_s64_from_s32() { ret void } + define void @anyext_s32_from_s8() { ret void } - define void @zext_s64_s32() { ret void } - define void @zext_s32_s8() { ret void } - define void @zext_s16_s8() { ret void } - - define void @sext_s64_s32() { ret void } - define void @sext_s32_s8() { ret void } - define void @sext_s16_s8() { ret void } + define void @zext_s64_from_s32() { ret void } + define void @zext_s32_from_s16() { ret void } + define void @zext_s32_from_s8() { ret void } + define void @zext_s16_from_s8() { ret void } + + define void @sext_s64_from_s32() { ret void } + define void @sext_s32_from_s16() { ret void } + define void @sext_s32_from_s8() { ret void } + define void @sext_s16_from_s8() { ret void } ... --- -# CHECK-LABEL: name: anyext_s64_s32 -name: anyext_s64_s32 +# CHECK-LABEL: name: anyext_s64_from_s32 +name: anyext_s64_from_s32 legalized: true regBankSelected: true @@ -43,8 +45,8 @@ ... --- -# CHECK-LABEL: name: anyext_s32_s8 -name: anyext_s32_s8 +# CHECK-LABEL: name: anyext_s32_from_s8 +name: anyext_s32_from_s8 legalized: true regBankSelected: true @@ -68,8 +70,8 @@ ... --- -# CHECK-LABEL: name: zext_s64_s32 -name: zext_s64_s32 +# CHECK-LABEL: name: zext_s64_from_s32 +name: zext_s64_from_s32 legalized: true regBankSelected: true @@ -95,8 +97,33 @@ ... --- -# CHECK-LABEL: name: zext_s32_s8 -name: zext_s32_s8 +# CHECK-LABEL: name: zext_s32_from_s16 +name: zext_s32_from_s16 +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr32 } +# CHECK-NEXT: - { id: 1, class: gpr32 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + +# CHECK: body: +# CHECK: %0 = COPY %w0 +# CHECK: %1 = UBFMWri %0, 0, 15 +body: | + bb.0: + liveins: %w0 + + %0(s16) = COPY %w0 + %1(s32) = G_ZEXT %0 + %w0 = COPY %1 +... + +--- +# CHECK-LABEL: name: zext_s32_from_s8 +name: zext_s32_from_s8 legalized: true regBankSelected: true @@ -120,8 +147,8 @@ ... --- -# CHECK-LABEL: name: zext_s16_s8 -name: zext_s16_s8 +# CHECK-LABEL: name: zext_s16_from_s8 +name: zext_s16_from_s8 legalized: true regBankSelected: true @@ -145,8 +172,8 @@ ... --- -# CHECK-LABEL: name: sext_s64_s32 -name: sext_s64_s32 +# CHECK-LABEL: name: sext_s64_from_s32 +name: sext_s64_from_s32 legalized: true regBankSelected: true @@ -172,8 +199,33 @@ ... --- -# CHECK-LABEL: name: sext_s32_s8 -name: sext_s32_s8 +# CHECK-LABEL: name: sext_s32_from_s16 +name: sext_s32_from_s16 +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr32 } +# CHECK-NEXT: - { id: 1, class: gpr32 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + +# CHECK: body: +# CHECK: %0 = COPY %w0 +# CHECK: %1 = SBFMWri %0, 0, 15 +body: | + bb.0: + liveins: %w0 + + %0(s16) = COPY %w0 + %1(s32) = G_SEXT %0 + %w0 = COPY %1 +... + +--- +# CHECK-LABEL: name: sext_s32_from_s8 +name: sext_s32_from_s8 legalized: true regBankSelected: true @@ -197,8 +249,8 @@ ... --- -# CHECK-LABEL: name: sext_s16_s8 -name: sext_s16_s8 +# CHECK-LABEL: name: sext_s16_from_s8 +name: sext_s16_from_s8 legalized: true regBankSelected: true Index: llvm/trunk/test/CodeGen/AArch64/GlobalISel/select-muladd.mir =================================================================== --- llvm/trunk/test/CodeGen/AArch64/GlobalISel/select-muladd.mir +++ llvm/trunk/test/CodeGen/AArch64/GlobalISel/select-muladd.mir @@ -0,0 +1,50 @@ +# RUN: llc -O0 -mtriple=aarch64-- -run-pass=instruction-select -verify-machineinstrs -global-isel %s -o - | FileCheck %s + +--- | + target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" + + define void @SMADDLrrr_gpr() { ret void } +... + +--- +# CHECK-LABEL: name: SMADDLrrr_gpr +name: SMADDLrrr_gpr +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr64 } +# CHECK-NEXT: - { id: 1, class: gpr32 } +# CHECK-NEXT: - { id: 2, class: gpr32 } +# CHECK-NEXT: - { id: 3, class: gpr } +# CHECK-NEXT: - { id: 4, class: gpr } +# CHECK-NEXT: - { id: 5, class: gpr } +# CHECK-NEXT: - { id: 6, class: gpr64 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + - { id: 2, class: gpr } + - { id: 3, class: gpr } + - { id: 4, class: gpr } + - { id: 5, class: gpr } + - { id: 6, class: gpr } + +# CHECK: body: +# CHECK: %0 = COPY %x0 +# CHECK: %1 = COPY %w1 +# CHECK: %2 = COPY %w2 +# CHECK: %6 = SMADDLrrr %1, %2, %0 +body: | + bb.0: + liveins: %x0, %w1, %w2 + + %0(s64) = COPY %x0 + %1(s32) = COPY %w1 + %2(s32) = COPY %w2 + %3(s64) = G_SEXT %1 + %4(s64) = G_SEXT %2 + %5(s64) = G_MUL %3, %4 + %6(s64) = G_ADD %0, %5 + %x0 = COPY %6 +... + Index: llvm/trunk/test/TableGen/GlobalISelEmitter.td =================================================================== --- llvm/trunk/test/TableGen/GlobalISelEmitter.td +++ llvm/trunk/test/TableGen/GlobalISelEmitter.td @@ -51,6 +51,91 @@ def ADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2), [(set GPR32:$dst, (add GPR32:$src1, GPR32:$src2))]>; +//===- Test a nested instruction match. -----------------------------------===// + +// CHECK-LABEL: if ([&]() { +// CHECK-NEXT: MachineInstr &MI0 = I; +// CHECK-NEXT: if (MI0.getNumOperands() < 3) +// CHECK-NEXT: return false; +// CHECK-NEXT: if (!MI0.getOperand(1).isReg()) +// CHECK-NEXT: return false; +// CHECK-NEXT: MachineInstr &MI1 = *MRI.getVRegDef(MI0.getOperand(1).getReg()); +// CHECK-NEXT: if (MI1.getNumOperands() < 3) +// CHECK-NEXT: return false; +// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_MUL) && +// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* Operand 1 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: (((MI1.getOpcode() == TargetOpcode::G_ADD) && +// CHECK-NEXT: ((/* Operand 0 */ (MRI.getType(MI1.getOperand(0).getReg()) == (LLT::scalar(32))))) && +// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI1.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(1).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI1.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(2).getReg(), MRI, TRI)))))) +// CHECK-NEXT: ))) && +// CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(2).getReg(), MRI, TRI)))))) { +// CHECK-NEXT: if (!isObviouslySafeToFold(MI1)) return false; +// CHECK-NEXT: // (mul:i32 (add:i32 GPR32:i32:$src1, GPR32:i32:$src2), GPR32:i32:$src3) => (MULADD:i32 GPR32:i32:$src1, GPR32:i32:$src2, GPR32:i32:$src3) +// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::MULADD)); +// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); +// CHECK-NEXT: MIB.add(MI1.getOperand(1)/*src1*/); +// CHECK-NEXT: MIB.add(MI1.getOperand(2)/*src2*/); +// CHECK-NEXT: MIB.add(MI0.getOperand(2)/*src3*/); +// CHECK-NEXT: for (const auto *FromMI : {&MI0, &MI1, }) +// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) +// CHECK-NEXT: MIB.addMemOperand(MMO); +// CHECK-NEXT: I.eraseFromParent(); +// CHECK-NEXT: MachineInstr &NewI = *MIB; +// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); +// CHECK-NEXT: return true; +// CHECK-NEXT: } + +// We also get a second rule by commutativity. +// CHECK-LABEL: if ([&]() { +// CHECK-NEXT: MachineInstr &MI0 = I; +// CHECK-NEXT: if (MI0.getNumOperands() < 3) +// CHECK-NEXT: return false; +// CHECK-NEXT: if (!MI0.getOperand(2).isReg()) +// CHECK-NEXT: return false; +// CHECK-NEXT: MachineInstr &MI1 = *MRI.getVRegDef(MI0.getOperand(2).getReg()); +// CHECK-NEXT: if (MI1.getNumOperands() < 3) +// CHECK-NEXT: return false; +// CHECK-NEXT: if ((MI0.getOpcode() == TargetOpcode::G_MUL) && +// CHECK-NEXT: ((/* dst */ (MRI.getType(MI0.getOperand(0).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(0).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src3 */ (MRI.getType(MI0.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI0.getOperand(1).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* Operand 2 */ (MRI.getType(MI0.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: (((MI1.getOpcode() == TargetOpcode::G_ADD) && +// CHECK-NEXT: ((/* Operand 0 */ (MRI.getType(MI1.getOperand(0).getReg()) == (LLT::scalar(32))))) && +// CHECK-NEXT: ((/* src1 */ (MRI.getType(MI1.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(1).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src2 */ (MRI.getType(MI1.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(MI1.getOperand(2).getReg(), MRI, TRI)))))) +// CHECK-NEXT: )))) { +// CHECK-NEXT: if (!isObviouslySafeToFold(MI1)) return false; +// CHECK-NEXT: // (mul:i32 GPR32:i32:$src3, (add:i32 GPR32:i32:$src1, GPR32:i32:$src2)) => (MULADD:i32 GPR32:i32:$src1, GPR32:i32:$src2, GPR32:i32:$src3) +// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::MULADD)); +// CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); +// CHECK-NEXT: MIB.add(MI1.getOperand(1)/*src1*/); +// CHECK-NEXT: MIB.add(MI1.getOperand(2)/*src2*/); +// CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src3*/); +// CHECK-NEXT: for (const auto *FromMI : {&MI0, &MI1, }) +// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) +// CHECK-NEXT: MIB.addMemOperand(MMO); +// CHECK-NEXT: I.eraseFromParent(); +// CHECK-NEXT: MachineInstr &NewI = *MIB; +// CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); +// CHECK-NEXT: return true; +// CHECK-NEXT: } + +def MULADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2, GPR32:$src3), + [(set GPR32:$dst, + (mul (add GPR32:$src1, GPR32:$src2), GPR32:$src3))]>; + +//===- Test another simple pattern with regclass operands. ----------------===// + // CHECK-LABEL: if ([&]() { // CHECK-NEXT: MachineInstr &MI0 = I; // CHECK-NEXT: if (MI0.getNumOperands() < 3) @@ -67,7 +152,9 @@ // CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); // CHECK-NEXT: MIB.add(MI0.getOperand(2)/*src2*/); // CHECK-NEXT: MIB.add(MI0.getOperand(1)/*src1*/); -// CHECK-NEXT: MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end()); +// CHECK-NEXT: for (const auto *FromMI : {&MI0, }) +// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) +// CHECK-NEXT: MIB.addMemOperand(MMO); // CHECK-NEXT: I.eraseFromParent(); // CHECK-NEXT: MachineInstr &NewI = *MIB; // CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); @@ -100,7 +187,9 @@ // CHECK-NEXT: MIB.add(MI0.getOperand(0)/*dst*/); // CHECK-NEXT: MIB.addReg(MyTarget::R0); // CHECK-NEXT: MIB.add(MI0.getOperand(1)/*Wm*/); -// CHECK-NEXT: MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end()); +// CHECK-NEXT: for (const auto *FromMI : {&MI0, }) +// CHECK-NEXT: for (const auto &MMO : FromMI->memoperands()) +// CHECK-NEXT: MIB.addMemOperand(MMO); // CHECK-NEXT: I.eraseFromParent(); // CHECK-NEXT: MachineInstr &NewI = *MIB; // CHECK-NEXT: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); Index: llvm/trunk/utils/TableGen/GlobalISelEmitter.cpp =================================================================== --- llvm/trunk/utils/TableGen/GlobalISelEmitter.cpp +++ llvm/trunk/utils/TableGen/GlobalISelEmitter.cpp @@ -152,6 +152,7 @@ //===- Matchers -----------------------------------------------------------===// +class OperandMatcher; class MatchAction; /// Generates code to check that a match rule matches. @@ -187,6 +188,7 @@ StringRef Value); StringRef getInsnVarName(const InstructionMatcher &InsnMatcher) const; + void emitCxxCapturedInsnList(raw_ostream &OS); void emitCxxCaptureStmts(raw_ostream &OS, StringRef Expr); void emit(raw_ostream &OS); @@ -257,6 +259,7 @@ /// are represented by a virtual register defined by a G_CONSTANT instruction. enum PredicateKind { OPM_ComplexPattern, + OPM_Instruction, OPM_Int, OPM_LLT, OPM_RegBank, @@ -272,6 +275,23 @@ PredicateKind getKind() const { return Kind; } + /// Return the OperandMatcher for the specified operand or nullptr if there + /// isn't one by that name in this operand predicate matcher. + /// + /// InstructionOperandMatcher is the only subclass that can return non-null + /// for this. + virtual Optional + getOptionalOperand(const StringRef SymbolicName) const { + assert(!SymbolicName.empty() && "Cannot lookup unnamed operand"); + return None; + } + + /// Emit C++ statements to capture instructions into local variables. + /// + /// Only InstructionOperandMatcher needs to do anything for this method. + virtual void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule, + StringRef Expr) const {} + /// Emit a C++ expression that checks the predicate for the given operand. virtual void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule, StringRef OperandExpr) const = 0; @@ -422,8 +442,28 @@ return (InsnVarName + ".getOperand(" + llvm::to_string(OpIdx) + ")").str(); } + Optional + getOptionalOperand(StringRef DesiredSymbolicName) const { + assert(!DesiredSymbolicName.empty() && "Cannot lookup unnamed operand"); + if (DesiredSymbolicName == SymbolicName) + return this; + for (const auto &OP : predicates()) { + const auto &MaybeOperand = OP->getOptionalOperand(DesiredSymbolicName); + if (MaybeOperand.hasValue()) + return MaybeOperand.getValue(); + } + return None; + } + InstructionMatcher &getInstructionMatcher() const { return Insn; } + /// Emit C++ statements to capture instructions into local variables. + void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule, + StringRef OperandExpr) const { + for (const auto &Predicate : predicates()) + Predicate->emitCxxCaptureStmts(OS, Rule, OperandExpr); + } + /// Emit a C++ expression that tests whether the instruction named in /// InsnVarName matches all the predicate and all the operands. void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule, @@ -581,14 +621,14 @@ llvm_unreachable("Failed to lookup operand"); } - Optional getOptionalOperand(StringRef SymbolicName) const { + Optional + getOptionalOperand(StringRef SymbolicName) const { assert(!SymbolicName.empty() && "Cannot lookup unnamed operand"); - const auto &I = std::find_if(Operands.begin(), Operands.end(), - [&SymbolicName](const OperandMatcher &X) { - return X.getSymbolicName() == SymbolicName; - }); - if (I != Operands.end()) - return &*I; + for (const auto &Operand : Operands) { + const auto &OM = Operand.getOptionalOperand(SymbolicName); + if (OM.hasValue()) + return OM.getValue(); + } return None; } @@ -600,6 +640,11 @@ } unsigned getNumOperands() const { return Operands.size(); } + OperandVec::iterator operands_begin() { return Operands.begin(); } + OperandVec::iterator operands_end() { return Operands.end(); } + iterator_range operands() { + return make_range(operands_begin(), operands_end()); + } OperandVec::const_iterator operands_begin() const { return Operands.begin(); } OperandVec::const_iterator operands_end() const { return Operands.end(); } iterator_range operands() const { @@ -608,12 +653,12 @@ /// Emit C++ statements to check the shape of the match and capture /// instructions into local variables. - /// - /// TODO: When nested instruction matching is implemented, this function will - /// descend into the operands and capture variables. void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule, StringRef Expr) { OS << "if (" << Expr << ".getNumOperands() < " << getNumOperands() << ")\n" << " return false;\n"; + for (const auto &Operand : Operands) { + Operand.emitCxxCaptureStmts(OS, Rule, Operand.getOperandExpr(Expr)); + } } /// Emit a C++ expression that tests whether the instruction named in @@ -671,6 +716,55 @@ } }; +/// Generates code to check that the operand is a register defined by an +/// instruction that matches the given instruction matcher. +/// +/// For example, the pattern: +/// (set $dst, (G_MUL (G_ADD $src1, $src2), $src3)) +/// would use an InstructionOperandMatcher for operand 1 of the G_MUL to match +/// the: +/// (G_ADD $src1, $src2) +/// subpattern. +class InstructionOperandMatcher : public OperandPredicateMatcher { +protected: + std::unique_ptr InsnMatcher; + +public: + InstructionOperandMatcher() + : OperandPredicateMatcher(OPM_Instruction), + InsnMatcher(new InstructionMatcher()) {} + + static bool classof(const OperandPredicateMatcher *P) { + return P->getKind() == OPM_Instruction; + } + + InstructionMatcher &getInsnMatcher() const { return *InsnMatcher; } + + Optional + getOptionalOperand(StringRef SymbolicName) const override { + assert(!SymbolicName.empty() && "Cannot lookup unnamed operand"); + return InsnMatcher->getOptionalOperand(SymbolicName); + } + + void emitCxxCaptureStmts(raw_ostream &OS, RuleMatcher &Rule, + StringRef OperandExpr) const override { + OS << "if (!" << OperandExpr + ".isReg())\n" + << " return false;\n"; + std::string InsnVarName = Rule.defineInsnVar( + OS, *InsnMatcher, + ("*MRI.getVRegDef(" + OperandExpr + ".getReg())").str()); + InsnMatcher->emitCxxCaptureStmts(OS, Rule, InsnVarName); + } + + void emitCxxPredicateExpr(raw_ostream &OS, RuleMatcher &Rule, + StringRef OperandExpr) const override { + OperandExpr = Rule.getInsnVarName(*InsnMatcher); + OS << "("; + InsnMatcher->emitCxxPredicateExpr(OS, Rule, OperandExpr); + OS << ")\n"; + } +}; + //===- Actions ------------------------------------------------------------===// void OperandPlaceholder::emitCxxValueExpr(raw_ostream &OS) const { switch (Kind) { @@ -878,7 +972,11 @@ << I->Namespace << "::" << I->TheDef->getName() << "));\n"; for (const auto &Renderer : OperandRenderers) Renderer->emitCxxRenderStmts(OS, Rule); - OS << " MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end());\n"; + OS << " for (const auto *FromMI : "; + Rule.emitCxxCapturedInsnList(OS); + OS << ")\n"; + OS << " for (const auto &MMO : FromMI->memoperands())\n"; + OS << " MIB.addMemOperand(MMO);\n"; OS << " " << RecycleVarName << ".eraseFromParent();\n"; OS << " MachineInstr &NewI = *MIB;\n"; } @@ -911,6 +1009,14 @@ llvm_unreachable("Matched Insn was not captured in a local variable"); } +/// Emit a C++ initializer_list containing references to every matched instruction. +void RuleMatcher::emitCxxCapturedInsnList(raw_ostream &OS) { + OS << "{"; + for (const auto &Pair : InsnVariableNames) + OS << "&" << Pair.second << ", "; + OS << "}"; +} + /// Emit C++ statements to check the shape of the match and capture /// instructions into local variables. void RuleMatcher::emitCxxCaptureStmts(raw_ostream &OS, StringRef Expr) { @@ -942,6 +1048,55 @@ getInsnVarName(*Matchers.front())); OS << ") {\n"; + // We must also check if it's safe to fold the matched instructions. + if (InsnVariableNames.size() >= 2) { + for (const auto &Pair : InsnVariableNames) { + // Skip the root node since it isn't moving anywhere. Everything else is + // sinking to meet it. + if (Pair.first == Matchers.front().get()) + continue; + + // Reject the difficult cases until we have a more accurate check. + OS << " if (!isObviouslySafeToFold(" << Pair.second + << ")) return false;\n"; + + // FIXME: Emit checks to determine it's _actually_ safe to fold and/or + // account for unsafe cases. + // + // Example: + // MI1--> %0 = ... + // %1 = ... %0 + // MI0--> %2 = ... %0 + // It's not safe to erase MI1. We currently handle this by not + // erasing %0 (even when it's dead). + // + // Example: + // MI1--> %0 = load volatile @a + // %1 = load volatile @a + // MI0--> %2 = ... %0 + // It's not safe to sink %0's def past %1. We currently handle + // this by rejecting all loads. + // + // Example: + // MI1--> %0 = load @a + // %1 = store @a + // MI0--> %2 = ... %0 + // It's not safe to sink %0's def past %1. We currently handle + // this by rejecting all loads. + // + // Example: + // G_CONDBR %cond, @BB1 + // BB0: + // MI1--> %0 = load @a + // G_BR @BB1 + // BB1: + // MI0--> %2 = ... %0 + // It's not always safe to sink %0 across control flow. In this + // case it may introduce a memory fault. We currentl handle this + // by rejecting all loads. + } + } + for (const auto &MA : Actions) { MA->emitCxxActionStmts(OS, *this, "I"); } @@ -1123,8 +1278,6 @@ return Error::success(); } } - - return failedImport("Src child operand is an unsupported type"); } auto OpTyOrNone = MVTToLLT(ChildTypes.front().getConcrete()); @@ -1132,6 +1285,19 @@ return failedImport("Src operand has an unsupported type"); OM.addPredicate(*OpTyOrNone); + // Check for nested instructions. + if (!SrcChild->isLeaf()) { + // Map the node to a gMIR instruction. + InstructionOperandMatcher &InsnOperand = + OM.addPredicate(); + auto InsnMatcherOrError = + createAndImportSelDAGMatcher(InsnOperand.getInsnMatcher(), SrcChild); + if (auto Error = InsnMatcherOrError.takeError()) + return Error; + + return Error::success(); + } + // Check for constant immediates. if (auto *ChildInt = dyn_cast(SrcChild->getLeafValue())) { OM.addPredicate(ChildInt->getValue()); @@ -1290,6 +1456,7 @@ if (!isTrivialOperatorNode(Src)) return failedImport("Src pattern root isn't a trivial operator"); + // Start with the defined operands (i.e., the results of the root operator). Record *DstOp = Dst->getOperator(); if (!DstOp->isSubClassOf("Instruction")) return failedImport("Pattern operator isn't an instruction");