diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -242,19 +242,47 @@ // CallGraph traversal //===----------------------------------------------------------------------===// +namespace { +/// This class represents a specific callgraph SCC. +class CallGraphSCC { +public: + CallGraphSCC(llvm::scc_iterator &parentIterator) + : parentIterator(parentIterator) {} + /// Return a range over the nodes within this SCC. + std::vector::iterator begin() { return nodes.begin(); } + std::vector::iterator end() { return nodes.end(); } + + /// Reset the nodes of this SCC with those provided. + void reset(const std::vector &newNodes) { nodes = newNodes; } + + /// Remove the given node from this SCC. + void remove(CallGraphNode *node) { + auto it = llvm::find(nodes, node); + if (it != nodes.end()) { + nodes.erase(it); + parentIterator.ReplaceNode(node, nullptr); + } + } + +private: + std::vector nodes; + llvm::scc_iterator &parentIterator; +}; +} // end anonymous namespace + /// Run a given transformation over the SCCs of the callgraph in a bottom up /// traversal. -static void runTransformOnCGSCCs( - const CallGraph &cg, - function_ref)> sccTransformer) { - std::vector currentSCCVec; - auto cgi = llvm::scc_begin(&cg); +static void +runTransformOnCGSCCs(const CallGraph &cg, + function_ref sccTransformer) { + llvm::scc_iterator cgi = llvm::scc_begin(&cg); + CallGraphSCC currentSCC(cgi); while (!cgi.isAtEnd()) { // Copy the current SCC and increment so that the transformer can modify the // SCC without invalidating our iterator. - currentSCCVec = *cgi; + currentSCC.reset(*cgi); ++cgi; - sccTransformer(currentSCCVec); + sccTransformer(currentSCC); } } @@ -343,6 +371,19 @@ /*traverseNestedCGNodes=*/true); } + /// Mark the given callgraph node for deletion. + void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); } + + /// This method properly disposes of callables that became dead during + /// inlining. + void eraseDeadCallables() { + for (CallGraphNode *node : deadNodes) + node->getCallableRegion()->getParentOp()->erase(); + } + + /// The set of callables known to be dead. + SmallPtrSet deadNodes; + /// The current set of call instructions to consider for inlining. SmallVector calls; @@ -368,27 +409,16 @@ return true; } -/// Delete the given node and remove it from the current scc and the callgraph. -static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg, - MutableArrayRef currentSCC) { - // Erase the parent operation and remove it from the various lists. - node->getCallableRegion()->getParentOp()->erase(); - cg.eraseNode(node); - - // Replace this node in the currentSCC with the external node. - auto it = llvm::find(currentSCC, node); - if (it != currentSCC.end()) - *it = cg.getExternalNode(); -} - /// Attempt to inline calls within the given scc. This function returns /// success if any calls were inlined, failure otherwise. -static LogicalResult -inlineCallsInSCC(Inliner &inliner, CGUseList &useList, - MutableArrayRef currentSCC) { +static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList, + CallGraphSCC ¤tSCC) { CallGraph &cg = inliner.cg; auto &calls = inliner.calls; + // A set of dead nodes to remove after inlining. + SmallVector deadNodes; + // Collect all of the direct calls within the nodes of the current SCC. We // don't traverse nested callgraph nodes, because they are handled separately // likely within a different SCC. @@ -396,18 +426,13 @@ if (node->isExternal()) continue; - // If this node is dead, just delete it now. + // Don't collect calls if the node is already dead. if (useList.isDead(node)) - deleteNode(node, useList, cg, currentSCC); + deadNodes.push_back(node); else collectCallOps(*node->getCallableRegion(), node, cg, calls, /*traverseNestedCGNodes=*/false); } - if (calls.empty()) - return failure(); - - // A set of dead nodes to remove after inlining. - SmallVector deadNodes; // Try to inline each of the call operations. Don't cache the end iterator // here as more calls may be added during inlining. @@ -453,8 +478,10 @@ } } - for (CallGraphNode *node : deadNodes) - deleteNode(node, useList, cg, currentSCC); + for (CallGraphNode *node : deadNodes) { + currentSCC.remove(node); + inliner.markForDeletion(node); + } calls.clear(); return success(inlinedAnyCalls); } @@ -462,8 +489,7 @@ /// Canonicalize the nodes within the given SCC with the given set of /// canonicalization patterns. static void canonicalizeSCC(CallGraph &cg, CGUseList &useList, - MutableArrayRef currentSCC, - MLIRContext *context, + CallGraphSCC ¤tSCC, MLIRContext *context, const OwningRewritePatternList &canonPatterns) { // Collect the sets of nodes to canonicalize. SmallVector nodesToCanonicalize; @@ -533,8 +559,7 @@ /// Attempt to inline calls within the given scc, and run canonicalizations /// with the given patterns, until a fixed point is reached. This allows for /// the inlining of newly devirtualized calls. - void inlineSCC(Inliner &inliner, CGUseList &useList, - MutableArrayRef currentSCC, + void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC, MLIRContext *context, const OwningRewritePatternList &canonPatterns); }; @@ -562,14 +587,16 @@ // Run the inline transform in post-order over the SCCs in the callgraph. Inliner inliner(context, cg); CGUseList useList(getOperation(), cg); - runTransformOnCGSCCs(cg, [&](MutableArrayRef scc) { + runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) { inlineSCC(inliner, useList, scc, context, canonPatterns); }); + + // After inlining, make sure to erase any callables proven to be dead. + inliner.eraseDeadCallables(); } void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList, - MutableArrayRef currentSCC, - MLIRContext *context, + CallGraphSCC ¤tSCC, MLIRContext *context, const OwningRewritePatternList &canonPatterns) { // If we successfully inlined any calls, run some simplifications on the // nodes of the scc. Continue attempting to inline until we reach a fixed diff --git a/mlir/test/Transforms/inlining-dce.mlir b/mlir/test/Transforms/inlining-dce.mlir --- a/mlir/test/Transforms/inlining-dce.mlir +++ b/mlir/test/Transforms/inlining-dce.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -inline | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -inline -split-input-file | FileCheck %s // This file tests the callgraph dead code elimination performed by the inliner. @@ -51,3 +51,23 @@ } "live.user"() {use = @live_function_d} : () -> () + +// ----- + +// This test checks that the inliner can properly handle the deletion of +// functions in different SCCs that are referenced by calls materialized during +// canonicalization. +// CHECK: func @live_function_e +func @live_function_e() { + call @dead_function_e() : () -> () + return +} +// CHECK-NOT: func @dead_function_e +func @dead_function_e() -> () attributes {sym_visibility = "private"} { + "test.fold_to_call_op"() {callee=@dead_function_f} : () -> () + return +} +// CHECK-NOT: func @dead_function_f +func @dead_function_f() attributes {sym_visibility = "private"} { + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -173,6 +173,28 @@ return targetOperandsMutable(); } +//===----------------------------------------------------------------------===// +// TestFoldToCallOp +//===----------------------------------------------------------------------===// + +namespace { +struct FoldToCallOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FoldToCallOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, ArrayRef(), op.calleeAttr(), + ValueRange()); + return success(); + } +}; +} // end anonymous namespace + +void FoldToCallOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -321,6 +321,12 @@ }]; } + +def FoldToCallOp : TEST_Op<"fold_to_call_op"> { + let arguments = (ins FlatSymbolRefAttr:$callee); + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // Test Traits //===----------------------------------------------------------------------===//