diff --git a/mlir/include/mlir/Analysis/CallInterfaces.td b/mlir/include/mlir/Analysis/CallInterfaces.td --- a/mlir/include/mlir/Analysis/CallInterfaces.td +++ b/mlir/include/mlir/Analysis/CallInterfaces.td @@ -54,29 +54,23 @@ be a target for a call-like operation (those providing the CallOpInterface above). These operations may be traditional functional operation `func @foo(...)`, as well as function producing operations - `%foo = dialect.create_function(...)`. These operations may produce multiple - callable regions, or subroutines. + `%foo = dialect.create_function(...)`. These operations may only contain a + single region, or subroutine. }]; let methods = [ InterfaceMethod<[{ - Returns a region on the current operation that the given callable refers - to. This may return null in the case of an external callable object, - e.g. an external function. + Returns the region on the current operation that is callable. This may + return null in the case of an external callable object, e.g. an external + function. }], - "Region *", "getCallableRegion", (ins "CallInterfaceCallable":$callable) + "Region *", "getCallableRegion" >, InterfaceMethod<[{ - Returns all of the callable regions of this operation. - }], - "void", "getCallableRegions", - (ins "SmallVectorImpl &":$callables) - >, - InterfaceMethod<[{ - Returns the results types that the given callable region produces when + Returns the results types that the callable region produces when executed. }], - "ArrayRef", "getCallableResults", (ins "Region *":$callable) + "ArrayRef", "getCallableResults" >, ]; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -122,26 +122,13 @@ // CallableOpInterface //===--------------------------------------------------------------------===// - /// Returns a region on the current operation that the given callable refers - /// to. This may return null in the case of an external callable object, e.g. - /// an external function. - Region *getCallableRegion(CallInterfaceCallable callable) { - assert(callable.get().getLeafReference() == getName()); - return isExternal() ? nullptr : &getBody(); - } - - /// Returns all of the callable regions of this operation. - void getCallableRegions(SmallVectorImpl &callables) { - if (!isExternal()) - callables.push_back(&getBody()); - } + /// Returns the region on the current operation that is callable. This may + /// return null in the case of an external callable object, e.g. an external + /// function. + Region *getCallableRegion() { return isExternal() ? nullptr : &getBody(); } - /// Returns the results types that the given callable region produces when - /// executed. - ArrayRef getCallableResults(Region *region) { - assert(!isExternal() && region == &getBody() && "invalid callable"); - return getType().getResults(); - } + /// Returns the results types that the callable region produces when executed. + ArrayRef getCallableResults() { return getType().getResults(); } private: // This trait needs access to the hooks defined below. diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -74,67 +74,38 @@ /// Recursively compute the callgraph edges for the given operation. Computed /// edges are placed into the given callgraph object. static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode); - -/// Compute the set of callgraph nodes that are created by regions nested within -/// 'op'. -static void computeCallables(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { - if (op->getNumRegions() == 0) + CallGraphNode *parentNode, bool resolveCalls) { + if (CallOpInterface call = dyn_cast(op)) { + // If there is no parent node, we ignore this operation. Even if this + // operation was a call, there would be no callgraph node to attribute it + // to. + if (!resolveCalls || !parentNode) + return; + parentNode->addCallEdge( + cg.resolveCallable(call.getCallableForCallee(), op)); return; - if (auto callableOp = dyn_cast(op)) { - SmallVector callables; - callableOp.getCallableRegions(callables); - for (auto *callableRegion : callables) - cg.getOrAddNode(callableRegion, parentNode); } -} -/// Recursively compute the callgraph edges within the given region. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Region ®ion, CallGraph &cg, - CallGraphNode *parentNode) { - // Iterate over the nested operations twice: - /// One to fully create nodes in the for each callable region of a nested - /// operation; - for (auto &block : region) - for (auto &nested : block) - computeCallables(&nested, cg, parentNode); - - /// And another to recursively compute the callgraph. - for (auto &block : region) - for (auto &nested : block) - computeCallGraph(&nested, cg, parentNode); -} - -/// Recursively compute the callgraph edges for the given operation. Computed -/// edges are placed into the given callgraph object. -static void computeCallGraph(Operation *op, CallGraph &cg, - CallGraphNode *parentNode) { // Compute the callgraph nodes and edges for each of the nested operations. - auto isCallable = isa(op); - for (auto ®ion : op->getRegions()) { - // Check to see if this region is a callable node, if so this is the parent - // node of the nested region. - CallGraphNode *nestedParentNode; - if (!isCallable || !(nestedParentNode = cg.lookupNode(®ion))) - nestedParentNode = parentNode; - computeCallGraph(region, cg, nestedParentNode); + if (CallableOpInterface callable = dyn_cast(op)) { + if (auto *callableRegion = callable.getCallableRegion()) + parentNode = cg.getOrAddNode(callableRegion, parentNode); + else + return; } - // If there is no parent node, we ignore this operation. Even if this - // operation was a call, there would be no callgraph node to attribute it to. - if (!parentNode) - return; - - // If this is a call operation, resolve the callee. - if (auto call = dyn_cast(op)) - parentNode->addCallEdge( - cg.resolveCallable(call.getCallableForCallee(), op)); + for (Region ®ion : op->getRegions()) + for (Block &block : region) + for (Operation &nested : block) + computeCallGraph(&nested, cg, parentNode, resolveCalls); } CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { - computeCallGraph(op, *this, /*parentNode=*/nullptr); + // Make two passes over the graph, one to compute the callables and one to + // resolve the calls. We split these up as we may have nested callable objects + // that need to be reserved before the calls. + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false); + computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true); } /// Get or add a call graph node for the given region. @@ -175,9 +146,7 @@ // Get the callee operation from the callable. Operation *callee; if (auto symbolRef = callable.dyn_cast()) - // TODO(riverriddle) Support nested references. - callee = SymbolTable::lookupNearestSymbolFrom(from, - symbolRef.getRootReference()); + callee = SymbolTable::lookupNearestSymbolFrom(from, symbolRef); else callee = callable.get()->getDefiningOp(); @@ -185,7 +154,7 @@ // called region from it. if (callee && callee->getNumRegions()) { if (auto callableOp = dyn_cast_or_null(callee)) { - if (auto *node = lookupNode(callableOp.getCallableRegion(callable))) + if (auto *node = lookupNode(callableOp.getCallableRegion())) return node; } } diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -86,8 +86,15 @@ while (!worklist.empty()) { for (Operation &op : *worklist.pop_back_val()) { if (auto call = dyn_cast(op)) { - CallGraphNode *node = - cg.resolveCallable(call.getCallableForCallee(), &op); + CallInterfaceCallable callable = call.getCallableForCallee(); + + // TODO(riverriddle) Support inlining nested call references. + if (SymbolRefAttr symRef = callable.dyn_cast()) { + if (!symRef.isa()) + continue; + } + + CallGraphNode *node = cg.resolveCallable(callable, &op); if (!node->isExternal()) calls.emplace_back(call, node); continue; @@ -274,6 +281,15 @@ CallGraph &cg = getAnalysis(); auto *context = &getContext(); + // The inliner should only be run on operations that define a symbol table, + // as the callgraph will need to resolve references. + Operation *op = getOperation(); + if (!op->hasTrait()) { + op->emitOpError() << " was scheduled to run under the inliner, but does " + "not define a symbol table"; + return signalPassFailure(); + } + // Collect a set of canonicalization patterns to use when simplifying // callable regions within an SCC. OwningRewritePatternList canonPatterns; diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -284,7 +284,7 @@ if (src->empty()) return failure(); auto *entryBlock = &src->front(); - ArrayRef callableResultTypes = callable.getCallableResults(src); + ArrayRef callableResultTypes = callable.getCallableResults(); // Make sure that the number of arguments and results matchup between the call // and the region. diff --git a/mlir/test/Analysis/test-callgraph.mlir b/mlir/test/Analysis/test-callgraph.mlir --- a/mlir/test/Analysis/test-callgraph.mlir +++ b/mlir/test/Analysis/test-callgraph.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-print-callgraph 2>&1 | FileCheck %s --dump-input-on-failure +// RUN: mlir-opt %s -test-print-callgraph -split-input-file 2>&1 | FileCheck %s --dump-input-on-failure // CHECK-LABEL: Testing : "simple" module attributes {test.name = "simple"} { @@ -50,3 +50,22 @@ return } } + +// ----- + +// CHECK-LABEL: Testing : "nested" +module attributes {test.name = "nested"} { + module @nested_module { + // CHECK: Node{{.*}}func_a + func @func_a() { + return + } + } + + // CHECK: Node{{.*}}func_b + // CHECK: Call-Edge{{.*}}func_a + func @func_b() { + "test.conversion_call_op"() { callee = @nested_module::@func_a } : () -> () + return + } +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -230,7 +230,7 @@ def ConversionCallOp : TEST_Op<"conversion_call_op", [CallOpInterface]> { - let arguments = (ins Variadic:$inputs, FlatSymbolRefAttr:$callee); + let arguments = (ins Variadic:$inputs, SymbolRefAttr:$callee); let results = (outs Variadic); let extraClassDeclaration = [{ @@ -239,7 +239,7 @@ /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { - return getAttrOfType("callee"); + return getAttrOfType("callee"); } }]; } @@ -250,11 +250,8 @@ let results = (outs FunctionType); let extraClassDeclaration = [{ - Region *getCallableRegion(CallInterfaceCallable) { return &body(); } - void getCallableRegions(SmallVectorImpl &callables) { - callables.push_back(&body()); - } - ArrayRef getCallableResults(Region *) { + Region *getCallableRegion() { return &body(); } + ArrayRef getCallableResults() { return getType().cast().getResults(); } }];