diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -171,7 +171,8 @@ advanced cases. This op is side effecting and may not be valid to use in graph regions - without additional consideration to evaluation order constraints. + without additional consideration to evaluation order constraints. See + `global_load_graph` for op which allows for explicit ordering constraints. Example: @@ -181,16 +182,14 @@ }]; let arguments = (ins - Arg:$global, - Variadic:$consumeTokens + Arg:$global ); let results = (outs - AnyType:$result, - Optional:$produceToken + AnyType:$result ); let assemblyFormat = [{ - $global `` custom($consumeTokens, type($produceToken)) `:` type($result) attr-dict + $global `:` type($result) attr-dict }]; let extraClassDeclaration = [{ @@ -238,6 +237,52 @@ }]; } +//===----------------------------------------------------------------------===// +// GlobalLoadGraphOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalLoadGraphOp : MLProgram_Op<"global_load_graph", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Direct load of a mutable value from a global in Graph region"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized load from a global + that may be mutable. + + It is fully expected that these constraints are not suitable for all + situations, and alternative ops should be defined and used for more advanced + cases. + + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. + + Example: + + ```mlir + %0, %cstr = ml_program.global_load_graph @foobar + ordering (%token -> !ml_program.token) : tensor + ``` + }]; + + let arguments = (ins + Arg:$global, + Variadic:$consumeTokens + ); + let results = (outs + AnyType:$result, + MLProgram_TokenType:$produceToken + ); + + let assemblyFormat = [{ + $global `` custom($consumeTokens, type($produceToken)) `:` type($result) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + //===----------------------------------------------------------------------===// // GlobalStoreOp //===----------------------------------------------------------------------===// @@ -255,7 +300,8 @@ advanced cases. This op is side effecting and may not be valid to use in graph regions - without additional consideration to evaluation order constraints. + without additional consideration to evaluation order constraints. See + `global_store_graph` for op which allows for explicit ordering constraints. Example: @@ -266,11 +312,53 @@ let arguments = (ins Arg:$global, + AnyType:$value + ); + + let assemblyFormat = [{ + $global `=` $value `:` type($value) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + +//===----------------------------------------------------------------------===// +// GlobalStoreGraphOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalStoreGraphOp : MLProgram_Op<"global_store_graph", [ + DeclareOpInterfaceMethods + ]> { + let summary = "Direct store of a value into a mutable global"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized store to a mutable + global. + + It is fully expected that these constraints are not suitable for + all situations, and alternative ops should be defined and used for more + advanced cases. + + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. + + Example: + + ```mlir + %token = ml_program.global_store @foobar = %0 : tensor + ordering (%in_token -> !ml_program.token) : tensor + ``` + }]; + + let arguments = (ins + Arg:$global, AnyType:$value, Variadic:$consumeTokens ); let results = (outs - Optional:$produceToken + MLProgram_TokenType:$produceToken ); let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -18,12 +18,11 @@ //===----------------------------------------------------------------------===// /// Parse and print an ordering clause for a variadic of consuming tokens -/// and an optional producing token. +/// and an producing token. /// /// Syntax: /// ordering(%0, %1 -> !ml_program.token) /// ordering(() -> !ml_program.token) -/// ordering(%0, %1) /// /// If both the consuming and producing token are not present on the op, then /// the clause prints nothing. @@ -46,10 +45,11 @@ return failure(); } - // Parse optional producer token. - if (succeeded(parser.parseOptionalArrow())) - if (failed(parser.parseType(produceTokenType))) - return failure(); + // Parse producer token. + if (failed(parser.parseArrow())) + return failure(); + if (failed(parser.parseType(produceTokenType))) + return failure(); if (failed(parser.parseRParen())) return failure(); @@ -220,6 +220,30 @@ return success(); } +//===----------------------------------------------------------------------===// +// GlobalLoadGraphOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (referrent.getType() != getResult().getType()) { + return emitOpError() << "cannot load from global typed " + << referrent.getType() << " as " + << getResult().getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // GlobalStoreOp //===----------------------------------------------------------------------===// @@ -249,6 +273,35 @@ return success(); } +//===----------------------------------------------------------------------===// +// GlobalStoreGraphOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (!referrent.getIsMutable()) { + return emitOpError() << "cannot store to an immutable global " + << getGlobal(); + } + + if (referrent.getType() != getValue().getType()) { + return emitOpError() << "cannot store to a global typed " + << referrent.getType() << " from " + << getValue().getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir --- a/mlir/test/Dialect/MLProgram/invalid.mlir +++ b/mlir/test/Dialect/MLProgram/invalid.mlir @@ -96,3 +96,17 @@ ml_program.global_store @var = %arg0 : i64 ml_program.return } + +// ----- + +ml_program.global private mutable @global_mutable_undef : tensor +ml_program.subgraph @global_load_store_tokens() -> (tensor, !ml_program.token) { + %token1 = ml_program.token + %0, %token2 = ml_program.global_load_graph @global_mutable_undef + ordering(() -> !ml_program.token) : tensor + %token3 = ml_program.global_store_graph @global_mutable_undef = %0 + // expected-error @+1 {{expected '->'}} + ordering(%token1, %token2) : tensor + + ml_program.output %0, %token3 : tensor, !ml_program.token +} diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir --- a/mlir/test/Dialect/MLProgram/ops.mlir +++ b/mlir/test/Dialect/MLProgram/ops.mlir @@ -45,12 +45,12 @@ // CHECK-LABEL: @global_load_store_tokens ml_program.subgraph @global_load_store_tokens() -> (tensor, !ml_program.token) { %token1 = ml_program.token - %0, %token2 = ml_program.global_load @global_mutable_undef + %0, %token2 = ml_program.global_load_graph @global_mutable_undef ordering(() -> !ml_program.token) : tensor - %token3 = ml_program.global_store @global_mutable_undef = %0 + %token3 = ml_program.global_store_graph @global_mutable_undef = %0 ordering(%token1, %token2 -> !ml_program.token) : tensor - ml_program.global_store @global_mutable_undef = %0 - ordering(%token3) : tensor + %token4 = ml_program.global_store_graph @global_mutable_undef = %0 + ordering(%token3 -> !ml_program.token) : tensor ml_program.output %0, %token3 : tensor, !ml_program.token }