diff --git a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td --- a/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td +++ b/mlir/include/mlir/Dialect/MLProgram/IR/MLProgramOps.td @@ -152,6 +152,50 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// GlobalLoadOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [ + DeclareOpInterfaceMethods, + MemoryEffects<[MemRead]> + ]> { + let summary = "Direct load of a mutable value from a global"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized load from a global + that may be mutable. + + It is fully expected that these constraints are not suitable for + all situations, and alternative ops should be defined and used for more + advanced cases. + + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. + + Example: + + ```mlir + %0 = ml_program.global_load @foobar : tensor + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$global + ); + let results = (outs + AnyType:$result + ); + + let assemblyFormat = [{ + $global attr-dict `:` type($result) + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + //===----------------------------------------------------------------------===// // GlobalLoadConstOp //===----------------------------------------------------------------------===// @@ -191,6 +235,48 @@ }]; } +//===----------------------------------------------------------------------===// +// GlobalStoreOp +//===----------------------------------------------------------------------===// + +def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [ + DeclareOpInterfaceMethods, + MemoryEffects<[MemWrite]> + ]> { + let summary = "Direct store of a value into a mutable global"; + let description = [{ + Performs a non-atomic, non-volatile, non-synchronized store to a mutable + global. + + It is fully expected that these constraints are not suitable for + all situations, and alternative ops should be defined and used for more + advanced cases. + + This op is side effecting and may not be valid to use in graph regions + without additional consideration to evaluation order constraints. + + Example: + + ```mlir + ml_program.global_store @foobar = %0 : tensor + ``` + }]; + + let arguments = (ins + FlatSymbolRefAttr:$global, + AnyType:$value + ); + + let assemblyFormat = [{ + $global `=` $value `:` type($value) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Gets the corresponding GlobalOp (or nullptr). + GlobalOp getGlobalOp(SymbolTableCollection &symbolTable); + }]; +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp --- a/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp +++ b/mlir/lib/Dialect/MLProgram/IR/MLProgramOps.cpp @@ -111,6 +111,29 @@ return success(); } +//===----------------------------------------------------------------------===// +// GlobalLoadOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalLoadOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (referrent.getType() != getResult().getType()) + return emitOpError() << "cannot load from global typed " + << referrent.getType() << " as " + << getResult().getType(); + + return success(); +} + //===----------------------------------------------------------------------===// // GlobalLoadConstOp //===----------------------------------------------------------------------===// @@ -138,6 +161,33 @@ return success(); } +//===----------------------------------------------------------------------===// +// GlobalStoreOp +//===----------------------------------------------------------------------===// + +GlobalOp GlobalStoreOp::getGlobalOp(SymbolTableCollection &symbolTable) { + return symbolTable.lookupNearestSymbolFrom( + getOperation()->getParentOp(), getGlobalAttr()); +} + +LogicalResult +GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + GlobalOp referrent = getGlobalOp(symbolTable); + if (!referrent) + return emitOpError() << "undefined global: " << getGlobal(); + + if (!referrent.getIsMutable()) + return emitOpError() << "cannot store to an immutable global " + << getGlobal(); + + if (referrent.getType() != getValue().getType()) + return emitOpError() << "cannot store to a global typed " + << referrent.getType() << " from " + << getValue().getType(); + + return success(); +} + //===----------------------------------------------------------------------===// // SubgraphOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MLProgram/invalid.mlir b/mlir/test/Dialect/MLProgram/invalid.mlir --- a/mlir/test/Dialect/MLProgram/invalid.mlir +++ b/mlir/test/Dialect/MLProgram/invalid.mlir @@ -58,3 +58,41 @@ %0 = ml_program.global_load_const @var : i32 ml_program.return %0 : i32 } + +// ----- +ml_program.func @load_undef() -> i32 { + // expected-error @+1 {{undefined global: nothere}} + %0 = ml_program.global_load @nothere : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.global private mutable @var(42 : i64) : i64 +ml_program.func @load_type_mismatch() -> i32 { + // expected-error @+1 {{cannot load from global typed 'i64' as 'i32'}} + %0 = ml_program.global_load @var : i32 + ml_program.return %0 : i32 +} + +// ----- +ml_program.func @store_undef(%arg0: i32) { + // expected-error @+1 {{undefined global: nothere}} + ml_program.global_store @nothere = %arg0 : i32 + ml_program.return +} + +// ----- +ml_program.global private mutable @var(42 : i64) : i64 +ml_program.func @store_type_mismatch(%arg0: i32) { + // expected-error @+1 {{cannot store to a global typed 'i64' from 'i32'}} + ml_program.global_store @var = %arg0 : i32 + ml_program.return +} + +// ----- +ml_program.global private @var(42 : i64) : i64 +ml_program.func @store_immutable(%arg0: i64) { + // expected-error @+1 {{cannot store to an immutable global var}} + ml_program.global_store @var = %arg0 : i64 + ml_program.return +} diff --git a/mlir/test/Dialect/MLProgram/ops.mlir b/mlir/test/Dialect/MLProgram/ops.mlir --- a/mlir/test/Dialect/MLProgram/ops.mlir +++ b/mlir/test/Dialect/MLProgram/ops.mlir @@ -27,3 +27,16 @@ // CHECK: ml_program.global private mutable @global_extern(#extern) : tensor ml_program.global private mutable @global_extern(#ml_program.extern : tensor<4xi32>) : tensor + +// CHECK-LABEL: @global_load_const +ml_program.func @global_load_const() -> tensor<4xi32> { + %0 = ml_program.global_load_const @global_same_type : tensor<4xi32> + ml_program.return %0 : tensor<4xi32> +} + +// CHECK-LABEL: @global_load_store +ml_program.func @global_load_store() { + %0 = ml_program.global_load @global_mutable_undef : tensor + ml_program.global_store @global_mutable_undef = %0 : tensor + ml_program.return +}