diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -365,6 +365,30 @@ // Inliner //===----------------------------------------------------------------------===// namespace { + +#ifndef NDEBUG +static std::string getNodeName(CallOpInterface op) { + if (auto fnAttr = op->getAttrOfType("callee")) + return fnAttr.getValue().str(); + return "_unnamed_callee_"; +} +#endif + +/// Return true if the specified `inlineHistoryID` indicates +/// an inline history that already includes `N`. +bool inlineHistoryIncludes( + CallGraphNode *N, int inlineHistoryID, + const SmallVectorImpl> &inlineHistory) { + while (inlineHistoryID != -1) { + assert(unsigned(inlineHistoryID) < inlineHistory.size() && + "Invalid inline history ID"); + if (inlineHistory[inlineHistoryID].first == N) + return true; + inlineHistoryID = inlineHistory[inlineHistoryID].second; + } + return false; +} + /// This class provides a specialization of the main inlining interface. struct Inliner : public InlinerInterface { Inliner(MLIRContext *context, CallGraph &cg, @@ -454,6 +478,19 @@ } } + // When inlining a callee produces new call sites, we want to keep track of + // the fact that they were inlined from the callee. This allows us to avoid + // infinite inlining. + SmallVector, 8> inlineHistory; + std::vector callHistory(calls.size(), -1); // initial call history + + LLVM_DEBUG({ + llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + for (unsigned i = 0; i < calls.size(); ++i) + llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; + llvm::dbgs() << "}\n"; + }); + // Try to inline each of the call operations. Don't cache the end iterator // here as more calls may be added during inlining. bool inlinedAnyCalls = false; @@ -461,16 +498,22 @@ if (deadNodes.contains(calls[i].sourceNode)) continue; ResolvedCall it = calls[i]; - bool doInline = shouldInline(it); + + int inlineHistoryID = callHistory[i]; + bool inHistory = + inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory); + bool doInline = !inHistory && shouldInline(it); CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << call << "\n"; + llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; else - llvm::dbgs() << "* Not inlining call: " << call << "\n"; + llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; }); if (!doInline) continue; + + unsigned prevSize = calls.size(); Region *targetRegion = it.targetNode->getCallableRegion(); // If this is the last call to the target node and the node is discardable, @@ -486,6 +529,23 @@ } inlinedAnyCalls = true; + // Create a inline history entry for this inlined call, so that we remember + // that new callsites came about due to inlining Callee. + int newInlineHistoryID = inlineHistory.size(); + inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID)); + LLVM_DEBUG(llvm::dbgs() + << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << inlineHistoryID << "]\n"); + + for (unsigned k = prevSize; k != calls.size(); ++k) { + callHistory.push_back(newInlineHistoryID); + LLVM_DEBUG(llvm::dbgs() + << "* new call " << k << " {" << calls[i].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << inlineHistoryID << "\n"); + } + // If the inlining was successful, Merge the new uses into the source node. useList.dropCallUses(it.sourceNode, call.getOperation(), cg); useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); diff --git a/mlir/test/Transforms/bugfix53492_inlining.mlir b/mlir/test/Transforms/bugfix53492_inlining.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/bugfix53492_inlining.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s +// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s +// Test BugFix: https://github.com/llvm/llvm-project/issues/53492 + +// CHECK-LABEL: func.func @foo0 +func.func @foo0(%arg0 : i32) -> i32 { +// CHECK: call @foo1 +// CHECK: } + %0 = arith.constant 0 : i32 + %1 = arith.cmpi eq, %arg0, %0 : i32 + cf.cond_br %1, ^exit, ^tail +^exit: + return %0 : i32 +^tail: + %3 = call @foo1(%arg0) : (i32) -> i32 + return %3 : i32 +} + +// CHECK-LABEL: func.func @foo1 +func.func @foo1(%arg0 : i32) -> i32 { +// CHECK: call @foo1 + %0 = arith.constant 1 : i32 + %1 = arith.subi %arg0, %0 : i32 + %2 = call @foo0(%1) : (i32) -> i32 + return %2 : i32 +}