diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2005,6 +2005,89 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> { + let summary = "declare or define a global memref variable"; + let description = [{ + The `global_memref` operation declares or defines a named global variable. + The backing memory for the variable is allocated statically and is described + by the type of the variable (which should be a statically shaped memref + type). The operation is a declaration if no `inital_value` is specified, + else it is a definition. The `initial_value` can either be a unit attribute + to represent a definition of an uninitialized global variable, or an + elements attribute to represent the definition of a global variable with an + initial value. The global variable can also be marked constant using the + `constant` unit attribute. Writing to such constant global variables is + undefined. + + The global variable can be accessed by using the `get_global_memref` to + retrieve the memref for the global variable. Note that the memref + for such global variable itself is immutable (i.e., get_global_memref for a + given global variable will always return the same memref descriptor). + + Example: + + ```mlir + // Private variable with an initial value. + global_memref @x : memref<2xf32> { sym_visibility = "private", + initial_value = dense<0.0,2.0> : tensor<2xf32> } + + // External variable. + global_memref @y : memref<4xi32> { sym_visibility = "public" } + + // Uninitialized externally visible variable. + global_memref @z : memref<3xf16> { sym_visibility = "public", + initial_value } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttr:$type, + OptionalAttr:$initial_value, + UnitAttr:$constant + ); + + let extraClassDeclaration = [{ + bool isExternal() { return !initial_value().hasValue(); } + bool isUnitialized() { + return !isExternal() && initial_value().getValue().isa(); + } + static SmallVector getElidedAttributes(); + }]; +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GetGlobalMemrefOp : Std_Op<"get_global_memref", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "get the memref pointing to a global variable"; + let description = [{ + The `get_global_memref` operation retrieves the memref pointing to a + named global variable. If the global variable is marked constant, writing + to the result memref (such as through a `std.store` operation) is + undefined. + + Example: + + ```mlir + %x = get_global_memref @foo : memref<2xf32> + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$name); + let results = (outs AnyStaticShapeMemRef:$result); + let assemblyFormat = "$name `:` type($result) attr-dict"; + + // `GetGlobalMemrefOp` is fully verified by its traits. + let verifier = ?; +} + //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -395,7 +395,7 @@ // Parse any kind of attribute. Attribute attr; - if (parseAttribute(attr)) + if (parseAttribute(attr, type)) return failure(); // Check for the right kind of attribute. diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -72,6 +72,13 @@ Nested, }; + /// Returns the name of the given visibility type. + static StringRef getVisibilityName(Visibility visibility); + /// Returns the visibility given the visibility name. + static Visibility getVisibilityFromName(StringRef name); + /// Returns if the given name is a valid visibility name. + static bool isVisibilityName(StringRef name); + /// Returns the name of the given symbol operation. static StringRef getSymbolName(Operation *symbol); /// Sets the name of the given symbol operation. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -245,6 +245,18 @@ return false; } +//===----------------------------------------------------------------------===// +// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp +//===----------------------------------------------------------------------===// + +static Type getTensorTypeFromMemRefType(Type type) { + if (auto memref = type.dyn_cast()) + return RankedTensorType::get(memref.getShape(), memref.getElementType()); + if (auto memref = type.dyn_cast()) + return UnrankedTensorType::get(memref.getElementType()); + return NoneType::get(type.getContext()); +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -2140,6 +2152,174 @@ return areVectorCastSimpleCompatible(a, b, areCastCompatible); } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// + +SmallVector GlobalMemrefOp::getElidedAttributes() { + return {SymbolTable::getSymbolAttrName(), + SymbolTable::getVisibilityAttrName(), "type", "constant", + "initial_value"}; +} + +static void print(OpAsmPrinter &p, GlobalMemrefOp op) { + // Syntax for global memref op: + // global_memref visibility? constant? @name : type [= initial_value]? + // where: visibility = private | nested (default visibility is public) + // initial_value = elements attribute | uninitialized + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperationName().drop_front(stdDotLen) << ' '; + + if (op.getVisibility() != SymbolTable::Visibility::Public) + p << SymbolTable::getVisibilityName(op.getVisibility()) << ' '; + + if (op.constant()) + p << "constant "; + + p.printSymbolName(op.sym_name()); + p << " : " << op.type(); + + if (!op.isExternal()) { + p << " = "; + if (op.isUnitialized()) { + p << "uninitialized"; + } else { + p.printAttributeWithoutType(op.initial_value().getValue()); + } + } + + p.printOptionalAttrDict(op.getAttrs(), op.getElidedAttributes()); +} + +static ParseResult parseGlobalMemrefOp(OpAsmParser &parser, + OperationState &result) { + NamedAttrList attributes; + + // Parse optional visibility. + StringRef keyword; + bool foundKeyword = succeeded(parser.parseOptionalKeyword(&keyword)); + if (foundKeyword && SymbolTable::isVisibilityName(keyword)) { + attributes.append(SymbolTable::getVisibilityAttrName(), + StringAttr::get(keyword, result.getContext())); + foundKeyword = succeeded(parser.parseOptionalKeyword(&keyword)); + } + + // Parse optional constant. + if (foundKeyword) { + if (!keyword.equals("constant")) { + return parser.emitError(parser.getNameLoc()) + << "expected `constant` keyword"; + } + attributes.append("constant", UnitAttr::get(result.getContext())); + } + + // Parse @name : type + StringAttr nameAttr; + Type type; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + attributes) || + parser.parseColonType(type)) + return failure(); + + auto memrefType = type.dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) { + return parser.emitError(parser.getNameLoc()) + << "type should be static shaped memref, but got " << type; + } + + attributes.append("type", TypeAttr::get(type)); + + // Parse optional initial value. + if (succeeded(parser.parseOptionalEqual())) { + Attribute initValueAttr; + if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { + initValueAttr = UnitAttr::get(result.getContext()); + } else { + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (parser.parseAttribute(initValueAttr, tensorType)) + return failure(); + if (!initValueAttr.isa()) { + return parser.emitError(parser.getNameLoc()) + << "initial value should be a unit or elements attribute"; + } + } + attributes.append("initial_value", initValueAttr); + } + + // Parse optional attributes directory. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // The elided attributes are inferred from other places, so they cannot be + // in the parsed attribute dictionary. + for (StringRef attrName : GlobalMemrefOp::getElidedAttributes()) { + if (result.attributes.get(attrName)) { + return parser.emitError(parser.getNameLoc()) + << "attribute '" << attrName << "' not allowed in attributes"; + } + } + + result.addAttributes(attributes.getAttrs()); + return success(); +} + +static LogicalResult verify(GlobalMemrefOp op) { + auto memrefType = op.type().dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) { + return op.emitOpError("type should be static shaped memref, but got ") + << op.type(); + } + + // Verify that the initial value, if present, is either a unit attribute or + // an elements attribute. + if (op.initial_value().hasValue()) { + Attribute initValue = op.initial_value().getValue(); + if (!initValue.isa() && !initValue.isa()) { + return op.emitOpError("initial value should be a unit or elements " + "attribute, but got ") + << initValue; + } + + // Check that the type of the initial value is compatible with the type of + // the global variable. + if (initValue.isa()) { + Type initType = initValue.getType(); + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (initType != tensorType) { + return op.emitOpError("initial value expected to be of type ") + << tensorType << ", but was of type " << initType; + } + } + } + + // TODO: verify visibility for declarations. + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +LogicalResult +GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Verify that the result type is same as the type of the referenced + // global_memref op. + auto global = + symbolTable.lookupNearestSymbolFrom(*this, nameAttr()); + if (!global) { + return emitOpError("'") + << name() << "' does not reference a valid global memref"; + } + + Type resultType = result().getType(); + if (global.type() != resultType) { + return emitOpError("result type ") + << resultType << " does not match type " << global.type() + << " of the global memref @" << name(); + } + return success(); +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -3891,18 +4071,6 @@ results.insert(context); } -//===----------------------------------------------------------------------===// -// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp -//===----------------------------------------------------------------------===// - -static Type getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) - return RankedTensorType::get(memref.getShape(), memref.getElementType()); - if (auto memref = type.dyn_cast()) - return UnrankedTensorType::get(memref.getElementType()); - return NoneType::get(type.getContext()); -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -180,12 +180,43 @@ setSymbolName(symbol, nameBuffer); } +/// Returns the name of the given visibility type. +StringRef SymbolTable::getVisibilityName(Visibility visibility) { + switch (visibility) { + case Visibility::Private: + return "private"; + case Visibility::Nested: + return "nested"; + case Visibility::Public: + return "public"; + } + llvm_unreachable("unknown visibility"); +} + +/// Returns the visibility given the visibility name. +SymbolTable::Visibility SymbolTable::getVisibilityFromName(StringRef name) { + return StringSwitch(name) + .Case("private", Visibility::Private) + .Case("nested", Visibility::Nested) + .Case("public", Visibility::Public); +} + /// Returns the name of the given symbol operation. StringRef SymbolTable::getSymbolName(Operation *symbol) { Optional name = getNameIfSymbol(symbol); assert(name && "expected valid symbol name"); return *name; } + +/// Returns if the given name is a valid visibility name. +bool SymbolTable::isVisibilityName(StringRef name) { + return StringSwitch(name) + .Case("private", true) + .Case("nested", true) + .Case("public", true) + .Default(false); +} + /// Sets the name of the given symbol operation. void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { symbol->setAttr(getSymbolAttrName(), @@ -199,11 +230,8 @@ if (!vis) return Visibility::Public; - // Otherwise, switch on the string value. - return StringSwitch(vis.getValue()) - .Case("private", Visibility::Private) - .Case("nested", Visibility::Nested) - .Case("public", Visibility::Public); + // Otherwise, convert the attribute to visibility. + return getVisibilityFromName(vis.getValue()); } /// Sets the visibility of the given symbol operation. void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { @@ -220,7 +248,7 @@ assert((vis == Visibility::Private || vis == Visibility::Nested) && "unknown symbol visibility kind"); - StringRef visName = vis == Visibility::Private ? "private" : "nested"; + StringRef visName = getVisibilityName(vis); symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); } diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -807,6 +807,7 @@ /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { + auto attribLoc = getToken().getLoc(); consumeToken(Token::kw_dense); if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; @@ -819,11 +820,14 @@ return nullptr; } - auto typeLoc = getToken().getLoc(); + // If the type is specified `parseElementsLiteralType` will not parse a type. + // Use the attribute location as the location for error reporting in that + // case. + auto loc = attrType ? attribLoc : getToken().getLoc(); auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; - return literalParser.getAttr(typeLoc, type); + return literalParser.getAttr(loc, type); } /// Parse an opaque elements attribute. diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -231,3 +231,78 @@ memref_reshape %buf(%shape) : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]> } + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : i32 + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : i32 = 5 + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref<*xf32> + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref + +// ----- + +// expected-error @+1 {{initial value should be a unit or elements attribute}} +global_memref @foo : memref<2x2xf32> = "foo" + +// ----- + +// expected-error @+1 {{inferred shape of elements literal ([2]) does not match type ([2, 2])}} +global_memref @foo : memref<2x2xf32> = dense<[0.0, 1.0]> + +// ----- + +// expected-error @+1 {{expected `constant` keyword}} +global_memref private public @foo : memref<2x2xf32> = "foo" + +// ----- + +// expected-error @+1 {{expected valid '@'-identifier for symbol name}} +global_memref private constant external @foo : memref<2x2xf32> = "foo" + +// ----- + +// constant qualifier must be after visibility. +// expected-error @+1 {{expected valid '@'-identifier for symbol name}} +global_memref constant private @foo : memref<2x2xf32> = "foo" + +// ----- + +func @nonexistent_global_memref() { + // expected-error @+1 {{'gv' does not reference a valid global memref}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + +// ----- + +func @foo() + +func @nonexistent_global_memref() { + // expected-error @+1 {{'foo' does not reference a valid global memref}} + %0 = get_global_memref @foo : memref<3xf32> + return +} + +// ----- + +global_memref @gv : memref<3xi32> + +func @mismatched_types() { + // expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -77,3 +77,35 @@ : (memref, memref) -> memref<*xf32> return %new_unranked : memref<*xf32> } + +// CHECK-LABEL: global_memref @memref0 : memref<2xf32> +global_memref @memref0 : memref<2xf32> + +// CHECK-LABEL: global_memref constant @memref1 : memref<2xf32> +// CHECK-SAME: = dense<[0.000000e+00, 1.000000e+00]> +global_memref constant @memref1 : memref<2xf32> = dense<[0.0, 1.0]> + +// CHECK-LABEL: global_memref @memref2 : memref<2xf32> = uninitialized +global_memref @memref2 : memref<2xf32> = uninitialized + +// CHECK-LABEL: global_memref private @memref3 : memref<2xf32> = uninitialized +global_memref private @memref3 : memref<2xf32> = uninitialized + +// CHECK-LABEL: global_memref private constant @memref4 : memref<2xf32> = uninitialized +global_memref private constant @memref4 : memref<2xf32> = uninitialized + +// CHECK-LABEL: func @write_global_memref +func @write_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = constant dense<[1.0, 2.0]> : tensor<2xf32> + tensor_store %1, %0 : memref<2xf32> + return +} + +// CHECK-LABEL: func @read_global_memref +func @read_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = tensor_load %0 : memref<2xf32> + return +} +