diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h --- a/mlir/include/mlir/Analysis/CallGraph.h +++ b/mlir/include/mlir/Analysis/CallGraph.h @@ -21,6 +21,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/SetVector.h" +#include namespace mlir { class CallOpInterface; @@ -204,6 +205,12 @@ iterator begin() const { return nodes.begin(); } iterator end() const { return nodes.end(); } + // Returns true if callgraph has cycle. + bool isRecursive() const; + bool dfsVisit(std::set &discovered, + std::set &finished, + const CallGraphNode *u) const; + /// Dump the graph in a human readable format. void dump() const; void print(raw_ostream &os) const; diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -166,6 +166,47 @@ nodes.erase(node->getCallableRegion()); } +// Detect cycle in callgraph using depth-first-search (CLRS) +bool CallGraph::dfsVisit(std::set &discovered, + std::set &finished, + const CallGraphNode *u) const { + bool hasCycle = false; + discovered.insert(u); + for (const auto &edge : *u) { + const CallGraphNode *v = edge.getTarget(); + // Excluding self-cycle here for now. + if (!edge.isCall() || (v == u)) + continue; + if (discovered.count(v)) { + hasCycle = true; + break; + } + if (!finished.count(v)) { + if (dfsVisit(discovered, finished, v)) { + hasCycle = true; + break; + } + } + } + discovered.erase(u); + finished.insert(u); + return hasCycle; +} +bool CallGraph::isRecursive() const { + std::set discovered; + std::set finished; + + for (const auto &nodeIt : nodes) { + const CallGraphNode *u = nodeIt.second.get(); + if (!discovered.count(u) && !finished.count(u)) { + if (dfsVisit(discovered, finished, u)) { + return true; + } + } + } + return false; +} + //===----------------------------------------------------------------------===// // Printing diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -607,12 +607,13 @@ SymbolTableCollection symbolTable; Inliner inliner(context, cg, symbolTable); CGUseList useList(getOperation(), cg, symbolTable); - LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { - return inlineSCC(inliner, useList, scc, context); - }); - if (failed(result)) - return signalPassFailure(); - + if (!cg.isRecursive()) { + LogicalResult result = runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { + return inlineSCC(inliner, useList, scc, context); + }); + if (failed(result)) + return signalPassFailure(); + } // After inlining, make sure to erase any callables proven to be dead. inliner.eraseDeadCallables(); } diff --git a/mlir/test/Transforms/bugfix53492_inlining.mlir b/mlir/test/Transforms/bugfix53492_inlining.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/bugfix53492_inlining.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s +// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s +// Test BugFix: https://github.com/llvm/llvm-project/issues/53492 + +// CHECK-LABEL: func.func @test_indirect_recursion +func.func @test_indirect_recursion(%arg0 : i32) -> i32 { +// CHECK-NEXT: arith.constant 0 : i32 +// CHECK-NEXT: arith.cmpi +// CHECK-NEXT: cf.cond_br +// CHECK: call @foo_with_recursive_call + + %0 = arith.constant 0 : i32 + %1 = arith.cmpi eq, %arg0, %0 : i32 + cf.cond_br %1, ^exit, ^tail +^exit: + return %0 : i32 +^tail: + %3 = call @foo_with_recursive_call(%arg0) : (i32) -> i32 + return %3 : i32 +} +// CHECK-LABEL: func @foo_with_recursive_call +func.func @foo_with_recursive_call(%arg0 : i32) -> i32 { +// CHECK-NEXT: arith.constant +// CHECK-NEXT: arith.subi +// CHECK-NEXT: call @test_indirect_recursion + + %0 = arith.constant 1 : i32 + %1 = arith.subi %arg0, %0 : i32 + %2 = call @test_indirect_recursion(%1) : (i32) -> i32 + return %2 : i32 +}