diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt @@ -1,3 +1,10 @@ set(LLVM_TARGET_DEFINITIONS MLProgramOps.td) add_mlir_dialect(MLProgramOps ml_program) add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS MLProgramAttributes.td) +mlir_tablegen(MLProgramAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(MLProgramAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRMLProgramAttributesIncGen) +add_dependencies(mlir-headers MLIRMLProgramAttributesIncGen) +add_mlir_doc(MLProgramAttributes MLProgramAttributes Dialects/ -gen-attrdef-doc) diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h @@ -8,6 +8,7 @@ #ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ #define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h @@ -0,0 +1,21 @@ +//===- MLProgramAttributes.h - Attribute Classes ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_ +#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_ + +#include "mlir/IR/Attributes.h" + +//===----------------------------------------------------------------------===// +// Tablegen Attribute Declarations +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h.inc" + +#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_ diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td @@ -0,0 +1,44 @@ +//===- MLProgramAttributed.td - Attr definitions -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLPROGRAM_ATTRIBUTES +#define MLPROGRAM_ATTRIBUTES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/MLProgram/IR/MLProgramBase.td" + +// Base class for MLProgram dialect attributes. +class MLProgram_Attr traits = []> + : AttrDef { + let mnemonic = ?; +} + +//===----------------------------------------------------------------------===// +// ExternAttr +//===----------------------------------------------------------------------===// + +def MLProgram_ExternAttr : MLProgram_Attr<"Extern"> { + let summary = "Value used for a global signalling external resolution"; + let description = [{ + When used as the value for a GlobalOp, this indicates that the actual + value should be resolved externally in an implementation defined manner. + The `sym_name` of the global is the key for locating the value. + + Examples: + + ```mlir + extern : tensor<4xi32> + ``` + }]; + + let parameters = (ins AttributeSelfTypeParameter<"">:$type); + let mnemonic = "extern"; + let assemblyFormat = "`<` $type `>`"; +} + +#endif // MLPROGRAM_ATTRIBUTES diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td @@ -27,6 +27,15 @@ it is recommended to inquire further prior to using this dialect. }]; + let extraClassDeclaration = [{ + private: + Attribute parseAttribute(DialectAsmParser& parser, Type type) const override; + void printAttribute(Attribute attr, DialectAsmPrinter& p) const override; + + public: + }]; + + let useDefaultAttributePrinterParser = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; } diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -96,6 +96,101 @@ let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalOp : MLProgram_Op<"global", [ + Symbol + ]> { + let summary = "Module level declaration of a global variable"; + let description = [{ + Declares a named global variable (or constant). + + A global contains a value of a specified type which can be accessed at + runtime via appropriate load/store operations. It can be mutable or + constant, optionally taking an initial value or declared as + extern (in which case, the initial value is found in external storage + by symbol name). + + Generally, the type of the global and the type of the initial value + will be the same. However, for type hierarchies which can have a more + generalized bounding type that can be assigned from a narrow type, this + is allowed (but not verified). + + Examples: + + ```mlir + // Constant global. + ml_program.global @foobar(dense<4> : tensor<4xi32>) : tensor + + // Constant with external linkage. + ml_program.global mutable @foobar(#ml_program.extern>) + : tensor + + // Mutable global with an undefined initial value. + ml_program.global mutable @foobar : tensor + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttr:$type, + UnitAttr:$is_mutable, + OptionalAttr:$value, + OptionalAttr:$sym_visibility + ); + + let assemblyFormat = [{ + custom($sym_visibility) + (`mutable` $is_mutable^)? + $sym_name `` + custom($type, $value) + attr-dict + }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// GlobalLoadConstOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [ + NoSideEffect, + DeclareOpInterfaceMethods + ]> { + let summary = "Direct load a constant value from a global"; + let description = [{ + Loads a constant (immutable) value from a global directly by symbol. + + This op is only legal for globals that are not mutable and exists because + such a load can be considered to have no side effects. + + Example: + + ```mlir + %0 = ml_program.global_load_const @foobar : tensor + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$global + ); + let results = (outs + AnyType:$result + ); + + let assemblyFormat = [{ + $global attr-dict `:` type($result) + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt @@ -7,6 +7,7 @@ DEPENDS MLIRMLProgramOpsIncGen + MLIRMLProgramAttributesIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp @@ -7,15 +7,42 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::ml_program; +//===----------------------------------------------------------------------===// +/// Tablegen Definitions +//===----------------------------------------------------------------------===// + #include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc" + +namespace { +struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (attr.isa()) { + os << "extern"; + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; + } +}; +} // namespace void ml_program::MLProgramDialect::initialize() { +#define GET_ATTRDEF_LIST + addAttributes< +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc" + >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" >(); + addInterfaces(); } diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -13,6 +13,70 @@ using namespace mlir; using namespace mlir::ml_program; +//===----------------------------------------------------------------------===// +// Custom asm helpers +//===----------------------------------------------------------------------===// + +/// some.op custom($type, $attr) +/// +/// Uninitialized: +// some.op : tensor<3xi32> +/// Initialized to same type as global: +/// some.op (dense<0> : tensor) +/// Initialized to narrower type than global: +/// some.op (dense<0> : tensor<3xi32>) : tensor +static ParseResult parseTypedInitialValue(OpAsmParser &parser, + TypeAttr &typeAttr, Attribute &attr) { + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseAttribute(attr))) + return failure(); + if (failed(parser.parseRParen())) + return failure(); + } + + Type type; + if (failed(parser.parseColonType(type))) + return failure(); + typeAttr = TypeAttr::get(type); + return success(); +} + +static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, + TypeAttr type, Attribute attr) { + if (attr) { + p << "("; + p.printAttribute(attr); + p << ")"; + } + + p << " : "; + p.printAttribute(type); +} + +/// some.op custom($sym_visibility) $sym_name +/// -> +/// some.op public @foo +/// some.op private @foo +static ParseResult parseSymbolVisibility(OpAsmParser &parser, + StringAttr &symVisibilityAttr) { + StringRef symVisibility; + parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"}); + if (symVisibility.empty()) + return parser.emitError(parser.getCurrentLocation()) + << "expected 'public', 'private', or 'nested'"; + if (!symVisibility.empty()) + symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); + return success(); +} + +static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, + StringAttr symVisibilityAttr) { + if (!symVisibilityAttr) + p << "public"; + else + p << symVisibilityAttr.getValue(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// @@ -38,6 +102,43 @@ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); } +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +LogicalResult GlobalOp::verify() { + if (!getIsMutable() && !getValue()) + return emitOpError() << "immutable global must have an initial value"; + return success(); +} + +//===----------------------------------------------------------------------===// +// GlobalLoadConstOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalLoadConstOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (referrent.getIsMutable()) + return emitOpError() << "cannot load as const from mutable global " + << getGlobal(); + + if (referrent.getType() != getResult().getType()) + return emitOpError() << "cannot load from global typed " + << referrent.getType() << " as " + << getResult().getType(); + + return success(); +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MLProgram/attrs.mlir b/mlir/test/Dialect/MLProgram/attrs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MLProgram/attrs.mlir @@ -0,0 +1,7 @@ +// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s + +// CHECK: #ml_program.extern +"unregistered.attributes"() { + value = #ml_program.extern +} : () -> () + diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir --- a/mlir/test/Dialect/MLProgram/invalid.mlir +++ b/mlir/test/Dialect/MLProgram/invalid.mlir @@ -31,3 +31,30 @@ // expected-error @+1 {{doesn't match function result}} ml_program.output %arg0 : i64 } + +// ----- +// expected-error @+1 {{immutable global must have an initial value}} +ml_program.global private @const : i32 + +// ----- +ml_program.func @undef_global() -> i32 { + // expected-error @+1 {{undefined global: nothere}} + %0 = ml_program.global_load_const @nothere : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.global private mutable @var : i32 +ml_program.func @mutable_const_load() -> i32 { + // expected-error @+1 {{op cannot load as const from mutable global var}} + %0 = ml_program.global_load_const @var : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.global private @var(42 : i64) : i64 +ml_program.func @const_load_type_mismatch() -> i32 { + // expected-error @+1 {{cannot load from global typed 'i64' as 'i32'}} + %0 = ml_program.global_load_const @var : i32 + ml_program.return %0 : i32 +} diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir --- a/mlir/test/Dialect/MLProgram/ops.mlir +++ b/mlir/test/Dialect/MLProgram/ops.mlir @@ -18,3 +18,12 @@ %0 = "unregistered.dummy"(%arg0) : (i32) -> i32 ml_program.output %0 : i32 } + +// CHECK: ml_program.global private @global_same_type(dense<4> : tensor<4xi32>) : tensor<4xi32> +ml_program.global private @global_same_type(dense<4> : tensor<4xi32>) : tensor<4xi32> + +// CHECK: ml_program.global private mutable @global_mutable_undef : tensor +ml_program.global private mutable @global_mutable_undef : tensor + +// CHECK: ml_program.global private mutable @global_extern(#extern) : tensor +ml_program.global private mutable @global_extern(#ml_program.extern>) : tensor