diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -218,6 +218,69 @@ let hasFolder = 1; } +def ATOMIC_RMW_KIND_ADDF : I64EnumAttrCase<"addf", 0>; +def ATOMIC_RMW_KIND_ADDI : I64EnumAttrCase<"addi", 1>; +def ATOMIC_RMW_KIND_ASSIGN : I64EnumAttrCase<"assign", 2>; +def ATOMIC_RMW_KIND_MAXF : I64EnumAttrCase<"maxf", 3>; +def ATOMIC_RMW_KIND_MAXS : I64EnumAttrCase<"maxs", 4>; +def ATOMIC_RMW_KIND_MAXU : I64EnumAttrCase<"maxu", 5>; +def ATOMIC_RMW_KIND_MINF : I64EnumAttrCase<"minf", 6>; +def ATOMIC_RMW_KIND_MINS : I64EnumAttrCase<"mins", 7>; +def ATOMIC_RMW_KIND_MINU : I64EnumAttrCase<"minu", 8>; +def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>; +def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>; + +def AtomicRMWKindAttr : I64EnumAttr< + "AtomicRMWKind", "", + [ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN, + ATOMIC_RMW_KIND_MAXF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU, + ATOMIC_RMW_KIND_MINF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU, + ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI]> { + let cppNamespace = "::mlir"; +} + +def AtomicRMWOp : Std_Op<"atomic_rmw", [ + AllTypesMatch<["value", "result"]>, + TypesMatchWith<"value type matches element type of memref", + "memref", "value", + "$_self.cast().getElementType()"> + ]> { + let summary = "atomic read-modify-write operation"; + let description = [{ + The "atomic_rmw" operation provides a way to perform a read-modify-write + sequence that is free from data races. The kind enumeration specifies the + modification to perform. The value operand represents the new value to be + applied during the modification. The memref operand represents the buffer + that the read and write will be performed against, as accessed by the + specified indices. The arity of the indices is the rank of the memref. The + result represents the latest value that was stored. + + Example: + + ```mlir + %x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 + ``` + }]; + + let arguments = (ins + AtomicRMWKindAttr:$kind, + AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value, + MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, + Variadic:$indices); + let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); + + let assemblyFormat = [{ + $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,` + type($memref) `)` `->` type($result) + }]; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref().getType().cast(); + } + }]; +} + def BranchOp : Std_Op<"br", [Terminator]> { let summary = "branch operation"; let description = [{ diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1143,7 +1143,8 @@ } }; -template struct OpCountValidator { +template +struct OpCountValidator { static_assert( std::is_base_of< typename OpTrait::NOperands::template Impl, @@ -1151,12 +1152,14 @@ "wrong operand count"); }; -template struct OpCountValidator { +template +struct OpCountValidator { static_assert(std::is_base_of, SourceOp>::value, "expected a single operand"); }; -template void ValidateOpCount() { +template +void ValidateOpCount() { OpCountValidator(); } @@ -1524,11 +1527,10 @@ if (strides[index] == MemRefType::getDynamicStrideOrOffset()) // Identity layout map is enforced in the match function, so we compute: // `runningStride *= sizes[index + 1]` - runningStride = - runningStride - ? rewriter.create(loc, runningStride, - sizes[index + 1]) - : createIndexConstant(rewriter, loc, 1); + runningStride = runningStride + ? rewriter.create(loc, runningStride, + sizes[index + 1]) + : createIndexConstant(rewriter, loc, 1); else runningStride = createIndexConstant(rewriter, loc, strides[index]); strideValues[index] = runningStride; @@ -2537,6 +2539,170 @@ } // namespace +/// Try to match the kind of a std.atomic_rmw to determine whether to use a +/// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. +static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { + switch (atomicOp.kind()) { + case AtomicRMWKind::addf: + return LLVM::AtomicBinOp::fadd; + case AtomicRMWKind::addi: + return LLVM::AtomicBinOp::add; + case AtomicRMWKind::assign: + return LLVM::AtomicBinOp::xchg; + case AtomicRMWKind::maxs: + return LLVM::AtomicBinOp::max; + case AtomicRMWKind::maxu: + return LLVM::AtomicBinOp::umax; + case AtomicRMWKind::mins: + return LLVM::AtomicBinOp::min; + case AtomicRMWKind::minu: + return LLVM::AtomicBinOp::umin; + default: + return llvm::None; + } + llvm_unreachable("Invalid AtomicRMWKind"); +} + +namespace { + +struct AtomicRMWOpLowering : public LoadStoreOpLowering { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto atomicOp = cast(op); + auto maybeKind = matchSimpleAtomicOp(atomicOp); + if (!maybeKind) + return matchFailure(); + OperandAdaptor adaptor(operands); + auto resultType = adaptor.value().getType(); + auto memRefType = atomicOp.getMemRefType(); + auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), + adaptor.indices(), rewriter, getModule()); + rewriter.replaceOpWithNewOp( + op, resultType, *maybeKind, dataPtr, adaptor.value(), + LLVM::AtomicOrdering::acq_rel); + return matchSuccess(); + } +}; + +/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be +/// retried until it succeeds in atomically storing a new value into memory. +/// +/// +---------------------------------+ +/// | | +/// | | +/// | br loop(%loaded) | +/// +---------------------------------+ +/// | +/// -------| | +/// | v v +/// | +--------------------------------+ +/// | | loop(%loaded): | +/// | | | +/// | | %pair = cmpxchg | +/// | | %ok = %pair[0] | +/// | | %new = %pair[1] | +/// | | cond_br %ok, end, loop(%new) | +/// | +--------------------------------+ +/// | | | +/// |----------- | +/// v +/// +--------------------------------+ +/// | end: | +/// | | +/// +--------------------------------+ +/// +struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering { + using Base::Base; + + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto atomicOp = cast(op); + auto maybeKind = matchSimpleAtomicOp(atomicOp); + if (maybeKind) + return matchFailure(); + + LLVM::FCmpPredicate predicate; + switch (atomicOp.kind()) { + case AtomicRMWKind::maxf: + predicate = LLVM::FCmpPredicate::ogt; + break; + case AtomicRMWKind::minf: + predicate = LLVM::FCmpPredicate::olt; + break; + default: + return matchFailure(); + } + + OperandAdaptor adaptor(operands); + auto loc = op->getLoc(); + auto valueType = adaptor.value().getType().cast(); + + // Split the block into initial, loop, and ending parts. + auto *initBlock = rewriter.getInsertionBlock(); + auto initPosition = rewriter.getInsertionPoint(); + auto *loopBlock = rewriter.splitBlock(initBlock, initPosition); + auto loopArgument = loopBlock->addArgument(valueType); + auto loopPosition = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition); + + // Compute the loaded value and branch to the loop block. + rewriter.setInsertionPointToEnd(initBlock); + auto memRefType = atomicOp.getMemRefType(); + auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter, getModule()); + auto init = rewriter.create(loc, dataPtr); + std::array brRegionOperands{init}; + std::array brOperands{brRegionOperands}; + rewriter.create(loc, ArrayRef{}, loopBlock, brOperands); + + // Prepare the body of the loop block. + rewriter.setInsertionPointToStart(loopBlock); + auto predicateI64 = + rewriter.getI64IntegerAttr(static_cast(predicate)); + auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); + auto lhs = loopArgument; + auto rhs = adaptor.value(); + auto cmp = + rewriter.create(loc, boolType, predicateI64, lhs, rhs); + auto select = rewriter.create(loc, cmp, lhs, rhs); + + // Prepare the epilog of the loop block. + rewriter.setInsertionPointToEnd(loopBlock); + // Append the cmpxchg op to the end of the loop block. + auto successOrdering = LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); + auto cmpxchg = rewriter.create( + loc, pairType, dataPtr, loopArgument, select, successOrdering, + failureOrdering); + // Extract the %new_loaded and %ok values from the pair. + auto newLoaded = rewriter.create( + loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); + auto ok = rewriter.create( + loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); + + // Conditionally branch to the end or back to the loop depending on %ok. + std::array condBrProperOperands{ok}; + std::array condBrDestinations{endBlock, loopBlock}; + std::array condBrRegionOperands{newLoaded}; + std::array condBrOperands{ArrayRef{}, + condBrRegionOperands}; + rewriter.create(loc, condBrProperOperands, + condBrDestinations, condBrOperands); + + // The 'result' of the atomic_rmw op is the newly loaded value. + rewriter.replaceOp(op, {newLoaded}); + + return matchSuccess(); + } +}; + +} // namespace + static void ensureDistinctSuccessors(Block &bb) { auto *terminator = bb.getTerminator(); @@ -2594,6 +2760,8 @@ AddFOpLowering, AddIOpLowering, AndOpLowering, + AtomicCmpXchgOpLowering, + AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -135,7 +135,8 @@ } /// A custom cast operation verifier. -template static LogicalResult verifyCastOp(T op) { +template +static LogicalResult verifyCastOp(T op) { auto opType = op.getOperand().getType(); auto resType = op.getType(); if (!T::areCastCompatible(opType, resType)) @@ -2615,6 +2616,41 @@ } //===----------------------------------------------------------------------===// +// AtomicRMWOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AtomicRMWOp op) { + if (op.getMemRefType().getRank() != op.getNumOperands() - 2) + return op.emitOpError( + "expects the number of subscripts to be equal to memref rank"); + switch (op.kind()) { + case AtomicRMWKind::addf: + case AtomicRMWKind::maxf: + case AtomicRMWKind::minf: + case AtomicRMWKind::mulf: + if (!op.value().getType().isa()) + return op.emitOpError() + << "with kind '" << stringifyAtomicRMWKind(op.kind()) + << "' expects a floating-point type"; + break; + case AtomicRMWKind::addi: + case AtomicRMWKind::maxs: + case AtomicRMWKind::maxu: + case AtomicRMWKind::mins: + case AtomicRMWKind::minu: + case AtomicRMWKind::muli: + if (!op.value().getType().isa()) + return op.emitOpError() + << "with kind '" << stringifyAtomicRMWKind(op.kind()) + << "' expects an integer type"; + break; + default: + break; + } + return success(); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -858,6 +858,46 @@ // ----- +// CHECK-LABEL: func @atomic_rmw +func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fval : f32, %i : index) { + atomic_rmw "assign" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: llvm.atomicrmw xchg %{{.*}}, %{{.*}} acq_rel + atomic_rmw "addi" %ival, %I[%i] : (i32, memref<10xi32>) -> i32 + // CHECK: llvm.atomicrmw add %{{.*}}, %{{.*}} acq_rel + atomic_rmw "maxs" %ival, %I[%i] : (i32, memref<10xi32>) -> i32 + // CHECK: llvm.atomicrmw max %{{.*}}, %{{.*}} acq_rel + atomic_rmw "mins" %ival, %I[%i] : (i32, memref<10xi32>) -> i32 + // CHECK: llvm.atomicrmw min %{{.*}}, %{{.*}} acq_rel + atomic_rmw "maxu" %ival, %I[%i] : (i32, memref<10xi32>) -> i32 + // CHECK: llvm.atomicrmw umax %{{.*}}, %{{.*}} acq_rel + atomic_rmw "minu" %ival, %I[%i] : (i32, memref<10xi32>) -> i32 + // CHECK: llvm.atomicrmw umin %{{.*}}, %{{.*}} acq_rel + atomic_rmw "addf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} acq_rel + return +} + +// ----- + +// CHECK-LABEL: func @cmpxchg +func @cmpxchg(%F : memref<10xf32>, %fval : f32, %i : index) -> f32 { + %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 + // CHECK: %[[init:.*]] = llvm.load %{{.*}} : !llvm<"float*"> + // CHECK-NEXT: llvm.br ^bb1(%[[init]] : !llvm.float) + // CHECK-NEXT: ^bb1(%[[loaded:.*]]: !llvm.float): + // CHECK-NEXT: %[[cmp:.*]] = llvm.fcmp "ogt" %[[loaded]], %{{.*}} : !llvm.float + // CHECK-NEXT: %[[max:.*]] = llvm.select %[[cmp]], %[[loaded]], %{{.*}} : !llvm.i1, !llvm.float + // CHECK-NEXT: %[[pair:.*]] = llvm.cmpxchg %{{.*}}, %[[loaded]], %[[max]] acq_rel monotonic : !llvm.float + // CHECK-NEXT: %[[new:.*]] = llvm.extractvalue %[[pair]][0] : !llvm<"{ float, i1 }"> + // CHECK-NEXT: %[[ok:.*]] = llvm.extractvalue %[[pair]][1] : !llvm<"{ float, i1 }"> + // CHECK-NEXT: llvm.cond_br %[[ok]], ^bb2, ^bb1(%[[new]] : !llvm.float) + // CHECK-NEXT: ^bb2: + return %x : f32 + // CHECK-NEXT: llvm.return %[[new]] +} + +// ----- + // CHECK-LABEL: func @assume_alignment func @assume_alignment(%0 : memref<4x4xf16>) { // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }"> diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -741,6 +741,13 @@ return } +// CHECK-LABEL: func @atomic_rmw +func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) { + // CHECK: %{{.*}} = atomic_rmw "addf" %{{.*}}, %{{.*}}[%{{.*}}] + %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<10xf32>) -> f32 + return +} + // CHECK-LABEL: func @assume_alignment // CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16> func @assume_alignment(%0: memref<4x4xf16>) { diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1039,6 +1039,30 @@ // ----- +func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) { + // expected-error@+1 {{expects the number of subscripts to be equal to memref rank}} + %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<16x10xf32>) -> f32 + return +} + +// ----- + +func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) { + // expected-error@+1 {{expects a floating-point type}} + %x = atomic_rmw "addf" %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32 + return +} + +// ----- + +func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) { + // expected-error@+1 {{expects an integer type}} + %x = atomic_rmw "addi" %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32 + return +} + +// ----- + // alignment is not power of 2. func @assume_alignment(%0: memref<4x4xf16>) { // expected-error@+1 {{alignment must be power of 2}}