diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -46,10 +46,31 @@ /// Returns the associated operation. Operation *getOp() const { return symbolTableOp; } + /// Return the name of the attribute used for symbol visibility. + static StringRef getVisibilityAttrName() { return "sym_visibility"; } + //===--------------------------------------------------------------------===// // Symbol Utilities //===--------------------------------------------------------------------===// + /// An enumeration detailing the different visibility types that a symbol may + /// have. + enum class Visibility { + /// The symbol is public and may be referenced anywhere internal or external + /// to the visible references in the IR. + Public, + + /// The symbol is private and may only be referenced by SymbolRefAttrs local + /// to the operations within the current symbol table. + Private, + + /// The symbol is visible to the current IR, which may include operations in + /// symbol tables above the one that owns the current symbol. `Nested` + /// visibility allows for referencing a symbol outside of its current symbol + /// table, while retaining the ability to observe all uses. + Nested, + }; + /// Returns true if the given operation defines a symbol. static bool isSymbol(Operation *op); @@ -58,6 +79,11 @@ /// Sets the name of the given symbol operation. static void setSymbolName(Operation *symbol, StringRef name); + /// Returns the visibility of the given symbol operation. + static Visibility getSymbolVisibility(Operation *symbol); + /// Sets the visibility of the given symbol operation. + static void setSymbolVisibility(Operation *symbol, Visibility vis); + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. @@ -200,6 +226,8 @@ template class Symbol : public TraitBase { public: + using Visibility = mlir::SymbolTable::Visibility; + static LogicalResult verifyTrait(Operation *op) { return impl::verifySymbol(op); } @@ -219,6 +247,16 @@ StringAttr::get(name, this->getOperation()->getContext())); } + /// Returns the visibility of the current symbol. + Visibility getVisibility() { + return mlir::SymbolTable::getSymbolVisibility(this->getOperation()); + } + + /// Sets the visibility of the current symbol. + void setVisibility(Visibility vis) { + mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis); + } + /// Get all of the uses of the current symbol that are nested within the given /// operation 'from'. /// Note: See mlir::SymbolTable::getSymbolUses for more details. diff --git a/mlir/lib/IR/Module.cpp b/mlir/lib/IR/Module.cpp --- a/mlir/lib/IR/Module.cpp +++ b/mlir/lib/IR/Module.cpp @@ -85,7 +85,10 @@ // the symbol name attribute. for (auto attr : getOperation()->getAttrList().getAttrs()) { if (!attr.first.strref().contains('.') && - attr.first.strref() != mlir::SymbolTable::getSymbolAttrName()) + !llvm::is_contained( + ArrayRef{mlir::SymbolTable::getSymbolAttrName(), + mlir::SymbolTable::getVisibilityAttrName()}, + attr.first.strref())) return emitOpError( "can only contain dialect-specific attributes, found: '") << attr.first << "'"; diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -10,6 +10,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringSwitch.h" using namespace mlir; @@ -179,6 +180,38 @@ StringAttr::get(name, symbol->getContext())); } +/// Returns the visibility of the given symbol operation. +SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) { + // If the attribute doesn't exist, assume public. + StringAttr vis = symbol->getAttrOfType(getVisibilityAttrName()); + if (!vis) + return Visibility::Public; + + // Otherwise, switch on the string value. + return llvm::StringSwitch(vis.getValue()) + .Case("private", Visibility::Private) + .Case("nested", Visibility::Nested) + .Case("public", Visibility::Public); +} +/// Sets the visibility of the given symbol operation. +void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { + MLIRContext *ctx = symbol->getContext(); + + // If the visibility is public, just drop the attribute as this is the + // default. + if (vis == Visibility::Public) { + symbol->removeAttr(Identifier::get(getVisibilityAttrName(), ctx)); + return; + } + + // Otherwise, update the attribute. + assert((vis == Visibility::Private || vis == Visibility::Nested) && + "unknown symbol visibility kind"); + + StringRef visName = vis == Visibility::Private ? "private" : "nested"; + symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); +} + /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol @@ -272,9 +305,26 @@ } LogicalResult OpTrait::impl::verifySymbol(Operation *op) { + // Verify the name attribute. if (!op->getAttrOfType(mlir::SymbolTable::getSymbolAttrName())) return op->emitOpError() << "requires string attribute '" << mlir::SymbolTable::getSymbolAttrName() << "'"; + + // Verify the visibility attribute. + if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) { + StringAttr visStrAttr = vis.dyn_cast(); + if (!visStrAttr) + return op->emitOpError() << "requires visibility attribute '" + << mlir::SymbolTable::getVisibilityAttrName() + << "' to be a string attribute, but got " << vis; + + if (!llvm::is_contained(ArrayRef{"public", "private", "nested"}, + visStrAttr.getValue())) + return op->emitOpError() + << "visibility expected to be one of [\"public\", \"private\", " + "\"nested\"], but got " + << visStrAttr; + } return success(); } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -207,6 +207,30 @@ // ----- +// Test the invariants of operations with the Symbol Trait. + +// expected-error@+1 {{requires string attribute 'sym_name'}} +"test.symbol"() {} : () -> () + +// ----- + +// expected-error@+1 {{requires visibility attribute 'sym_visibility' to be a string attribute}} +"test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> () + +// ----- + +// expected-error@+1 {{visibility expected to be one of ["public", "private", "nested"]}} +"test.symbol"() {sym_name = "foo_2", sym_visibility = "foo"} : () -> () + +// ----- + +"test.symbol"() {sym_name = "foo_3", sym_visibility = "nested"} : () -> () +"test.symbol"() {sym_name = "foo_4", sym_visibility = "private"} : () -> () +"test.symbol"() {sym_name = "foo_5", sym_visibility = "public"} : () -> () +"test.symbol"() {sym_name = "foo_6"} : () -> () + +// ----- + // Test that operation with the SymbolTable Trait define a new symbol scope. "test.symbol_scope"() ({ func @foo() { diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -74,9 +74,15 @@ } //===----------------------------------------------------------------------===// -// Test Operands +// Test Symbols //===----------------------------------------------------------------------===// +def SymbolOp : TEST_Op<"symbol", [Symbol]> { + let summary = "operation which defines a new symbol"; + let arguments = (ins StrAttr:$sym_name, + OptionalAttr:$sym_visibility); +} + def SymbolScopeOp : TEST_Op<"symbol_scope", [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { let summary = "operation which defines a new symbol table";