diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -218,6 +218,72 @@ let hasFolder = 1; } +def AtomicRMWOp : Std_Op<"atomic_rmw"> { + let summary = "atomic read-modify-write operation"; + let description = [{ + The "atomic_rmw" operation provides a way to perform a read-modify-write + sequence that is free from data races. The memref operand represents the + buffer that the read and write will be performed against, as accessed by the + specified indices. The arity of the indices is the rank of the memref. + The result represents the latest value that was stored. + + The "atomic_rmw" operation contains a region whose entry block expects one + argument of the same type as the element type of the memref and represents + the loaded value read from the memref. The body region may use this argument + to perform a modification. The body must be terminated with the + "atomic_rmw_yield" op, which is used to specify the new value that needs to + be stored. + + Example: + + ```mlir + %cst = constant 1.0 : f32 + %x = atomic_rmw %loaded = %I[%i] : memref<10xf32> { + %new_val = addf %loaded, %cst : f32 + atomic_rmw_yield %new_val : f32 + } + ``` + }]; + + let arguments = (ins AnyMemRef:$memref, Variadic:$indices); + let results = (outs AnyType:$res); + let regions = (region SizedRegion<1>:$body); + + let extraClassDeclaration = [{ + Block *getBody() { return &body().front(); } + + Value getLoadedValue() { return getBody()->getArgument(0); } + + MemRefType getMemRefType() { + return memref().getType().cast(); + } + }]; +} + +def AtomicRMWYieldOp : + Std_Op<"atomic_rmw_yield", [HasParent<"AtomicRMWOp">, Terminator]> { + let summary = "terminator for atomic_rmw operation"; + let description = [{ + "atomic_rmw_yield" is a special terminator operation for the block inside + "atomic_rmw" which terminates the region. It should have the same type as + the entry block argument of the "atomic_rmw" body region. + + Example: + + ```mlir + %cst = constant 1.0 : f32 + %x = atomic_rmw %loaded = %I[%i] : memref<10xf32> { + %new_val = addf %loaded, %cst : f32 + atomic_rmw_yield %new_val : f32 + } + ``` + }]; + + let arguments = (ins AnyType:$result); + + let assemblyFormat = "$result attr-dict `:` type($result)"; +} + def BranchOp : Std_Op<"br", [Terminator]> { let summary = "branch operation"; let description = [{ diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1161,7 +1161,8 @@ } }; -template struct OpCountValidator { +template +struct OpCountValidator { static_assert( std::is_base_of< typename OpTrait::NOperands::template Impl, @@ -1169,12 +1170,14 @@ "wrong operand count"); }; -template struct OpCountValidator { +template +struct OpCountValidator { static_assert(std::is_base_of, SourceOp>::value, "expected a single operand"); }; -template void ValidateOpCount() { +template +void ValidateOpCount() { OpCountValidator(); } @@ -1540,11 +1543,10 @@ if (strides[index] == MemRefType::getDynamicStrideOrOffset()) // Identity layout map is enforced in the match function, so we compute: // `runningStride *= sizes[index + 1]` - runningStride = - runningStride - ? rewriter.create(loc, runningStride, - sizes[index + 1]) - : createIndexConstant(rewriter, loc, 1); + runningStride = runningStride + ? rewriter.create(loc, runningStride, + sizes[index + 1]) + : createIndexConstant(rewriter, loc, 1); else runningStride = createIndexConstant(rewriter, loc, strides[index]); strideValues[index] = runningStride; @@ -2507,6 +2509,230 @@ } }; +struct SimpleAtomicMatch { + LLVM::AtomicBinOp binOp; + Value val; +}; + +} // namespace + +template +static Optional +matchAtomicBinaryOp(AtomicRMWOp atomicOp, OpType op, LLVM::AtomicBinOp binOp) { + Value lhs = op.lhs(); + Value rhs = op.rhs(); + if (lhs == atomicOp.getLoadedValue() && + rhs.getParentRegion() != &atomicOp.body()) { + return SimpleAtomicMatch{binOp, rhs}; + } + if (rhs == atomicOp.getLoadedValue() && + lhs.getParentRegion() != &atomicOp.body()) { + return SimpleAtomicMatch{binOp, lhs}; + } + return llvm::None; +} + +static Optional matchAtomicCmpOp(AtomicRMWOp atomicOp, + SelectOp selectOp, + CmpIOp cmpOp, + LLVM::AtomicBinOp binOp) { + auto loaded = atomicOp.getLoadedValue(); + auto trueValue = selectOp.true_value(); + auto falseValue = selectOp.false_value(); + auto lhs = cmpOp.lhs(); + auto rhs = cmpOp.rhs(); + if (trueValue == loaded && trueValue == lhs && falseValue == rhs && + rhs.getParentRegion() != &atomicOp.body()) { + return SimpleAtomicMatch{binOp, rhs}; + } + return llvm::None; +} + +static Optional matchAtomicSelectOp(AtomicRMWOp atomicOp, + SelectOp op) { + auto cmpOp = dyn_cast_or_null(op.condition().getDefiningOp()); + if (!cmpOp) + return llvm::None; + auto predicate = cmpOp.getPredicate(); + switch (predicate) { + case CmpIPredicate::sgt: + return matchAtomicCmpOp(atomicOp, op, cmpOp, LLVM::AtomicBinOp::max); + case CmpIPredicate::ugt: + return matchAtomicCmpOp(atomicOp, op, cmpOp, LLVM::AtomicBinOp::umax); + case CmpIPredicate::slt: + return matchAtomicCmpOp(atomicOp, op, cmpOp, LLVM::AtomicBinOp::min); + case CmpIPredicate::ult: + return matchAtomicCmpOp(atomicOp, op, cmpOp, LLVM::AtomicBinOp::umin); + default: + break; + } + return llvm::None; +} + +static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { + auto *body = atomicOp.getBody(); + auto *terminator = body->getTerminator(); + auto yieldOp = cast(terminator); + if (yieldOp.result().getParentRegion() != &atomicOp.body()) + return SimpleAtomicMatch{LLVM::AtomicBinOp::xchg, yieldOp.result()}; + auto defOp = yieldOp.result().getDefiningOp(); + return TypeSwitch>(defOp) + .Case([&](auto op) { + return matchAtomicBinaryOp(atomicOp, op, LLVM::AtomicBinOp::add); + }) + .Case([&](auto op) { + return matchAtomicBinaryOp(atomicOp, op, LLVM::AtomicBinOp::sub); + }) + .Case([&](auto op) { + return matchAtomicBinaryOp(atomicOp, op, LLVM::AtomicBinOp::_and); + }) + // TODO(missing NotOp): nand: ~(*ptr & val) + .Case([&](auto op) { + return matchAtomicBinaryOp(atomicOp, op, LLVM::AtomicBinOp::_or); + }) + .Case([&](auto op) { + return matchAtomicBinaryOp(atomicOp, op, LLVM::AtomicBinOp::_xor); + }) + .Case( + [&](auto op) { return matchAtomicSelectOp(atomicOp, op); }) + .Case([&](auto op) { + return matchAtomicBinaryOp(atomicOp, op, LLVM::AtomicBinOp::fadd); + }) + .Case([&](auto op) { + return matchAtomicBinaryOp(atomicOp, op, LLVM::AtomicBinOp::fsub); + }) + .Default([](Operation *op) { return llvm::None; }); +} + +namespace { + +struct AtomicRMWOpLowering : public LoadStoreOpLowering { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto atomicOp = cast(op); + auto simpleMatch = matchSimpleAtomicOp(atomicOp); + if (!simpleMatch) { + return matchFailure(); + } + OperandAdaptor adaptor(operands); + auto type = atomicOp.getMemRefType(); + auto resultType = lowering.convertType(simpleMatch->val.getType()); + auto dataPtr = getDataPtr(op->getLoc(), type, adaptor.memref(), + adaptor.indices(), rewriter, getModule()); + rewriter.create( + op->getLoc(), resultType, simpleMatch->binOp, dataPtr, simpleMatch->val, + LLVM::AtomicOrdering::acq_rel); + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + +/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be +/// retried until it succeeds in atomically storing a new value into memory. +/// +/// +---------------------------------+ +/// | | +/// | | +/// | br loop(%loaded) | +/// +---------------------------------+ +/// | +/// -------| | +/// | v v +/// | +--------------------------------+ +/// | | loop(%loaded): | +/// | | | +/// | | %pair = cmpxchg | +/// | | %ok = %pair[0] | +/// | | %new = %pair[1] | +/// | | cond_br %ok, end, loop(%new) | +/// | +--------------------------------+ +/// | | | +/// |----------- | +/// v +/// +--------------------------------+ +/// | end: | +/// | | +/// +--------------------------------+ +/// +struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto atomicOp = cast(op); + auto simpleMatch = matchSimpleAtomicOp(atomicOp); + if (simpleMatch) { + return matchFailure(); + } + + // Split the block into initial and ending parts. + auto *initBlock = rewriter.getInsertionBlock(); + auto initPosition = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(initBlock, initPosition); + + // The body of the atomic_rmw op will form the basis of the loop block. + auto &loopRegion = atomicOp.body(); + auto *loopBlock = &loopRegion.front(); + auto *terminator = loopBlock->getTerminator(); + auto yieldOp = cast(terminator); + auto loc = op->getLoc(); + + // Compute the loaded value and branch to the loop block. + rewriter.setInsertionPointToEnd(initBlock); + OperandAdaptor adaptor(operands); + auto type = atomicOp.getMemRefType(); + auto dataPtr = getDataPtr(loc, type, adaptor.memref(), adaptor.indices(), + rewriter, getModule()); + auto init = rewriter.create(loc, dataPtr); + ArrayRef brProperOperands{}; + std::array brDestinations{loopBlock}; + std::array brRegionOperands{init.res()}; + std::array brOperands{brRegionOperands}; + rewriter.create(loc, brProperOperands, brDestinations, + brOperands); + + // Prepare the epilog of the loop block. + rewriter.setInsertionPointToEnd(loopBlock); + // Append the cmpxchg op to the end of the loop block. + auto successOrdering = LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto loaded = atomicOp.getLoadedValue(); + auto resultType = + lowering.convertType(loaded.getType()).cast(); + auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); + auto pairType = LLVM::LLVMType::getStructTy(resultType, boolType); + auto cmpxchg = rewriter.create( + loc, pairType, dataPtr, loaded, yieldOp.result(), successOrdering, + failureOrdering); + // Extract the %new_loaded and %ok values from the pair. + auto newLoaded = rewriter.create( + loc, resultType, cmpxchg.res(), rewriter.getI64ArrayAttr({0})); + auto ok = rewriter.create( + loc, boolType, cmpxchg.res(), rewriter.getI64ArrayAttr({1})); + + // Conditionally branch to the end or back to the loop depending on %ok. + std::array condBrProperOperands{ok.res()}; + std::array condBrDestinations{endBlock, loopBlock}; + std::array condBrRegionOperands{newLoaded.res()}; + std::array condBrOperands{ArrayRef{}, + condBrRegionOperands}; + rewriter.replaceOpWithNewOp( + terminator, condBrProperOperands, condBrDestinations, condBrOperands); + + // Now move the body of the atomic_rmw op into the outer region just before + // the ending block. + rewriter.inlineRegionBefore(loopRegion, endBlock); + + // Remove the original atomic_rmw op. + rewriter.eraseOp(op); + return matchSuccess(); + } +}; + } // namespace static void ensureDistinctSuccessors(Block &bb) { @@ -2566,6 +2792,8 @@ AddFOpLowering, AddIOpLowering, AndOpLowering, + AtomicRMWOpLowering, + AtomicCmpXchgOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -135,7 +135,8 @@ } /// A custom cast operation verifier. -template static LogicalResult verifyCastOp(T op) { +template +static LogicalResult verifyCastOp(T op) { auto opType = op.getOperand().getType(); auto resType = op.getType(); if (!T::areCastCompatible(opType, resType)) @@ -2951,6 +2952,75 @@ return false; } +//===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, AtomicRMWOp op) { + p << op.getOperation()->getName() << ' ' << op.getLoadedValue(); + p << " = " << op.memref() << '[' << op.indices() << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.memref().getType(); + p.printRegion(op.body(), /*printEntryBlockArgs=*/false); +} + +// ::= `atomic_rmw` argument `=` ssa-use `[` ssa-use-list `]` +// attribute-dict? `:` type region +static ParseResult parseAtomicRMWOp(OpAsmParser &parser, + OperationState &result) { + MemRefType type; + OpAsmParser::OperandType loaded; + OpAsmParser::OperandType memref; + auto indexTy = parser.getBuilder().getIndexType(); + SmallVector idxs; + Region *body = result.addRegion(); + if (parser.parseRegionArgument(loaded) || parser.parseEqual() || + parser.parseOperand(memref) || + parser.parseOperandList(idxs, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(type) || + parser.resolveOperand(memref, type, result.operands) || + parser.resolveOperands(idxs, indexTy, result.operands) || + parser.parseRegion(*body, loaded, type.getElementType())) { + return failure(); + } + result.addTypes(type.getElementType()); + return success(); +} + +static LogicalResult verify(AtomicRMWOp op) { + if (op.getMemRefType().getRank() != op.getNumOperands() - 1) + return op.emitOpError( + "expects the number of subscripts to be equal to memref rank"); + auto block = op.getBody(); + if (block->empty()) { + return op.emitOpError("expects a non-empty body"); + } + auto elementType = op.getMemRefType().getElementType(); + if (block->getNumArguments() != 1 || + block->getArgument(0).getType() != elementType) { + return op.emitOpError() + << "expects a body with one argument of type " << elementType; + } + if (!llvm::isa(block->getTerminator())) { + return op.emitOpError( + "expects the body to be terminated with a 'atomic_rmw_yield' op"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// AtomicRMWYieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AtomicRMWYieldOp op) { + auto parentOp = op.getParentOfType(); + Type elementType = parentOp.getMemRefType().getElementType(); + if (elementType != op.result().getType()) + return op.emitOpError() << "needs to have type " << elementType; + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -855,3 +855,95 @@ // CHECK: llvm.func @tanhf(!llvm.float) -> !llvm.float // CHECK-LABEL: func @check_tanh_func_added_only_once_to_symbol_table } + +// ----- + +// CHECK-LABEL: func @atomic_rmw +func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) { + // CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %F[%i] : memref<10xf32> { + atomic_rmw_yield %fval : f32 + } + // CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %0 = addi %loaded, %ival : i32 + atomic_rmw_yield %0 : i32 + } + // CHECK: llvm.atomicrmw sub %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %0 = subi %loaded, %ival : i32 + atomic_rmw_yield %0 : i32 + } + // CHECK: llvm.atomicrmw _and %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %0 = and %loaded, %ival : i32 + atomic_rmw_yield %0 : i32 + } + // CHECK: llvm.atomicrmw _or %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %0 = or %loaded, %ival : i32 + atomic_rmw_yield %0 : i32 + } + // CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %0 = xor %loaded, %ival : i32 + atomic_rmw_yield %0 : i32 + } + // CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %cmp = cmpi "sgt", %loaded, %ival : i32 + %max = select %cmp, %loaded, %ival : i32 + atomic_rmw_yield %max : i32 + } + // CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %cmp = cmpi "slt", %loaded, %ival : i32 + %min = select %cmp, %loaded, %ival : i32 + atomic_rmw_yield %min : i32 + } + // CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %cmp = cmpi "ugt", %loaded, %ival : i32 + %max = select %cmp, %loaded, %ival : i32 + atomic_rmw_yield %max : i32 + } + // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %I[%i] : memref<10xi32> { + %cmp = cmpi "ult", %loaded, %ival : i32 + %min = select %cmp, %loaded, %ival : i32 + atomic_rmw_yield %min : i32 + } + // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %F[%i] : memref<10xf32> { + %0 = addf %loaded, %fval : f32 + atomic_rmw_yield %0 : f32 + } + // CHECK: llvm.atomicrmw fsub %{{.*}}, %{{.*}} acq_rel + atomic_rmw %loaded = %F[%i] : memref<10xf32> { + %0 = subf %loaded, %fval : f32 + atomic_rmw_yield %0 : f32 + } + return +} + +// ----- + +// CHECK-LABEL: func @cmpxchg +func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) { + // CHECK: llvm.br ^bb1(%{{.*}} : !llvm.float) + atomic_rmw %loaded = %F[%i] : memref<10xf32> { + %cmp = cmpf "ogt", %loaded, %fval : f32 + %max = select %cmp, %loaded, %fval : f32 + atomic_rmw_yield %max : f32 + // CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float): + // CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float + // CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float + // CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float + // CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }"> + // CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }"> + // CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float) + } + // CHECK-NEXT: ^bb2: + // CHECK-NEXT: llvm.return + return +} diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -740,3 +740,14 @@ tensor_store %1, %0 : memref<4x4xi32> return } + +// CHECK-LABEL: func @atomic_rmw +func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) { + // CHECK: %{{.*}} = std.atomic_rmw %{{.*}} = %{{.*}}[%{{.*}}] : memref<10xf32> + %x = atomic_rmw %loaded = %I[%i] : memref<10xf32> { + %new_val = addf %loaded, %val : f32 + // CHECK: atomic_rmw_yield %{{.*}} : f32 + atomic_rmw_yield %new_val : f32 + } + return +} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1036,3 +1036,71 @@ %2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0> return } + +// ----- + +func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index) { + %cst = constant 1.0 : f32 + // expected-error@+1 {{expects the number of subscripts to be equal to memref rank}} + %x = atomic_rmw %val = %I[%i] : memref<16x10xf32> { + %0 = addf %val, %cst : f32 + atomic_rmw_yield %0 : f32 + } + return +} + +// ----- + +func @atomic_rmw_empty_body(%I: memref<16x10xf32>, %i : index, %j : index) { + %cst = constant 1.0 : f32 + // expected-error@+1 {{expects a non-empty body}} + %x = atomic_rmw %val = %I[%i, %j] : memref<16x10xf32> {} + return +} + +// ----- + +func @atomic_rmw_region_arg_missing(%I: memref<16x10xf32>, %i : index, %j : index) { + %cst = constant 1.0 : f32 + // expected-error@+1 {{expects a body with one argument of type 'f32'}} + %x = "std.atomic_rmw"(%I, %i, %j) ({ + atomic_rmw_yield %cst : f32 + }) : (memref<16x10xf32>, index, index) -> f32 + return +} + +// ----- + +func @atomic_rmw_region_arg_type_mismatch(%I: memref<16x10xf32>, %i : index, %j : index) { + %cst = constant 1.0 : f32 + // expected-error@+1 {{expects a body with one argument of type 'f32'}} + %x = "std.atomic_rmw"(%I, %i, %j) ({ + ^bb0(%val: i32): + atomic_rmw_yield %cst : f32 + }) : (memref<16x10xf32>, index, index) -> f32 + return +} + +// ----- + +func @atomic_rmw_missing_yield(%I: memref<16x10xf32>, %i : index, %j : index) { + %cst = constant 1.0 : f32 + // expected-error@+1 {{expects the body to be terminated with a 'atomic_rmw_yield' op}} + %x = atomic_rmw %val = %I[%i, %j] : memref<16x10xf32> { + %0 = addf %val, %cst : f32 + "loop.terminator"() : () -> () + } + return +} + +// ----- + +func @atomic_rmw_yield_type_mismatch(%I: memref<16x10xf32>, %i : index, %j : index) { + %c0 = constant 1 : i32 + %cst = constant 1.0 : f32 + %x = atomic_rmw %val = %I[%i, %j] : memref<16x10xf32> { + // expected-error@+1 {{needs to have type 'f32'}} + atomic_rmw_yield %c0 : i32 + } + return +}