diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -236,6 +236,11 @@ /// state. void visitBlockArgument(Block *block, int i); + /// Mark the entry block of the given region as executable. Returns false if + /// the block was already marked executable. If `markArgsOverdefined` is true, + /// the arguments of the entry block are also set to overdefined. + bool markEntryBlockExecutable(Region *region, bool markArgsOverdefined); + /// Mark the given block as executable. Returns false if the block was already /// marked executable. bool markBlockExecutable(Block *block); @@ -313,16 +318,9 @@ SCCPSolver::SCCPSolver(Operation *op) { /// Initialize the solver with the regions within this operation. for (Region ®ion : op->getRegions()) { - if (region.empty()) - continue; - Block *entryBlock = ®ion.front(); - - // Mark the entry block as executable. - markBlockExecutable(entryBlock); - - // The values passed to these regions are invisible, so mark any arguments - // as overdefined. - markAllOverdefined(entryBlock->getArguments()); + // Mark the entry block as executable. The values passed to these regions + // are also invisible, so mark any arguments as overdefined. + markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); } initializeSymbolCallables(op); } @@ -405,8 +403,10 @@ // If not all of the uses of this symbol are visible, we can't track the // state of the arguments. - if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) - markAllOverdefined(callableRegion->getArguments()); + if (symbol.isPublic() || (!allUsesVisible && symbol.isNested())) { + for (Region ®ion : callable->getRegions()) + markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); + } } if (callableLatticeState.empty()) return; @@ -443,8 +443,10 @@ // This use isn't a call, so don't we know all of the callers. auto *symbol = symbolTable.lookupSymbolIn(op, use.getSymbolRef()); auto it = callableLatticeState.find(symbol); - if (it != callableLatticeState.end()) - markAllOverdefined(it->second.getCallableArguments()); + if (it != callableLatticeState.end()) { + for (Region ®ion : it->first->getRegions()) + markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); + } } }; SymbolTable::walkSymbolTables(op, /*allSymUsesVisible=*/!op->getBlock(), @@ -495,8 +497,14 @@ // Process callable operations. These are specially handled region operations // that track dataflow via calls. - if (isa(op)) + if (isa(op)) { + // If this callable has a tracked lattice state, it will be visited by calls + // that reference it instead. This way, we don't assume that it is + // executable unless there is a proper reference to it. + if (callableLatticeState.count(op)) + return; return visitCallableOperation(op); + } // Process region holding operations. The region visitor processes result // values, so we can exit afterwards. @@ -551,19 +559,11 @@ } void SCCPSolver::visitCallableOperation(Operation *op) { - // Mark the regions as executable. + // Mark the regions as executable. If we aren't tracking lattice state for + // this callable, mark all of the region arguments as overdefined. bool isTrackingLatticeState = callableLatticeState.count(op); - for (Region ®ion : op->getRegions()) { - if (region.empty()) - continue; - Block *entryBlock = ®ion.front(); - markBlockExecutable(entryBlock); - - // If we aren't tracking lattice state for this callable, mark all of the - // region arguments as overdefined. - if (!isTrackingLatticeState) - markAllOverdefined(entryBlock->getArguments()); - } + for (Region ®ion : op->getRegions()) + markEntryBlockExecutable(®ion, !isTrackingLatticeState); // TODO: Add support for non-symbol callables when necessary. If the callable // has non-call uses we would mark overdefined, otherwise allow for @@ -599,6 +599,9 @@ visitUsers(callableArg); } + // Visit the callable. + visitCallableOperation(callableOp); + // Merge in the lattice state for the callable results as well. auto callableResults = callableLatticeIt->second.getResultLatticeValues(); for (auto it : llvm::zip(callResults, callableResults)) @@ -613,13 +616,8 @@ auto regionInterface = dyn_cast(op); if (!regionInterface) { // If we can't, conservatively mark all regions as executable. - for (Region ®ion : op->getRegions()) { - if (region.empty()) - continue; - Block *entryBlock = ®ion.front(); - markBlockExecutable(entryBlock); - markAllOverdefined(entryBlock->getArguments()); - } + for (Region ®ion : op->getRegions()) + markEntryBlockExecutable(®ion, /*markArgsOverdefined=*/true); // Don't try to simulate the results of a region operation as we can't // guarantee that folding will be out-of-place. We don't allow in-place @@ -856,6 +854,16 @@ visitUsers(arg); } +bool SCCPSolver::markEntryBlockExecutable(Region *region, + bool markArgsOverdefined) { + if (!region->empty()) { + if (markArgsOverdefined) + markAllOverdefined(region->front().getArguments()); + return markBlockExecutable(®ion->front()); + } + return false; +} + bool SCCPSolver::markBlockExecutable(Block *block) { bool marked = executableBlocks.insert(block).second; if (marked) diff --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir --- a/mlir/test/Transforms/sccp-callgraph.mlir +++ b/mlir/test/Transforms/sccp-callgraph.mlir @@ -140,7 +140,7 @@ /// Check that return values are overdefined when the constant conflicts. func private @callable(%arg0 : i32) -> i32 { - "unknown.return"(%arg0) : (i32) -> () + return %arg0 : i32 } // CHECK-LABEL: func @conflicting_constant( @@ -255,3 +255,18 @@ %res = call_indirect %fn() : () -> (i32) return %res : i32 } + +// ----- + +/// Check that private callables don't get processed if they have no uses. + +// CHECK-LABEL: func private @unreferenced_private_function +func private @unreferenced_private_function() -> i32 { + // CHECK: %[[RES:.*]] = select + // CHECK: return %[[RES]] : i32 + %true = constant true + %cst0 = constant 0 : i32 + %cst1 = constant 1 : i32 + %result = select %true, %cst0, %cst1 : i32 + return %result : i32 +}