diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2005,6 +2005,62 @@ let hasFolder = 0; } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GlobalMemrefOp : Std_Op<"global_memref", [NoSideEffect, Symbol]> { + let summary = "standard global variable"; + let description = [{ + The `global_memref` operation declares or defines a named global variable. + The operation is a declaration if no inital_value is specified, else its a + definition. The global variable can be marked constant using the + constant unit attribute. + + Also see `get_global_memref`, which can be used to retrieve the memref for + a named global variable. + + Example: + + ```mlir + global_memref @foo : memref<2xf32> { sym_visibility = "private", + initial_value = dense<0.0,2.0> : tensor<2xf32> } + ``` + }]; + + let arguments = (ins + StrAttr: $sym_name, + TypeAttr: $type, + OptionalAttr: $initial_value, + UnitAttr: $constant + ); + let extraClassDeclaration = [{ + bool isExternal() { return !initial_value().hasValue(); } + }]; +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// + +def GetGlobalMemrefOp : Std_Op<"get_global_memref", [NoSideEffect]> { + let summary = "get the memref pointing to a global variable"; + let description = [{ + The `get_global_memref` operation retrieves the memref pointing to a + named global variable. If the global variable is marked constant, writing + to the result memref is undefined. + + Example: + + ```mlir + %x = get_global_memref @foo : memref<2xf32> + ``` + }]; + let arguments = (ins FlatSymbolRefAttr:$name); + let results = (outs AnyType:$result); + let assemblyFormat = "$name `:` type($result) attr-dict"; +} + //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -239,6 +239,18 @@ return false; } +//===----------------------------------------------------------------------===// +// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp +//===----------------------------------------------------------------------===// + +static Type getTensorTypeFromMemRefType(Type type) { + if (auto memref = type.dyn_cast()) + return RankedTensorType::get(memref.getShape(), memref.getElementType()); + if (auto memref = type.dyn_cast()) + return UnrankedTensorType::get(memref.getElementType()); + return NoneType::get(type.getContext()); +} + //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// @@ -2134,6 +2146,87 @@ return areVectorCastSimpleCompatible(a, b, areCastCompatible); } +//===----------------------------------------------------------------------===// +// GlobalMemrefOp +//===----------------------------------------------------------------------===// +static void print(OpAsmPrinter &p, GlobalMemrefOp op) { + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperationName().drop_front(stdDotLen) << ' '; + p.printSymbolName(op.sym_name()); + p << " : " << op.type(); + p.printOptionalAttrDict(op.getAttrs(), + {::mlir::SymbolTable::getSymbolAttrName(), "type"}); +} + +static ParseResult parseGlobalMemrefOp(OpAsmParser &parser, + OperationState &result) { + StringAttr nameAttr; + Type type; + if (parser.parseSymbolName(nameAttr, ::mlir::SymbolTable::getSymbolAttrName(), + result.attributes) || + parser.parseColonType(type) || + parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (result.attributes.get("type")) { + return parser.emitError(parser.getNameLoc()) + << "type should not be specified in the attributes"; + } + + // Add the type to the list of attributes + result.addAttribute("type", TypeAttr::get(type)); + return success(); +} + +static LogicalResult verify(GlobalMemrefOp op) { + auto memrefType = op.type().dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) + return op.emitOpError("type should be static shaped memref"); + + // verify that the initial value, if present is either a unit attribute or + // elements attribute. + if (op.initial_value().hasValue()) { + Attribute initValue = op.initial_value().getValue(); + if (!initValue.isa() && !initValue.isa()) { + return op.emitOpError( + "initial value should be a unit or elements attribute"); + } + + // Check that the type of the initial value is compatible with the type of + // the global variable. + if (initValue.isa()) { + Type initType = initValue.getType(); + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (initType != tensorType) { + return op.emitOpError("initial value expected to be of type ") + << tensorType << ", but was of type " << initType; + } + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalMemrefOp +//===----------------------------------------------------------------------===// +static LogicalResult verify(GetGlobalMemrefOp op) { + // Verify that the result type is same as the type of the global_memref op. + auto global = + SymbolTable::lookupNearestSymbolFrom(op, op.name()); + if (!global) { + return op.emitOpError("cannot resolve @") + << op.name() << " to a global memref"; + } + + Type resultType = op.result().getType(); + if (global.type() != resultType) { + op.emitOpError("result type ") + << resultType << " does not match type " << global.type() + << " of the global memref @" << op.name(); + } + return success(); +} + //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// @@ -3853,18 +3946,6 @@ results.insert(context); } -//===----------------------------------------------------------------------===// -// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp -//===----------------------------------------------------------------------===// - -static Type getTensorTypeFromMemRefType(Type type) { - if (auto memref = type.dyn_cast()) - return RankedTensorType::get(memref.getShape(), memref.getElementType()); - if (auto memref = type.dyn_cast()) - return UnrankedTensorType::get(memref.getElementType()); - return NoneType::get(type.getContext()); -} - //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -231,3 +231,57 @@ memref_reshape %buf(%shape) : (memref<4x4xf32>, memref<1xi32>) -> memref<8xf32, offset: 0, strides: [2]> } + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : i32 + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref<*xf32> + +// ----- + +// expected-error @+1 {{type should be static shaped memref}} +global_memref @foo : memref + +// ----- + +// expected-error @+1 {{initial value should be a unit or elements attribute}} +global_memref @foo : memref<2x2xf32> { initial_value = "foo" } + +// ----- + +// expected-error @+1 {{initial value expected to be of type 'tensor<2x2xf32>', but was of type 'tensor<2xf32>'}} +global_memref @foo : memref<2x2xf32> { initial_value = dense<[0.0, 1.0]> : tensor<2xf32> } + +// ----- + +func @nonexistent_global_memref() { + // expected-error @+1 {{cannot resolve @gv to a global memref}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + +// ----- + +func @foo() + +func @nonexistent_global_memref() { + // expected-error @+1 {{cannot resolve @foo to a global memref}} + %0 = get_global_memref @foo : memref<3xf32> + return +} + +// ----- + +global_memref @gv : memref<3xi32> + +func @mismatched_types() { + // expected-error @+1 {{result type 'memref<3xf32>' does not match type 'memref<3xi32>' of the global memref @gv}} + %0 = get_global_memref @gv : memref<3xf32> + return +} + diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -77,3 +77,37 @@ : (memref, memref) -> memref<*xf32> return %new_unranked : memref<*xf32> } + +// CHECK-LABEL: global_memref @memref0 : memref<2xf32> +global_memref @memref0 : memref<2xf32> + +// CHECK-LABEL: global_memref @memref1 : memref<2xf32> +// CHECK-SAME: constant +// CHECK-SAME: initial_value = dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> +global_memref @memref1 : memref<2xf32> + { initial_value = dense<[0.0, 1.0]> : tensor<2xf32>, constant } + +// CHECK-LABEL: global_memref @memref2 : memref<2xf32> +// CHECK-SAME: initial_value +global_memref @memref2 : memref<2xf32> { initial_value } + +// CHECK-LABEL: global_memref @memref3 : memref<2xf32> +// CHECK-SAME: initial_value +// CHECK-SAME: sym_visibility = "private" +global_memref @memref3 : memref<2xf32> { initial_value, sym_visibility = "private" } + +// CHECK-LABEL: func @write_global_memref +func @write_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = constant dense<[1.0, 2.0]> : tensor<2xf32> + tensor_store %1, %0 : memref<2xf32> + return +} + +// CHECK-LABEL: func @read_global_memref +func @read_global_memref() { + %0 = get_global_memref @memref0 : memref<2xf32> + %1 = tensor_load %0 : memref<2xf32> + return +} +