diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2005,6 +2005,88 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> { + let summary = "declare or define a global memref variable"; + let description = [{ + The `global_memref` operation declares or defines a named global variable. + The backing memory for the variable is allocated statically and is described + by the type of the variable (which should be a statically shaped memref + type). The operation is a declaration if no `inital_value` is specified, + else it is a definition. The `initial_value` can either be a unit attribute + to represent a definition of an uninitialized global variable, or an + elements attribute to represent the definition of a global variable with an + initial value. The global variable can also be marked constant using the + `constant` unit attribute. + + The global variable can be accessed by using the `get_global_memref` to + retrieve the memref pointing to the global variable. Note that the memref + for such global variable itself is immutable (i.e., get_global_memref for a + given global variable will always return the same memref descriptor). + + Example: + + ```mlir + // Private variable with an initial value. + global_memref @x : memref<2xf32> { sym_visibility = "private", + initial_value = dense<0.0,2.0> : tensor<2xf32> } + + // External variable. + global_memref @y : memref<4xi32> { sym_visibility = "public" } + + // Uninitialized externally visible variable. + global_memref @z : memref<3xf16> { sym_visibility = "public", + initial_value } + ``` + }]; + + let arguments = (ins + StrAttr:$sym_name, + TypeAttr:$type, + OptionalAttr:$initial_value, + UnitAttr:$constant + ); + + let extraClassDeclaration = [{ + bool isExternal() { return !initial_value().hasValue(); } + bool isUnitialized() { + return initial_value().hasValue() && + initial_value().getValue().isa(); + } + }]; +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GetGlobalMemrefOp : Std_Op<"get_global_memref", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "get the memref pointing to a global variable"; + let description = [{ + The `get_global_memref` operation retrieves the memref pointing to a + named global variable. If the global variable is marked constant, writing + to the result memref (such as through a `std.store` operation) is + undefined. + + Example: + + ```mlir + %x = get_global_memref @foo : memref<2xf32> + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$name); + let results = (outs AnyType:$result); + let assemblyFormat = "$name `:` type($result) attr-dict"; + + // `GetGlobalMemrefOp` is fully verified by its traits. + let verifier = ?; +} + //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -239,6 +239,18 @@ return false; } +//===----------------------------------------------------------------------===// +// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp +//===----------------------------------------------------------------------===// + +static Type getTensorTypeFromMemRefType(Type type) { + if (auto memref = type.dyn_cast()) + return RankedTensorType::get(memref.getShape(), memref.getElementType()); + if (auto memref = type.dyn_cast()) + return UnrankedTensorType::get(memref.getElementType()); + return NoneType::get(type.getContext()); +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -2134,6 +2146,97 @@ return areVectorCastSimpleCompatible(a, b, areCastCompatible); } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, GlobalMemrefOp op) { + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperationName().drop_front(stdDotLen) << ' '; + p.printSymbolName(op.sym_name()); + p << " : " << op.type(); + p.printOptionalAttrDict(op.getAttrs(), + {SymbolTable::getSymbolAttrName(), "type"}); +} + +static ParseResult parseGlobalMemrefOp(OpAsmParser &parser, + OperationState &result) { + StringAttr nameAttr; + Type type; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes) || + parser.parseColonType(type) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (result.attributes.get("type")) { + return parser.emitError(parser.getNameLoc()) + << "type should not be specified in the attributes"; + } + + // Add the type to the list of attributes + result.addAttribute("type", TypeAttr::get(type)); + return success(); +} + +static LogicalResult verify(GlobalMemrefOp op) { + auto memrefType = op.type().dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) + return op.emitOpError("type should be static shaped memref"); + + // Verify that the initial value, if present, is either a unit attribute or + // an elements attribute. + if (op.initial_value().hasValue()) { + Attribute initValue = op.initial_value().getValue(); + if (!initValue.isa() && !initValue.isa()) { + return op.emitOpError( + "initial value should be a unit or elements attribute"); + } + + // Check that the type of the initial value is compatible with the type of + // the global variable. + if (initValue.isa()) { + Type initType = initValue.getType(); + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (initType != tensorType) { + return op.emitOpError("initial value expected to be of type ") + << tensorType << ", but was of type " << initType; + } + } + } + + // External variables should have public visibility. + if (op.isExternal() && !op.isPublic()) { + return op.emitOpError( + "external global variables should have public visibility"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +LogicalResult +GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Verify that the result type is same as the type of the referenced + // global_memref op. + auto global = + symbolTable.lookupNearestSymbolFrom(*this, nameAttr()); + if (!global) { + return emitOpError("'") + << name() << "' does not reference a valid global memref"; + } + + Type resultType = result().getType(); + if (global.type() != resultType) { + return emitOpError("result type ") + << resultType << " does not match type " << global.type() + << " of the global memref @" << name(); + } + return success(); +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -3853,18 +3956,6 @@ results.insert(context); } -//===----------------------------------------------------------------------===// -// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp -//===----------------------------------------------------------------------===// - -static Type getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) - return RankedTensorType::get(memref.getShape(), memref.getElementType()); - if (auto memref = type.dyn_cast()) - return UnrankedTensorType::get(memref.getElementType()); - return NoneType::get(type.getContext()); -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -231,3 +231,62 @@ memref_reshape %buf(%shape) : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]> } + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : i32 + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref<*xf32> + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref + +// ----- + +// expected-error @+1 {{initial value should be a unit or elements attribute}} +global_memref @foo : memref<2x2xf32> { initial_value = "foo" } + +// ----- + +// expected-error @+1 {{initial value expected to be of type 'tensor<2x2xf32>', but was of type 'tensor<2xf32>'}} +global_memref @foo : memref<2x2xf32> { initial_value = dense<[0.0, 1.0]> : tensor<2xf32> } + +// ----- + +// expected-error @+1 {{external global variables should have public visibility}} +global_memref @foo : memref<2x2xf32> { sym_visibility = "private" } + +// ----- + +func @nonexistent_global_memref() { + // expected-error @+1 {{'gv' does not reference a valid global memref}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + +// ----- + +func @foo() + +func @nonexistent_global_memref() { + // expected-error @+1 {{'foo' does not reference a valid global memref}} + %0 = get_global_memref @foo : memref<3xf32> + return +} + +// ----- + +global_memref @gv : memref<3xi32> + +func @mismatched_types() { + // expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -77,3 +77,37 @@ : (memref, memref) -> memref<*xf32> return %new_unranked : memref<*xf32> } + +// CHECK-LABEL: global_memref @memref0 : memref<2xf32> +global_memref @memref0 : memref<2xf32> + +// CHECK-LABEL: global_memref @memref1 : memref<2xf32> +// CHECK-SAME: constant +// CHECK-SAME: initial_value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> +global_memref @memref1 : memref<2xf32> + { initial_value = dense<[0.0, 1.0]> : tensor<2xf32>, constant } + +// CHECK-LABEL: global_memref @memref2 : memref<2xf32> +// CHECK-SAME: initial_value +global_memref @memref2 : memref<2xf32> { initial_value } + +// CHECK-LABEL: global_memref @memref3 : memref<2xf32> +// CHECK-SAME: initial_value +// CHECK-SAME: sym_visibility = "private" +global_memref @memref3 : memref<2xf32> { initial_value, sym_visibility = "private" } + +// CHECK-LABEL: func @write_global_memref +func @write_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = constant dense<[1.0, 2.0]> : tensor<2xf32> + tensor_store %1, %0 : memref<2xf32> + return +} + +// CHECK-LABEL: func @read_global_memref +func @read_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = tensor_load %0 : memref<2xf32> + return +} +