diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2005,6 +2005,98 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> { + let summary = "declare or define a global memref variable"; + let description = [{ + The `global_memref` operation declares or defines a named global variable. + The backing memory for the variable is allocated statically and is described + by the type of the variable (which should be a statically shaped memref + type). The operation is a declaration if no `inital_value` is specified, + else it is a definition. The `initial_value` can either be a unit attribute + to represent a definition of an uninitialized global variable, or an + elements attribute to represent the definition of a global variable with an + initial value. The global variable can also be marked constant using the + `constant` unit attribute. Writing to such constant global variables is + undefined. + + The global variable can be accessed by using the `get_global_memref` to + retrieve the memref for the global variable. Note that the memref + for such global variable itself is immutable (i.e., get_global_memref for a + given global variable will always return the same memref descriptor). + + Example: + + ```mlir + // Private variable with an initial value. + global_memref @x : memref<2xf32> { sym_visibility = "private", + initial_value = dense<0.0,2.0> : tensor<2xf32> } + + // External variable. + global_memref @y : memref<4xi32> { sym_visibility = "public" } + + // Uninitialized externally visible variable. + global_memref @z : memref<3xf16> { sym_visibility = "public", + initial_value } + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + OptionalAttr:$sym_visibility, + TypeAttr:$type, + OptionalAttr:$initial_value, + UnitAttr:$constant + ); + + // global_memref visibility? constant? @name : type [= initial value]? {attributes} + let assemblyFormat = [{ + custom($sym_visibility) + custom($constant) + $sym_name `:` + custom($type, $initial_value) + attr-dict + }]; + + let extraClassDeclaration = [{ + bool isExternal() { return !initial_value(); } + bool isUnitialized() { + return !isExternal() && initial_value().getValue().isa(); + } + }]; +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GetGlobalMemrefOp : Std_Op<"get_global_memref", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "get the memref pointing to a global variable"; + let description = [{ + The `get_global_memref` operation retrieves the memref pointing to a + named global variable. If the global variable is marked constant, writing + to the result memref (such as through a `std.store` operation) is + undefined. + + Example: + + ```mlir + %x = get_global_memref @foo : memref<2xf32> + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$name); + let results = (outs AnyStaticShapeMemRef:$result); + let assemblyFormat = "$name `:` type($result) attr-dict"; + + // `GetGlobalMemrefOp` is fully verified by its traits. + let verifier = ?; +} + //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -395,7 +395,7 @@ // Parse any kind of attribute. Attribute attr; - if (parseAttribute(attr)) + if (parseAttribute(attr, type)) return failure(); // Check for the right kind of attribute. diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -72,6 +72,9 @@ Nested, }; + /// Populates the given vector with all valid visibility names. + static void getVisibilityNames(SmallVectorImpl &names); + /// Returns the name of the given symbol operation. static StringRef getSymbolName(Operation *symbol); /// Sets the name of the given symbol operation. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -245,6 +245,18 @@ return false; } +//===----------------------------------------------------------------------===// +// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp +//===----------------------------------------------------------------------===// + +static Type getTensorTypeFromMemRefType(Type type) { + if (auto memref = type.dyn_cast()) + return RankedTensorType::get(memref.getShape(), memref.getElementType()); + if (auto memref = type.dyn_cast()) + return UnrankedTensorType::get(memref.getElementType()); + return NoneType::get(type.getContext()); +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -2140,6 +2152,140 @@ return areVectorCastSimpleCompatible(a, b, areCastCompatible); } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// + +static void printGlobalMemrefOpVisibility(OpAsmPrinter &p, GlobalMemrefOp op, + StringAttr sym_visibility) { + if (sym_visibility) + p << sym_visibility.getValue(); +} + +static ParseResult parseGlobalMemrefOpVisibility(OpAsmParser &parser, + StringAttr &sym_visibility) { + // Parse the next token if it's a visibility token. + SmallVector visibilityNames; + SymbolTable::getVisibilityNames(visibilityNames); + + for (StringRef visibility : visibilityNames) { + if (succeeded(parser.parseOptionalKeyword(visibility))) { + if (visibility != "public") + sym_visibility = + StringAttr::get(visibility, parser.getBuilder().getContext()); + break; + } + } + return success(); +} + +static void printGlobalMemrefOpConstant(OpAsmPrinter &p, GlobalMemrefOp op, + UnitAttr constant) { + if (constant) + p << "constant"; +} + +static ParseResult parseGlobalMemrefOpConstant(OpAsmParser &parser, + UnitAttr &constant) { + if (succeeded(parser.parseOptionalKeyword("constant"))) + constant = UnitAttr::get(parser.getBuilder().getContext()); + return success(); +} + +static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, + GlobalMemrefOp op, + TypeAttr type, + Attribute initial_value) { + p << type; + if (!op.isExternal()) { + p << " = "; + if (op.isUnitialized()) + p << "uninitialized"; + else + p.printAttributeWithoutType(initial_value); + } +} + +static ParseResult +parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, + Attribute &initial_value) { + Type type; + if (parser.parseType(type)) + return failure(); + + auto memrefType = type.dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) + return parser.emitError(parser.getNameLoc()) + << "type should be static shaped memref, but got " << type; + typeAttr = TypeAttr::get(type); + + if (succeeded(parser.parseOptionalEqual())) { + if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { + initial_value = UnitAttr::get(parser.getBuilder().getContext()); + } else { + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (parser.parseAttribute(initial_value, tensorType)) + return failure(); + if (!initial_value.isa()) + return parser.emitError(parser.getNameLoc()) + << "initial value should be a unit or elements attribute"; + } + } + return success(); +} + +static LogicalResult verify(GlobalMemrefOp op) { + auto memrefType = op.type().dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) + return op.emitOpError("type should be static shaped memref, but got ") + << op.type(); + + // Verify that the initial value, if present, is either a unit attribute or + // an elements attribute. + if (op.initial_value().hasValue()) { + Attribute initValue = op.initial_value().getValue(); + if (!initValue.isa() && !initValue.isa()) + return op.emitOpError("initial value should be a unit or elements " + "attribute, but got ") + << initValue; + + // Check that the type of the initial value is compatible with the type of + // the global variable. + if (initValue.isa()) { + Type initType = initValue.getType(); + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (initType != tensorType) + return op.emitOpError("initial value expected to be of type ") + << tensorType << ", but was of type " << initType; + } + } + + // TODO: verify visibility for declarations. + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +LogicalResult +GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Verify that the result type is same as the type of the referenced + // global_memref op. + auto global = + symbolTable.lookupNearestSymbolFrom(*this, nameAttr()); + if (!global) + return emitOpError("'") + << name() << "' does not reference a valid global memref"; + + Type resultType = result().getType(); + if (global.type() != resultType) + return emitOpError("result type ") + << resultType << " does not match type " << global.type() + << " of the global memref @" << name(); + return success(); +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -3891,18 +4037,6 @@ results.insert(context); } -//===----------------------------------------------------------------------===// -// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp -//===----------------------------------------------------------------------===// - -static Type getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) - return RankedTensorType::get(memref.getShape(), memref.getElementType()); - if (auto memref = type.dyn_cast()) - return UnrankedTensorType::get(memref.getElementType()); - return NoneType::get(type.getContext()); -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -186,6 +186,12 @@ assert(name && "expected valid symbol name"); return *name; } + +/// Populates the given vector with all valid visibility names. +void SymbolTable::getVisibilityNames(SmallVectorImpl &names) { + names.assign({"private", "nested", "public"}); +} + /// Sets the name of the given symbol operation. void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { symbol->setAttr(getSymbolAttrName(), diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -807,6 +807,7 @@ /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { + auto attribLoc = getToken().getLoc(); consumeToken(Token::kw_dense); if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; @@ -819,11 +820,14 @@ return nullptr; } - auto typeLoc = getToken().getLoc(); + // If the type is specified `parseElementsLiteralType` will not parse a type. + // Use the attribute location as the location for error reporting in that + // case. + auto loc = attrType ? attribLoc : getToken().getLoc(); auto type = parseElementsLiteralType(attrType); if (!type) return nullptr; - return literalParser.getAttr(typeLoc, type); + return literalParser.getAttr(loc, type); } /// Parse an opaque elements attribute. diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -231,3 +231,78 @@ memref_reshape %buf(%shape) : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]> } + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : i32 + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : i32 = 5 + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref<*xf32> + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref + +// ----- + +// expected-error @+1 {{initial value should be a unit or elements attribute}} +global_memref @foo : memref<2x2xf32> = "foo" + +// ----- + +// expected-error @+1 {{inferred shape of elements literal ([2]) does not match type ([2, 2])}} +global_memref @foo : memref<2x2xf32> = dense<[0.0, 1.0]> + +// ----- + +// expected-error @+1 {{expected valid '@'-identifier for symbol name}} +global_memref private public @foo : memref<2x2xf32> = "foo" + +// ----- + +// expected-error @+1 {{expected valid '@'-identifier for symbol name}} +global_memref private constant external @foo : memref<2x2xf32> = "foo" + +// ----- + +// constant qualifier must be after visibility. +// expected-error @+1 {{expected valid '@'-identifier for symbol name}} +global_memref constant private @foo : memref<2x2xf32> = "foo" + +// ----- + +func @nonexistent_global_memref() { + // expected-error @+1 {{'gv' does not reference a valid global memref}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + +// ----- + +func @foo() + +func @nonexistent_global_memref() { + // expected-error @+1 {{'foo' does not reference a valid global memref}} + %0 = get_global_memref @foo : memref<3xf32> + return +} + +// ----- + +global_memref @gv : memref<3xi32> + +func @mismatched_types() { + // expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -77,3 +77,35 @@ : (memref, memref) -> memref<*xf32> return %new_unranked : memref<*xf32> } + +// CHECK-LABEL: global_memref @memref0 : memref<2xf32> +global_memref @memref0 : memref<2xf32> + +// CHECK-LABEL: global_memref constant @memref1 : memref<2xf32> +// CHECK-SAME: = dense<[0.000000e+00, 1.000000e+00]> +global_memref constant @memref1 : memref<2xf32> = dense<[0.0, 1.0]> + +// CHECK-LABEL: global_memref @memref2 : memref<2xf32> = uninitialized +global_memref @memref2 : memref<2xf32> = uninitialized + +// CHECK-LABEL: global_memref private @memref3 : memref<2xf32> = uninitialized +global_memref private @memref3 : memref<2xf32> = uninitialized + +// CHECK-LABEL: global_memref private constant @memref4 : memref<2xf32> = uninitialized +global_memref private constant @memref4 : memref<2xf32> = uninitialized + +// CHECK-LABEL: func @write_global_memref +func @write_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = constant dense<[1.0, 2.0]> : tensor<2xf32> + tensor_store %1, %0 : memref<2xf32> + return +} + +// CHECK-LABEL: func @read_global_memref +func @read_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = tensor_load %0 : memref<2xf32> + return +} +