diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -484,6 +484,69 @@ }]; } +def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [ + SingleBlockImplicitTerminator<"AtomicYieldOp">, + TypesMatchWith<"result type matches element type of memref", + "memref", "result", + "$_self.cast().getElementType()"> + ]> { + let summary = "atomic read-modify-write operation with a region"; + let description = [{ + The `atomic_rmw` operation provides a way to perform a read-modify-write + sequence that is free from data races. The memref operand represents the + buffer that the read and write will be performed against, as accessed by + the specified indices. The arity of the indices is the rank of the memref. + The result represents the latest value that was stored. The region contains + the code for the modification itself. + + Example: + + ```mlir + x = generic_atomic_rmw %I[%i] : memref<10xf32> { + %c0 = constant 1.0 : f32 + atomic_yield %c0 : f32 + } + ``` + }]; + + let arguments = (ins + MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, + Variadic:$indices); + + let results = (outs + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); + + let regions = (region AnyRegion:$body); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "Value memref, ValueRange ivs"> + ]; + + let extraClassDeclaration = [{ + OpBuilder getBodyBuilder() { + assert(!body().empty() && "Unexpected empty 'body' region."); + Block &block = body().front(); + return OpBuilder(&block, block.end()); + } + }]; +} + +def AtomicYieldOp : Std_Op<"atomic_yield", [ + HasParent<"GenericAtomicRMWOp">, + NoSideEffect, + Terminator + ]> { + let summary = "yield operation for GenericAtomicRMWOp"; + let description = [{ + "atomic_yield" yields an SSA value from a GenericAtomicRMWOp region. + }]; + + let arguments = (ins AnyType:$result); + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -480,6 +480,77 @@ return success(); } +//===----------------------------------------------------------------------===// +// GenericAtomicRMWOp +//===----------------------------------------------------------------------===// + +void GenericAtomicRMWOp::build(Builder *builder, OperationState &result, + Value memref, ValueRange ivs) { + result.addOperands(memref); + result.addOperands(ivs); + + if (auto memrefType = memref.getType().dyn_cast()) { + Type elementType = memrefType.getElementType(); + result.addTypes(elementType); + + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block()); + bodyRegion->front().addArgument(elementType); + } +} + +static LogicalResult verify(GenericAtomicRMWOp op) { + auto &block = op.body().front(); + if (block.getNumArguments() != 1) + return op.emitOpError("expected single number of entry block arguments"); + + if (op.getResult().getType() != block.getArgument(0).getType()) + return op.emitOpError( + "expected block argument of the same type result type"); + return success(); +} + +static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType memref; + Type memrefType; + SmallVector ivs; + + Type indexType = parser.getBuilder().getIndexType(); + if (parser.parseOperand(memref) || + parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || + parser.parseColonType(memrefType) || + parser.resolveOperand(memref, memrefType, result.operands) || + parser.resolveOperands(ivs, indexType, result.operands)) + return failure(); + + Region *body = result.addRegion(); + if (parser.parseRegion(*body, llvm::None, llvm::None)) + return failure(); + result.types.push_back(memrefType.cast().getElementType()); + return success(); +} + +static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { + p << op.getOperationName() << ' ' << op.memref() << " [" << op.indices() + << "] : " << op.memref().getType(); + p.printRegion(op.body()); + p.printOptionalAttrDict(op.getAttrs()); +} + +//===----------------------------------------------------------------------===// +// AtomicYieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AtomicYieldOp op) { + Type parentType = op.getParentOp()->getResultTypes().front(); + Type resultType = op.result().getType(); + if (parentType != resultType) + return op.emitOpError() << "types mismatch between yield op: " << resultType + << " and its parent: " << parentType; + return success(); +} + //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -751,9 +751,23 @@ } // CHECK-LABEL: func @atomic_rmw +// CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index) func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) { - // CHECK: %{{.*}} = atomic_rmw "addf" %{{.*}}, %{{.*}}[%{{.*}}] %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: atomic_rmw "addf" [[VAL]], [[BUF]]{{\[}}[[I]]] + return +} + +// CHECK-LABEL: func @generic_atomic_rmw +// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index) +func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) { + %x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> { + // CHECK-NEXT: generic_atomic_rmw [[BUF]] {{\[}}[[I]], [[J]]] : memref + ^bb0(%old_value : f32): + %c1 = constant 1.0 : f32 + %out = addf %c1, %old_value : f32 + atomic_yield %out : f32 + } return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1143,6 +1143,42 @@ // ----- +func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{expected single number of entry block arguments}} + %x = generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%arg0 : f32, %arg1 : f32): + %c1 = constant 1.0 : f32 + atomic_yield %c1 : f32 + } + return +} + +// ----- + +func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) { + // expected-error@+1 {{expected block argument of the same type result type}} + %x = generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : i32): + %c1 = constant 1.0 : f32 + atomic_yield %c1 : f32 + } + return +} + +// ----- + +func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) { + // expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}} + %x = generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : f32): + %c1 = constant 1 : i32 + atomic_yield %c1 : i32 + } + return +} + +// ----- + // alignment is not power of 2. func @assume_alignment(%0: memref<4x4xf16>) { // expected-error@+1 {{alignment must be power of 2}}