diff --git a/mlir/include/mlir/Analysis/CallGraph.h b/mlir/include/mlir/Analysis/CallGraph.h --- a/mlir/include/mlir/Analysis/CallGraph.h +++ b/mlir/include/mlir/Analysis/CallGraph.h @@ -27,6 +27,7 @@ struct CallInterfaceCallable; class Operation; class Region; +class SymbolTableCollection; //===----------------------------------------------------------------------===// // CallGraphNode @@ -189,8 +190,11 @@ } /// Resolve the callable for given callee to a node in the callgraph, or the - /// external node if a valid node was not resolved. - CallGraphNode *resolveCallable(CallOpInterface call) const; + /// external node if a valid node was not resolved. The provided symbol table + /// is used when resolving calls that reference callables via a symbol + /// reference. + CallGraphNode *resolveCallable(CallOpInterface call, + SymbolTableCollection &symbolTable) const; /// Erase the given node from the callgraph. void eraseNode(CallGraphNode *node); diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -46,19 +46,16 @@ }], "Operation::operand_range", "getArgOperands" >, - InterfaceMethod<[{ - Resolve the callable operation for given callee to a - CallableOpInterface, or nullptr if a valid callable was not resolved. - }], - "Operation *", "resolveCallable", (ins), [{ - // If the callable isn't a value, lookup the symbol reference. - CallInterfaceCallable callable = $_op.getCallableForCallee(); - if (auto symbolRef = callable.dyn_cast()) - return SymbolTable::lookupNearestSymbolFrom($_op, symbolRef); - return callable.get().getDefiningOp(); - }] - >, ]; + + let extraClassDeclaration = [{ + /// Resolve the callable operation for given callee to a + /// CallableOpInterface, or nullptr if a valid callable was not resolved. + /// `symbolTable` is an optional parameter that will allow for using a + /// cached symbol table for symbol lookups instead of performing an O(N) + /// scan. + Operation *resolveCallable(SymbolTableCollection *symbolTable = nullptr); + }]; } /// Interface for callable operations. diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -68,13 +68,14 @@ /// Recursively compute the callgraph edges for the given operation. Computed /// edges are placed into the given callgraph object. static void computeCallGraph(Operation *op, CallGraph &cg, + SymbolTableCollection &symbolTable, CallGraphNode *parentNode, bool resolveCalls) { if (CallOpInterface call = dyn_cast(op)) { // If there is no parent node, we ignore this operation. Even if this // operation was a call, there would be no callgraph node to attribute it // to. if (resolveCalls && parentNode) - parentNode->addCallEdge(cg.resolveCallable(call)); + parentNode->addCallEdge(cg.resolveCallable(call, symbolTable)); return; } @@ -88,15 +89,18 @@ for (Region ®ion : op->getRegions()) for (Operation &nested : region.getOps()) - computeCallGraph(&nested, cg, parentNode, resolveCalls); + computeCallGraph(&nested, cg, symbolTable, parentNode, resolveCalls); } CallGraph::CallGraph(Operation *op) : externalNode(/*callableRegion=*/nullptr) { // Make two passes over the graph, one to compute the callables and one to // resolve the calls. We split these up as we may have nested callable objects // that need to be reserved before the calls. - computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/false); - computeCallGraph(op, *this, /*parentNode=*/nullptr, /*resolveCalls=*/true); + SymbolTableCollection symbolTable; + computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, + /*resolveCalls=*/false); + computeCallGraph(op, *this, symbolTable, /*parentNode=*/nullptr, + /*resolveCalls=*/true); } /// Get or add a call graph node for the given region. @@ -109,16 +113,17 @@ node.reset(new CallGraphNode(region)); // Add this node to the given parent node if necessary. - if (parentNode) + if (parentNode) { parentNode->addChildEdge(node.get()); - else + } else { // Otherwise, connect all callable nodes to the external node, this allows // for conservatively including all callable nodes within the graph. - // FIXME(riverriddle) This isn't correct, this is only necessary for - // callable nodes that *could* be called from external sources. This - // requires extending the interface for callables to check if they may be - // referenced externally. + // FIXME This isn't correct, this is only necessary for callable nodes + // that *could* be called from external sources. This requires extending + // the interface for callables to check if they may be referenced + // externally. externalNode.addAbstractEdge(node.get()); + } } return node.get(); } @@ -132,8 +137,10 @@ /// Resolve the callable for given callee to a node in the callgraph, or the /// external node if a valid node was not resolved. -CallGraphNode *CallGraph::resolveCallable(CallOpInterface call) const { - Operation *callable = call.resolveCallable(); +CallGraphNode * +CallGraph::resolveCallable(CallOpInterface call, + SymbolTableCollection &symbolTable) const { + Operation *callable = call.resolveCallable(&symbolTable); if (auto callableOp = dyn_cast_or_null(callable)) if (auto *node = lookupNode(callableOp.getCallableRegion())) return node; diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp --- a/mlir/lib/Interfaces/CallInterfaces.cpp +++ b/mlir/lib/Interfaces/CallInterfaces.cpp @@ -10,6 +10,27 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// CallOpInterface +//===----------------------------------------------------------------------===// + +/// Resolve the callable operation for given callee to a CallableOpInterface, or +/// nullptr if a valid callable was not resolved. `symbolTable` is an optional +/// parameter that will allow for using a cached symbol table for symbol lookups +/// instead of performing an O(N) scan. +Operation * +CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) { + CallInterfaceCallable callable = getCallableForCallee(); + if (auto symbolVal = callable.dyn_cast()) + return symbolVal.getDefiningOp(); + + // If the callable isn't a value, lookup the symbol reference. + auto symbolRef = callable.get(); + if (symbolTable) + return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef); + return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef); +} + //===----------------------------------------------------------------------===// // CallInterfaces //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -33,7 +33,7 @@ /// Walk all of the used symbol callgraph nodes referenced with the given op. static void walkReferencedSymbolNodes( - Operation *op, CallGraph &cg, + Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable, DenseMap &resolvedRefs, function_ref callback) { auto symbolUses = SymbolTable::getSymbolUses(op); @@ -47,8 +47,8 @@ // If this is the first instance of this reference, try to resolve a // callgraph node for it. if (refIt.second) { - auto *symbolOp = SymbolTable::lookupNearestSymbolFrom(symbolTableOp, - use.getSymbolRef()); + auto *symbolOp = symbolTable.lookupNearestSymbolFrom(symbolTableOp, + use.getSymbolRef()); auto callableOp = dyn_cast_or_null(symbolOp); if (!callableOp) continue; @@ -80,7 +80,7 @@ DenseMap innerUses; }; - CGUseList(Operation *op, CallGraph &cg); + CGUseList(Operation *op, CallGraph &cg, SymbolTableCollection &symbolTable); /// Drop uses of nodes referred to by the given call operation that resides /// within 'userNode'. @@ -110,13 +110,19 @@ /// A mapping between a discardable callgraph node (that is a symbol) and the /// number of uses for this node. DenseMap discardableSymNodeUses; + /// A mapping between a callgraph node and the symbol callgraph nodes that it /// uses. DenseMap nodeUses; + + /// A symbol table to use when resolving call lookups. + SymbolTableCollection &symbolTable; }; } // end anonymous namespace -CGUseList::CGUseList(Operation *op, CallGraph &cg) { +CGUseList::CGUseList(Operation *op, CallGraph &cg, + SymbolTableCollection &symbolTable) + : symbolTable(symbolTable) { /// A set of callgraph nodes that are always known to be live during inlining. DenseMap alwaysLiveNodes; @@ -135,7 +141,7 @@ } } // Otherwise, check for any referenced nodes. These will be always-live. - walkReferencedSymbolNodes(&op, cg, alwaysLiveNodes, + walkReferencedSymbolNodes(&op, cg, symbolTable, alwaysLiveNodes, [](CallGraphNode *, Operation *) {}); } }; @@ -162,7 +168,7 @@ --discardableSymNodeUses[node]; }; DenseMap resolvedRefs; - walkReferencedSymbolNodes(callOp, cg, resolvedRefs, walkFn); + walkReferencedSymbolNodes(callOp, cg, symbolTable, resolvedRefs, walkFn); } void CGUseList::eraseNode(CallGraphNode *node) { @@ -220,7 +226,7 @@ return; ++discardSymIt->second; }; - walkReferencedSymbolNodes(parentOp, cg, resolvedRefs, walkFn); + walkReferencedSymbolNodes(parentOp, cg, symbolTable, resolvedRefs, walkFn); } void CGUseList::mergeUsesAfterInlining(CallGraphNode *lhs, CallGraphNode *rhs) { @@ -305,6 +311,7 @@ /// inside of nested callgraph nodes. static void collectCallOps(iterator_range blocks, CallGraphNode *sourceNode, CallGraph &cg, + SymbolTableCollection &symbolTable, SmallVectorImpl &calls, bool traverseNestedCGNodes) { SmallVector, 8> worklist; @@ -328,7 +335,7 @@ continue; } - CallGraphNode *targetNode = cg.resolveCallable(call); + CallGraphNode *targetNode = cg.resolveCallable(call, symbolTable); if (!targetNode->isExternal()) calls.emplace_back(call, sourceNode, targetNode); continue; @@ -352,8 +359,9 @@ namespace { /// This class provides a specialization of the main inlining interface. struct Inliner : public InlinerInterface { - Inliner(MLIRContext *context, CallGraph &cg) - : InlinerInterface(context), cg(cg) {} + Inliner(MLIRContext *context, CallGraph &cg, + SymbolTableCollection &symbolTable) + : InlinerInterface(context), cg(cg), symbolTable(symbolTable) {} /// Process a set of blocks that have been inlined. This callback is invoked /// *before* inlined terminator operations have been processed. @@ -367,7 +375,7 @@ assert(region && "expected valid parent node"); } - collectCallOps(inlinedBlocks, node, cg, calls, + collectCallOps(inlinedBlocks, node, cg, symbolTable, calls, /*traverseNestedCGNodes=*/true); } @@ -389,6 +397,9 @@ /// The callgraph being operated on. CallGraph &cg; + + /// A symbol table to use when resolving call lookups. + SymbolTableCollection &symbolTable; }; } // namespace @@ -427,11 +438,12 @@ continue; // Don't collect calls if the node is already dead. - if (useList.isDead(node)) + if (useList.isDead(node)) { deadNodes.push_back(node); - else - collectCallOps(*node->getCallableRegion(), node, cg, calls, - /*traverseNestedCGNodes=*/false); + } else { + collectCallOps(*node->getCallableRegion(), node, cg, inliner.symbolTable, + calls, /*traverseNestedCGNodes=*/false); + } } // Try to inline each of the call operations. Don't cache the end iterator @@ -585,8 +597,9 @@ op->getCanonicalizationPatterns(canonPatterns, context); // Run the inline transform in post-order over the SCCs in the callgraph. - Inliner inliner(context, cg); - CGUseList useList(getOperation(), cg); + SymbolTableCollection symbolTable; + Inliner inliner(context, cg, symbolTable); + CGUseList useList(getOperation(), cg, symbolTable); runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { inlineSCC(inliner, useList, scc, context, canonPatterns); }); diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -304,6 +304,9 @@ /// avoids re-resolving symbol references during propagation. Value based /// callables are trivial to resolve, so they can be done in-place. DenseMap callToSymbolCallable; + + /// A symbol table used for O(1) symbol lookups during simplification. + SymbolTableCollection symbolTable; }; } // end anonymous namespace @@ -425,7 +428,7 @@ // If the use is a call, track it to avoid the need to recompute the // reference later. if (auto callOp = dyn_cast(use.getUser())) { - Operation *symCallable = callOp.resolveCallable(); + Operation *symCallable = callOp.resolveCallable(&symbolTable); auto callableLatticeIt = callableLatticeState.find(symCallable); if (callableLatticeIt != callableLatticeState.end()) { callToSymbolCallable.try_emplace(callOp, symCallable); @@ -438,7 +441,7 @@ continue; } // This use isn't a call, so don't we know all of the callers. - auto *symbol = SymbolTable::lookupSymbolIn(op, use.getSymbolRef()); + auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); auto it = callableLatticeState.find(symbol); if (it != callableLatticeState.end()) markAllOverdefined(it->second.getCallableArguments());