diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt @@ -1,3 +1,10 @@ set(LLVM_TARGET_DEFINITIONS MLProgramOps.td) add_mlir_dialect(MLProgramOps ml_program) add_mlir_doc(MLProgramOps MLProgramOps Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS MLProgramAttributes.td) +mlir_tablegen(MLProgramAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(MLProgramAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRMLProgramAttributesIncGen) +add_dependencies(mlir-headers MLIRMLProgramAttributesIncGen) +add_mlir_doc(MLProgramAttributes MLProgramAttributes Dialects/ -gen-attrdef-doc) diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h @@ -8,6 +8,7 @@ #ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ #define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.h @@ -0,0 +1,21 @@ +//===- MLProgramAttributes.h - Attribute Classes ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_ +#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_ + +#include "mlir/IR/Attributes.h" + +//===----------------------------------------------------------------------===// +// Tablegen Attribute Declarations +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h.inc" + +#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMATTRIBUTES_H_ diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td @@ -0,0 +1,45 @@ +//===- MLProgramAttributed.td - Attr definitions -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLPROGRAM_ATTRIBUTES +#define MLPROGRAM_ATTRIBUTES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/MLProgram/IR/MLProgramBase.td" + +// Base class for MLProgram dialect attributes. +class MLProgram_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { + let mnemonic = ?; +} + +//===----------------------------------------------------------------------===// +// ExternAttr +//===----------------------------------------------------------------------===// + +def MLProgram_ExternAttr : MLProgram_Attr<"Extern", []> { + let summary = "Value used for a global signalling external resolution"; + let description = [{ + When used as the value for a GlobalOp, this indicates that the actual + value should be resolved externally in an implementation defined manner. + The `sym_name` of the global is the key for locating the value. + + Examples: + + ```mlir + extern : tensor<4xi32> + ``` + }]; + + let parameters = (ins AttributeSelfTypeParameter<"">:$type); + let mnemonic = "extern"; + let assemblyFormat = "`<` $type `>`"; +} + +#endif // MLPROGRAM_ATTRIBUTES diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td @@ -27,6 +27,13 @@ it is recommended to inquire further prior to using this dialect. }]; + let extraClassDeclaration = [{ + private: + Attribute parseAttribute(DialectAsmParser& parser, Type type) const override; + void printAttribute(Attribute attr, DialectAsmPrinter& p) const override; + public: + }]; + let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; } diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -96,6 +96,101 @@ let hasCustomAssemblyFormat = 1; } +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalOp : MLProgram_Op<"global", [ + Symbol + ]> { + let summary = "Module level declaration of a global variable"; + let description = [{ + Declares a named global variable (or constant). + + A global contains a value of a specified type which can be accessed at + runtime via appropriate load/store operations. It can be mutable or + constant, optionally taking an initial value or declared as + extern (in which case, the initial value is found in external storage + by symbol name). + + Generally, the type of the global and the type of the initial value + will be the same. However, for type hierarchies which can have a more + generalized bounding type that can be assigned from a narrow type, this + is allowed (but not verified). + + Examples: + // Constant global that stores the same type as its initial value. + ml_program.global @foobar = dense<4> : tensor<4xi32> + + // Constant global with an inline initial value of a more specific type. + ml_program.global @foobar : tensor = dense<4> : tensor<4xi32> + + // Mutable global with an external initial value of a more specific type. + ml_program.global mutable @foobar : tensor = #extern : tensor<4xi32> + + // Mutable global with an undefined initial value. + ml_program.global mutable @foobar : tensor + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttr:$type, + UnitAttr:$is_mutable, + OptionalAttr:$value, + OptionalAttr:$sym_visibility + ); + + // TODO: We really want a declararative assembly directive for + // parseOptionalVisibilityKeyword vs the custom code. + // See: https://github.com/llvm/llvm-project/issues/55052 + let assemblyFormat = [{ + custom($sym_visibility) + (`mutable` $is_mutable^)? + $sym_name + custom($type, $value) + attr-dict + }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// GlobalLoadConstOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [ + NoSideEffect + ]> { + let summary = "Direct load a constant value from a global"; + let description = [{ + Loads a constant (immutable) value from a global directly by symbol. + + This op is only legal for globals that are not mutable and exists because + such a load can be considered to have no side effects. + + Example: + %0 = ml_program.global_load_const @foobar : tensor + }]; + + let arguments = (ins + FlatSymbolRefAttr:$global + ); + let results = (outs + AnyType:$result + ); + + let assemblyFormat = [{ + $global attr-dict `:` type($result) + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(); + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt @@ -7,6 +7,7 @@ DEPENDS MLIRMLProgramOpsIncGen + MLIRMLProgramAttributesIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp @@ -7,15 +7,39 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::ml_program; +//===----------------------------------------------------------------------===// +/// Tablegen Definitions +//===----------------------------------------------------------------------===// + #include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc" + +namespace { +struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface { + using OpAsmDialectInterface::OpAsmDialectInterface; + + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { + if (attr.isa()) { + os << "extern"; + return AliasResult::OverridableAlias; + } + return AliasResult::NoAlias; + } +}; +} // namespace void ml_program::MLProgramDialect::initialize() { + addAttributes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" >(); + addInterfaces(); } diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -13,6 +13,99 @@ using namespace mlir; using namespace mlir::ml_program; +//===----------------------------------------------------------------------===// +// Custom asm helpers +//===----------------------------------------------------------------------===// + +/// some.op custom($type, $attr) +/// -> +/// some.op : i32 +/// some.op = 42 : i32 +/// some.op : i32 = 42 : index +/// some.op = extern : i32 +/// some.op : i32 = extern : index +static ParseResult parseTypedInitialValue(OpAsmParser &parser, + TypeAttr &typeAttr, Attribute &attr) { + auto parseAttr = [&]() -> LogicalResult { + // Parse `extern` bareword as an ExternAttr. + if (succeeded(parser.parseOptionalKeyword("extern"))) { + Type externType; + if (failed(parser.parseColonType(externType))) { + return parser.emitError(parser.getCurrentLocation()) << "expected type"; + } + attr = ExternAttr::get(parser.getContext(), externType); + } else if (failed(parser.parseAttribute(attr))) { + return parser.emitError(parser.getCurrentLocation()) + << "expected attribute"; + } + return success(); + }; + + // Parse `=` value + if (succeeded(parser.parseOptionalEqual())) { + if (failed(parseAttr())) + return failure(); + typeAttr = TypeAttr::get(attr.getType()); + return success(); + } + + // Parse `:` type `=` value + Type type; + if (failed(parser.parseColonType(type))) { + return parser.emitError(parser.getCurrentLocation()) << "expected type"; + } + typeAttr = TypeAttr::get(type); + if (succeeded(parser.parseOptionalEqual())) { + if (failed(parseAttr())) + return failure(); + } + + return success(); +} + +static void printTypedInitialValue(OpAsmPrinter &p, Operation *op, + TypeAttr type, Attribute attr) { + if (!attr || attr.getType() != type.getValue()) { + p << ": "; + p.printAttribute(type); + p << " "; + } + if (attr) { + p << "= "; + if (attr.isa_and_nonnull()) + p << "extern : " << attr.getType(); + else + p.printAttribute(attr); + } +} + +/// some.op custom($sym_visibility) $sym_name +/// -> +/// some.op public @foo +/// some.op private @foo +static ParseResult parseSymbolVisibility(OpAsmParser &parser, + StringAttr &symVisibilityAttr) { + StringRef symVisibility; + parser.parseOptionalKeyword(&symVisibility, {"public", "private", "nested"}); + if (symVisibility.empty()) { + return parser.emitError(parser.getCurrentLocation()) + << "expected 'public', 'private', or 'nested'"; + } + if (!symVisibility.empty()) { + symVisibilityAttr = parser.getBuilder().getStringAttr(symVisibility); + } + return success(); +} + +static void printSymbolVisibility(OpAsmPrinter &p, Operation *op, + StringAttr symVisibilityAttr) { + if (!symVisibilityAttr) { + p << "public"; + } else { + p << symVisibilityAttr.getValue(); + } +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// @@ -38,6 +131,46 @@ function_interface_impl::printFunctionOp(p, *this, /*isVariadic=*/false); } +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +LogicalResult GlobalOp::verify() { + if (!getIsMutable() && !getValue()) { + return emitOpError() << "immutable global must have an initial value"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// GlobalLoadConstOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalLoadConstOp::getGlobalOp() { + return SymbolTable::lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult GlobalLoadConstOp::verify() { + GlobalOp referrent = getGlobalOp(); + if (!referrent) { + return emitOpError() << "undefined global: " << getGlobal(); + } + + if (referrent.getIsMutable()) { + return emitOpError() << "cannot load as const from mutable global " + << getGlobal(); + } + + if (referrent.getType() != getResult().getType()) { + return emitOpError() << "cannot load from global typed " + << referrent.getType() << " as " + << getResult().getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MLProgram/attrs.mlir b/mlir/test/Dialect/MLProgram/attrs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MLProgram/attrs.mlir @@ -0,0 +1,7 @@ +// RUN: mlir-opt %s --allow-unregistered-dialect | mlir-opt --allow-unregistered-dialect | FileCheck %s + +// CHECK: #ml_program.extern +"unregistered.attributes"() { + value = #ml_program.extern +} : () -> () + diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir --- a/mlir/test/Dialect/MLProgram/invalid.mlir +++ b/mlir/test/Dialect/MLProgram/invalid.mlir @@ -31,3 +31,30 @@ // expected-error @+1 {{doesn't match function result}} ml_program.output %arg0 : i64 } + +// ----- +// expected-error @+1 {{immutable global must have an initial value}} +ml_program.global private @const : i32 + +// ----- +ml_program.func @undef_global() -> i32 { + // expected-error @+1 {{undefined global: nothere}} + %0 = ml_program.global_load_const @nothere : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.global private mutable @var : i32 +ml_program.func @mutable_const_load() -> i32 { + // expected-error @+1 {{op cannot load as const from mutable global var}} + %0 = ml_program.global_load_const @var : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.global private @var = 42 : i64 +ml_program.func @const_load_type_mismatch() -> i32 { + // expected-error @+1 {{cannot load from global typed 'i64' as 'i32'}} + %0 = ml_program.global_load_const @var : i32 + ml_program.return %0 : i32 +} diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir --- a/mlir/test/Dialect/MLProgram/ops.mlir +++ b/mlir/test/Dialect/MLProgram/ops.mlir @@ -18,3 +18,15 @@ %0 = "unregistered.dummy"(%arg0) : (i32) -> i32 ml_program.output %0 : i32 } + +// CHECK: ml_program.global private @global_same_type = dense<4> : tensor<4xi32> +ml_program.global private @global_same_type = dense<4> : tensor<4xi32> +// CHECK: ml_program.global public @global_bounding_type : tensor = dense<4> : tensor<4xi32> +ml_program.global public @global_bounding_type : tensor = dense<4> : tensor<4xi32> +// CHECK: ml_program.global private mutable @foobar_mutable_undef : tensor +ml_program.global private mutable @foobar_mutable_undef : tensor +// CHECK: ml_program.global private mutable @global_extern = extern : tensor<4xi32> +ml_program.global private mutable @global_extern = extern : tensor<4xi32> +// CHECK: ml_program.global private mutable @global_extern_bounding_type : tensor = extern : tensor<4xi32> +ml_program.global private mutable + @global_extern_bounding_type : tensor = extern : tensor<4xi32>