diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -87,6 +87,11 @@ /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); +/// Creates a pass which marks top-level symbol operations as `private` unless +/// listed in `excludeSymbols`. +std::unique_ptr +createSymbolPrivatizePass(ArrayRef excludeSymbols = {}); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -214,6 +214,20 @@ let constructor = "mlir::createSymbolDCEPass()"; } +def SymbolPrivatize : Pass<"symbol-privatize"> { + let summary = "Mark symbols private"; + let description = [{ + This pass marks all top-level symbols of the operation run as `private` + except if listed in `exclude` pass option. + }]; + let options = [ + ListOption<"exclude", "exclude", "std::string", + "Comma separated list of symbols that should not be marked private", + "llvm::cl::MiscFlags::CommaSeparated"> + ]; + let constructor = "mlir::createSymbolPrivatizePass()"; +} + def ViewOpGraph : Pass<"view-op-graph"> { let summary = "Print Graphviz visualization of an operation"; let description = [{ diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ SCCP.cpp StripDebugInfo.cpp SymbolDCE.cpp + SymbolPrivatize.cpp ViewOpGraph.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Transforms/SymbolPrivatize.cpp b/mlir/lib/Transforms/SymbolPrivatize.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/SymbolPrivatize.cpp @@ -0,0 +1,58 @@ +//===- SymbolPrivatize.cpp - Pass to mark symbols private -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an pass that marks all symbols as private unless +// excluded. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +struct SymbolPrivatize : public SymbolPrivatizeBase { + explicit SymbolPrivatize(ArrayRef excludeSymbols); + LogicalResult initialize(MLIRContext *context) override; + void runOnOperation() override; + + /// Symbols whose visibility won't be changed. + DenseSet excludedSymbols; +}; +} // namespace + +SymbolPrivatize::SymbolPrivatize(llvm::ArrayRef excludeSymbols) { + exclude = excludeSymbols; +} + +LogicalResult SymbolPrivatize::initialize(MLIRContext *context) { + for (const std::string &symbol : exclude) + excludedSymbols.insert(StringAttr::get(context, symbol)); + return success(); +} + +void SymbolPrivatize::runOnOperation() { + for (Region ®ion : getOperation()->getRegions()) { + for (Block &block : region) { + for (Operation &op : block) { + auto symbol = dyn_cast(op); + if (!symbol) + continue; + if (!excludedSymbols.contains(symbol.getNameAttr())) + symbol.setVisibility(SymbolTable::Visibility::Private); + } + } + } +} + +std::unique_ptr +mlir::createSymbolPrivatizePass(ArrayRef exclude) { + return std::make_unique(exclude); +} diff --git a/mlir/test/Transforms/test-symbol-privatize.mlir b/mlir/test/Transforms/test-symbol-privatize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-symbol-privatize.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s -symbol-privatize=exclude="aap" | FileCheck %s + +// CHECK-LABEL: module attributes {test.simple} +module attributes {test.simple} { + // CHECK: func @aap + func @aap() { return } + + // CHECK: func private @kat + func @kat() { return } +} +