diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -364,7 +364,34 @@ //===----------------------------------------------------------------------===// // Inliner //===----------------------------------------------------------------------===// + +#ifndef NDEBUG +static std::string getNodeName(CallOpInterface op) { + auto sym = op.getCallableForCallee().dyn_cast(); + if (sym && sym.isa()) + return sym.dyn_cast().getValue().str(); + return "_unnamed_callee_"; +} +#endif + +/// Return true if the specified `inlineHistoryID` indicates an inline history +/// that already includes `node`. +static bool inlineHistoryIncludes( + CallGraphNode *node, Optional inlineHistoryID, + const SmallVectorImpl>> + &inlineHistory) { + while (inlineHistoryID.has_value()) { + assert(inlineHistoryID.value() < inlineHistory.size() && + "Invalid inline history ID"); + if (inlineHistory[inlineHistoryID.value()].first == node) + return true; + inlineHistoryID = inlineHistory[inlineHistoryID.value()].second; + } + return false; +} + namespace { + /// This class provides a specialization of the main inlining interface. struct Inliner : public InlinerInterface { Inliner(MLIRContext *context, CallGraph &cg, @@ -454,6 +481,20 @@ } } + // When inlining a callee produces new call sites, we want to keep track of + // the fact that they were inlined from the callee. This allows us to avoid + // infinite inlining. + using InlineHistoryT = Optional; + SmallVector, 8> inlineHistory; + std::vector callHistory(calls.size(), InlineHistoryT{}); + + LLVM_DEBUG({ + llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + for (unsigned i = 0; i < calls.size(); ++i) + llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; + llvm::dbgs() << "}\n"; + }); + // Try to inline each of the call operations. Don't cache the end iterator // here as more calls may be added during inlining. bool inlinedAnyCalls = false; @@ -461,16 +502,22 @@ if (deadNodes.contains(calls[i].sourceNode)) continue; ResolvedCall it = calls[i]; - bool doInline = shouldInline(it); + + InlineHistoryT inlineHistoryID = callHistory[i]; + bool inHistory = + inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory); + bool doInline = !inHistory && shouldInline(it); CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << call << "\n"; + llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; else - llvm::dbgs() << "* Not inlining call: " << call << "\n"; + llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; }); if (!doInline) continue; + + unsigned prevSize = calls.size(); Region *targetRegion = it.targetNode->getCallableRegion(); // If this is the last call to the target node and the node is discardable, @@ -486,6 +533,27 @@ } inlinedAnyCalls = true; + // Create a inline history entry for this inlined call, so that we remember + // that new callsites came about due to inlining Callee. + InlineHistoryT newInlineHistoryID{inlineHistory.size()}; + inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID)); + + auto hisStr = [](InlineHistoryT h) { + return h.has_value() ? std::to_string(h.value()) : "root"; + }; + LLVM_DEBUG(llvm::dbgs() << "* new inlineHistory entry: " + << newInlineHistoryID << ". [" << getNodeName(call) + << ", " << hisStr(inlineHistoryID) << "]\n"); + + for (unsigned k = prevSize; k != calls.size(); ++k) { + callHistory.push_back(newInlineHistoryID); + LLVM_DEBUG(llvm::dbgs() + << "* new call " << k << " {" << calls[i].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << hisStr(inlineHistoryID) << "\n"); + } + // If the inlining was successful, Merge the new uses into the source node. useList.dropCallUses(it.sourceNode, call.getOperation(), cg); useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode); diff --git a/mlir/test/Transforms/bugfix53492_inlining.mlir b/mlir/test/Transforms/bugfix53492_inlining.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/bugfix53492_inlining.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s +// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s +// Test BugFix: https://github.com/llvm/llvm-project/issues/53492 + +// CHECK-LABEL: func.func @foo0 +func.func @foo0(%arg0 : i32) -> i32 { +// CHECK: call @foo1 +// CHECK: } + %0 = arith.constant 0 : i32 + %1 = arith.cmpi eq, %arg0, %0 : i32 + cf.cond_br %1, ^exit, ^tail +^exit: + return %0 : i32 +^tail: + %3 = call @foo1(%arg0) : (i32) -> i32 + return %3 : i32 +} + +// CHECK-LABEL: func.func @foo1 +func.func @foo1(%arg0 : i32) -> i32 { +// CHECK: call @foo1 + %0 = arith.constant 1 : i32 + %1 = arith.subi %arg0, %0 : i32 + %2 = call @foo0(%1) : (i32) -> i32 + return %2 : i32 +}