diff --git a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h @@ -234,6 +234,11 @@ /// any of the operand lattices are uninitialized. Optional> getOperandValues(Operation *op); + /// The top-level operation the analysis is running on. This is used to detect + /// if a callable is outside the scope of the analysis and thus must be + /// considered an external callable. + Operation *analysisScope; + /// A symbol table used for O(1) symbol lookups during simplification. SymbolTableCollection symbolTable; }; diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -118,6 +118,7 @@ } void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { + analysisScope = top; auto walkFn = [&](Operation *symTable, bool allUsesVisible) { Region &symbolTableRegion = symTable->getRegion(0); Block *symbolTableBlock = &symbolTableRegion.front(); @@ -278,14 +279,14 @@ } void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { - Operation *callableOp = nullptr; - if (Value callableValue = call.getCallableForCallee().dyn_cast()) - callableOp = callableValue.getDefiningOp(); - else - callableOp = call.resolveCallable(&symbolTable); + Operation *callableOp = call.resolveCallable(&symbolTable); // A call to a externally-defined callable has unknown predecessors. - const auto isExternalCallable = [](Operation *op) { + const auto isExternalCallable = [this](Operation *op) { + // A callable outside the analysis scope is an external callable. + if (!analysisScope->isAncestor(op)) + return true; + // Otherwise, check if the callable region is defined. if (auto callable = dyn_cast(op)) return !callable.getCallableRegion(); return false; diff --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir --- a/mlir/test/Transforms/sccp-callgraph.mlir +++ b/mlir/test/Transforms/sccp-callgraph.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -sccp -split-input-file | FileCheck %s // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="builtin.module(sccp)" -split-input-file | FileCheck %s --check-prefix=NESTED +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func.func(sccp)" -split-input-file | FileCheck %s --check-prefix=FUNC /// Check that a constant is properly propagated through the arguments and /// results of a private function. @@ -270,3 +271,38 @@ %result = arith.select %true, %cst0, %cst1 : i32 return %result : i32 } + +// ----- + +/// Check that callables outside the analysis scope are marked as external. + +func.func private @foo() -> index { + %0 = arith.constant 10 : index + return %0 : index +} + +// CHECK-LABEL: func @bar +// FUNC-LABEL: func @bar +func.func @bar(%arg0: index) -> index { + // CHECK: %[[C10:.*]] = arith.constant 10 + %c0 = arith.constant 0 : index + %1 = arith.constant 420 : index + %7 = arith.cmpi eq, %arg0, %c0 : index + cf.cond_br %7, ^bb1(%1 : index), ^bb2 + +// CHECK: ^bb1(%[[ARG:.*]]: index): +// FUNC: ^bb1(%[[ARG:.*]]: index): +^bb1(%8: index): // 2 preds: ^bb0, ^bb4 + // CHECK-NEXT: return %[[ARG]] + // FUNC-NEXT: return %[[ARG]] + return %8 : index + +// CHECK: ^bb2 +// FUNC: ^bb2 +^bb2: + // FUNC-NEXT: %[[FOO:.*]] = call @foo + %13 = call @foo() : () -> index + // CHECK: cf.br ^bb1(%[[C10]] + // FUNC: cf.br ^bb1(%[[FOO]] + cf.br ^bb1(%13 : index) +}