diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/MLProgram/IR/CMakeLists.txt @@ -8,3 +8,10 @@ add_public_tablegen_target(MLIRMLProgramAttributesIncGen) add_dependencies(mlir-headers MLIRMLProgramAttributesIncGen) add_mlir_doc(MLProgramAttributes MLProgramAttributes Dialects/ -gen-attrdef-doc) + +set(LLVM_TARGET_DEFINITIONS MLProgramTypes.td) +mlir_tablegen(MLProgramTypes.h.inc -gen-typedef-decls) +mlir_tablegen(MLProgramTypes.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRMLProgramTypesIncGen) +add_dependencies(mlir-headers MLIRMLProgramTypesIncGen) +add_mlir_doc(MLProgramTypes MLProgramTypes Dialects/ -gen-typedef-doc) diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgram.h @@ -9,6 +9,7 @@ #define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAM_H_ #include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.h" +#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramBase.td @@ -28,6 +28,7 @@ }]; let useDefaultAttributePrinterParser = 1; + let useDefaultTypePrinterParser = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; } diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -10,6 +10,7 @@ #define MLPROGRAM_OPS include "mlir/Dialect/MLProgram/IR/MLProgramBase.td" +include "mlir/Dialect/MLProgram/IR/MLProgramTypes.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -152,6 +153,51 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// GlobalLoadOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Direct load of a mutable value from a global"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized load from a global + that may be mutable. + + It is fully expected that these constraints are not suitable for + all situations, and alternative ops should be defined and used for more + advanced cases. + + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. + + Example: + + ```mlir + %0 = ml_program.global_load @foobar : tensor + ``` + }]; + + let arguments = (ins + Arg:$global, + Variadic:$consumeTokens + ); + let results = (outs + AnyType:$result, + Optional:$produceToken + ); + + let assemblyFormat = [{ + $global `` custom($consumeTokens, type($produceToken)) `:` type($result) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + //===----------------------------------------------------------------------===// // GlobalLoadConstOp //===----------------------------------------------------------------------===// @@ -175,14 +221,59 @@ }]; let arguments = (ins - FlatSymbolRefAttr:$global + SymbolRefAttr:$global ); let results = (outs AnyType:$result ); let assemblyFormat = [{ - $global attr-dict `:` type($result) + $global `:` type($result) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + +//===----------------------------------------------------------------------===// +// GlobalStoreOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Direct store of a value into a mutable global"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized store to a mutable + global. + + It is fully expected that these constraints are not suitable for + all situations, and alternative ops should be defined and used for more + advanced cases. + + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. + + Example: + + ```mlir + ml_program.global_store @foobar = %0 : tensor + ``` + }]; + + let arguments = (ins + Arg:$global, + AnyType:$value, + Variadic:$consumeTokens + ); + let results = (outs + Optional:$produceToken + ); + + let assemblyFormat = [{ + $global `=` $value `` custom($consumeTokens, type($produceToken)) `:` type($value) attr-dict }]; let extraClassDeclaration = [{ @@ -310,4 +401,24 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// TokenOp +//===----------------------------------------------------------------------===// + +def MLProgram_TokenOp : MLProgram_Op<"token", [ + NoSideEffect + ]> { + let summary = "Produces a new token value"; + let description = [{ + Token values are used to chain side effecting ops in a graph so as to + establish an execution order. This op produces a token. + }]; + + let results = (outs + MLProgram_TokenType:$token + ); + + let assemblyFormat = "attr-dict"; +} + #endif // MLPROGRAM_OPS diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramTypes.h b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramTypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramTypes.h @@ -0,0 +1,21 @@ +//===- MLProgramTypes.h - Type Classes --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMTYPES_H_ +#define MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMTYPES_H_ + +#include "mlir/IR/Types.h" + +//===----------------------------------------------------------------------===// +// Tablegen Type Declarations +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.h.inc" + +#endif // MLIR_DIALECT_MLPROGRAM_IR_MLPROGRAMTYPES_H_ diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramTypes.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramTypes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramTypes.td @@ -0,0 +1,24 @@ +//===- MLProgramTypes.td - Type definitions ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLPROGRAM_TYPES +#define MLPROGRAM_TYPES + +include "mlir/IR/AttrTypeBase.td" +include "mlir/Dialect/MLProgram/IR/MLProgramBase.td" + +class MLProgram_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef {} + +def MLProgram_TokenType : MLProgram_Type<"Token"> { + let summary = "Token for establishing execution ordering in a graph"; + let mnemonic = "token"; +} + +#endif // MLPROGRAM_TYPES diff --git a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MLProgram/IR/CMakeLists.txt @@ -8,6 +8,7 @@ DEPENDS MLIRMLProgramOpsIncGen MLIRMLProgramAttributesIncGen + MLIRMLProgramTypesIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramDialect.cpp @@ -20,6 +20,8 @@ #include "mlir/Dialect/MLProgram/IR/MLProgramOpsDialect.cpp.inc" #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc" +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.cpp.inc" namespace { struct MLProgramOpAsmDialectInterface : public OpAsmDialectInterface { @@ -40,9 +42,16 @@ addAttributes< #include "mlir/Dialect/MLProgram/IR/MLProgramAttributes.cpp.inc" >(); + +#define GET_TYPEDEF_LIST + addTypes< +#include "mlir/Dialect/MLProgram/IR/MLProgramTypes.cpp.inc" + >(); + addOperations< #define GET_OP_LIST #include "mlir/Dialect/MLProgram/IR/MLProgramOps.cpp.inc" >(); + addInterfaces(); } diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -17,6 +17,64 @@ // Custom asm helpers //===----------------------------------------------------------------------===// +/// Parse and print an ordering clause for a variadic of consuming tokens +/// and an optional producing token. +/// +/// Syntax: +/// ordering(%0, %1 -> !ml_program.token) +/// ordering(() -> !ml_program.token) +/// ordering(%0, %1) +/// +/// If both the consuming and producing token are not present on the op, then +/// the clause prints nothing. +static ParseResult parseTokenOrdering( + OpAsmParser &parser, + SmallVectorImpl &consumeTokens, + Type &produceTokenType) { + if (failed(parser.parseOptionalKeyword("ordering")) || + failed(parser.parseLParen())) + return success(); + + // Parse consuming token list. If there are no consuming tokens, the + // '()' null list represents this. + if (succeeded(parser.parseOptionalLParen())) { + if (failed(parser.parseRParen())) + return failure(); + } else { + if (failed(parser.parseOperandList(consumeTokens, + /*requiredOperandCount=*/-1))) + return failure(); + } + + // Parse optional producer token. + if (succeeded(parser.parseOptionalArrow())) + if (failed(parser.parseType(produceTokenType))) + return failure(); + + if (failed(parser.parseRParen())) + return failure(); + + return success(); +} + +static void printTokenOrdering(OpAsmPrinter &p, Operation *op, + OperandRange consumeTokens, + Type produceTokenType) { + if (consumeTokens.empty() && !produceTokenType) + return; + + p << " ordering("; + if (consumeTokens.empty()) + p << "()"; + else + p.printOperands(consumeTokens); + if (produceTokenType) { + p << " -> "; + p.printType(produceTokenType); + } + p << ")"; +} + /// some.op custom($type, $attr) /// /// Uninitialized: @@ -111,6 +169,30 @@ return success(); } +//===----------------------------------------------------------------------===// +// GlobalLoadOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (referrent.getType() != getResult().getType()) { + return emitOpError() << "cannot load from global typed " + << referrent.getType() << " as " + << getResult().getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // GlobalLoadConstOp //===----------------------------------------------------------------------===// @@ -138,6 +220,35 @@ return success(); } +//===----------------------------------------------------------------------===// +// GlobalStoreOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (!referrent.getIsMutable()) { + return emitOpError() << "cannot store to an immutable global " + << getGlobal(); + } + + if (referrent.getType() != getValue().getType()) { + return emitOpError() << "cannot store to a global typed " + << referrent.getType() << " from " + << getValue().getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir --- a/mlir/test/Dialect/MLProgram/invalid.mlir +++ b/mlir/test/Dialect/MLProgram/invalid.mlir @@ -38,7 +38,7 @@ // ----- ml_program.func @undef_global() -> i32 { - // expected-error @+1 {{undefined global: nothere}} + // expected-error @+1 {{undefined global: @nothere}} %0 = ml_program.global_load_const @nothere : i32 ml_program.return %0 : i32 } @@ -46,7 +46,7 @@ // ----- ml_program.global private mutable @var : i32 ml_program.func @mutable_const_load() -> i32 { - // expected-error @+1 {{op cannot load as const from mutable global var}} + // expected-error @+1 {{op cannot load as const from mutable global @var}} %0 = ml_program.global_load_const @var : i32 ml_program.return %0 : i32 } @@ -58,3 +58,41 @@ %0 = ml_program.global_load_const @var : i32 ml_program.return %0 : i32 } + +// ----- +ml_program.func @load_undef() -> i32 { + // expected-error @+1 {{undefined global: @nothere}} + %0 = ml_program.global_load @nothere : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.global private mutable @var(42 : i64) : i64 +ml_program.func @load_type_mismatch() -> i32 { + // expected-error @+1 {{cannot load from global typed 'i64' as 'i32'}} + %0 = ml_program.global_load @var : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.func @store_undef(%arg0: i32) { + // expected-error @+1 {{undefined global: @nothere}} + ml_program.global_store @nothere = %arg0 : i32 + ml_program.return +} + +// ----- +ml_program.global private mutable @var(42 : i64) : i64 +ml_program.func @store_type_mismatch(%arg0: i32) { + // expected-error @+1 {{cannot store to a global typed 'i64' from 'i32'}} + ml_program.global_store @var = %arg0 : i32 + ml_program.return +} + +// ----- +ml_program.global private @var(42 : i64) : i64 +ml_program.func @store_immutable(%arg0: i64) { + // expected-error @+1 {{cannot store to an immutable global @var}} + ml_program.global_store @var = %arg0 : i64 + ml_program.return +} diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir --- a/mlir/test/Dialect/MLProgram/ops.mlir +++ b/mlir/test/Dialect/MLProgram/ops.mlir @@ -14,7 +14,8 @@ // CHECK-LABEL: ml_program.subgraph @compute_subgraph ml_program.subgraph @compute_subgraph(%arg0 : i32) -> i32 { - %1 = "unregistered.dummy"(%0) : (i32) -> i32 + %token = ml_program.token + %1 = "unregistered.dummy"(%0, %token) : (i32, !ml_program.token) -> i32 %0 = "unregistered.dummy"(%arg0) : (i32) -> i32 ml_program.output %0 : i32 } @@ -27,3 +28,29 @@ // CHECK: ml_program.global private mutable @global_extern(#extern) : tensor ml_program.global private mutable @global_extern(#ml_program.extern : tensor<4xi32>) : tensor + +// CHECK-LABEL: @global_load_const +ml_program.func @global_load_const() -> tensor<4xi32> { + %0 = ml_program.global_load_const @global_same_type : tensor<4xi32> + ml_program.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @global_load_store +ml_program.func @global_load_store() { + %0 = ml_program.global_load @global_mutable_undef : tensor + ml_program.global_store @global_mutable_undef = %0 : tensor + ml_program.return +} + +// CHECK-LABEL: @global_load_store_tokens +ml_program.subgraph @global_load_store_tokens() -> (tensor, !ml_program.token) { + %token1 = ml_program.token + %0, %token2 = ml_program.global_load @global_mutable_undef + ordering(() -> !ml_program.token) : tensor + %token3 = ml_program.global_store @global_mutable_undef = %0 + ordering(%token1, %token2 -> !ml_program.token) : tensor + ml_program.global_store @global_mutable_undef = %0 + ordering(%token3) : tensor + + ml_program.output %0, %token3 : tensor, !ml_program.token +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8621,6 +8621,7 @@ "include/mlir/Dialect/MLProgram/IR/MLProgramAttributes.td", "include/mlir/Dialect/MLProgram/IR/MLProgramBase.td", "include/mlir/Dialect/MLProgram/IR/MLProgramOps.td", + "include/mlir/Dialect/MLProgram/IR/MLProgramTypes.td", ], includes = ["include"], deps = [ @@ -8677,6 +8678,24 @@ deps = [":MLProgramOpsTdFiles"], ) +gentbl_cc_library( + name = "MLProgramTypesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-typedef-decls"], + "include/mlir/Dialect/MLProgram/IR/MLProgramTypes.h.inc", + ), + ( + ["-gen-typedef-defs"], + "include/mlir/Dialect/MLProgram/IR/MLProgramTypes.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/MLProgram/IR/MLProgramTypes.td", + deps = [":MLProgramOpsTdFiles"], +) + cc_library( name = "MLProgramDialect", srcs = glob([ @@ -8692,6 +8711,7 @@ ":IR", ":MLProgramAttributesIncGen", ":MLProgramOpsIncGen", + ":MLProgramTypesIncGen", ":Pass", ":Support", "//llvm:Support",