diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1119,6 +1119,9 @@ // Declaration of an external variable. memref.global "private" @y : memref<4xi32> + // Declaration of an internal variable. + memref.global "internal" @y : memref<4xi32> + // Uninitialized externally visible variable. memref.global @z : memref<3xf16> = uninitialized diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -58,6 +58,12 @@ return getVisibility() == mlir::SymbolTable::Visibility::Nested; }] >, + InterfaceMethod<"Returns true if this symbol has internal visibility.", + "bool", "isInternal", (ins), [{}], + /*defaultImplementation=*/[{ + return getVisibility() == mlir::SymbolTable::Visibility::Internal; + }] + >, InterfaceMethod<"Returns true if this symbol has private visibility.", "bool", "isPrivate", (ins), [{}], /*defaultImplementation=*/[{ diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -84,6 +84,10 @@ /// visibility allows for referencing a symbol outside of its current symbol /// table, while retaining the ability to observe all uses. Nested, + + /// The symbol is public and may be referenced internally + /// to the visible references in the IR. + Internal, }; /// Returns the name of the given symbol operation, aborting if no symbol is diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -379,7 +379,7 @@ }; /// This enum describes C++ inheritance visibility. -enum class Visibility { Public, Protected, Private }; +enum class Visibility { Public, Protected, Private, Internal }; /// Write "public", "protected", or "private". llvm::raw_ostream &operator<<(llvm::raw_ostream &os, diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -664,8 +664,9 @@ Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); - LLVM::Linkage linkage = - global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; + LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External + : global.isInternal() ? LLVM::Linkage::Internal + : LLVM::Linkage::Private; Attribute initialValue = nullptr; if (!global.isExternal() && !global.isUninitialized()) { diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -241,7 +241,8 @@ return StringSwitch(vis.getValue()) .Case("private", Visibility::Private) .Case("nested", Visibility::Nested) - .Case("public", Visibility::Public); + .Case("public", Visibility::Public) + .Case("internal", Visibility::Internal); } /// Sets the visibility of the given symbol operation. void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { @@ -255,10 +256,13 @@ } // Otherwise, update the attribute. - assert((vis == Visibility::Private || vis == Visibility::Nested) && + assert((vis == Visibility::Private || vis == Visibility::Nested || + vis == Visibility::Internal) && "unknown symbol visibility kind"); - StringRef visName = vis == Visibility::Private ? "private" : "nested"; + StringRef visName = vis == Visibility::Private ? "private" + : vis == Visibility::Nested ? "nested" + : "internal"; symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName)); } @@ -402,6 +406,8 @@ return os << "private"; case SymbolTable::Visibility::Nested: return os << "nested"; + case SymbolTable::Visibility::Internal: + return os << "internal"; } llvm_unreachable("Unexpected visibility"); } @@ -465,11 +471,12 @@ << mlir::SymbolTable::getVisibilityAttrName() << "' to be a string attribute, but got " << vis; - if (!llvm::is_contained(ArrayRef{"public", "private", "nested"}, - visStrAttr.getValue())) + if (!llvm::is_contained( + ArrayRef{"public", "private", "nested", "internal"}, + visStrAttr.getValue())) return op->emitOpError() << "visibility expected to be one of [\"public\", \"private\", " - "\"nested\"], but got " + "\"nested\", \"internal\"], but got " << visStrAttr; } return success(); diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp --- a/mlir/lib/TableGen/Class.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -209,6 +209,8 @@ return os << "protected"; case Visibility::Private: return os << "private"; + case Visibility::Internal: + return os << "internal"; } return os; } diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -598,3 +598,16 @@ memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> func.return } + +// ----- + +// CHECK: llvm.mlir.global internal @gv1() {addr_space = 3 : i32} : !llvm.array<1 x i64> +memref.global "internal" @gv1 : memref<1xi64,3> + +func.func @get_gv0_memref() { + %0 = memref.get_global @gv1 : memref<1xi64,3> + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : i64 + memref.store %c4, %0[%c0] : memref<1xi64,3> + return +} diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -338,7 +338,7 @@ // ----- -// expected-error @+1 {{op visibility expected to be one of ["public", "private", "nested"], but got "priate"}} +// expected-error @+1 {{op visibility expected to be one of ["public", "private", "nested", "internal"], but got "priate"}} memref.global "priate" constant @memref5 : memref<2xf32> = uninitialized // ----- diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -350,7 +350,7 @@ // ----- -// expected-error@+1 {{visibility expected to be one of ["public", "private", "nested"]}} +// expected-error@+1 {{visibility expected to be one of ["public", "private", "nested", "internal"]}} "test.symbol"() {sym_name = "foo_2", sym_visibility = "foo"} : () -> () // -----