diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -539,6 +539,9 @@ Value getCurrentValue() { return body().front().getArgument(0); } + MemRefType getMemRefType() { + return memref().getType().cast(); + } }]; } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" @@ -2746,6 +2747,104 @@ } }; +/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be +/// retried until it succeeds in atomically storing a new value into memory. +/// +/// +---------------------------------+ +/// | | +/// | | +/// | br loop(%loaded) | +/// +---------------------------------+ +/// | +/// -------| | +/// | v v +/// | +--------------------------------+ +/// | | loop(%loaded): | +/// | | | +/// | | %pair = cmpxchg | +/// | | %ok = %pair[0] | +/// | | %new = %pair[1] | +/// | | cond_br %ok, end, loop(%new) | +/// | +--------------------------------+ +/// | | | +/// |----------- | +/// v +/// +--------------------------------+ +/// | end: | +/// | | +/// +--------------------------------+ +/// +struct GenericAtomicRMWOpLowering + : public LoadStoreOpLowering { + using Base::Base; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto atomicOp = cast(op); + + auto loc = op->getLoc(); + OperandAdaptor adaptor(operands); + LLVM::LLVMType valueType = + typeConverter.convertType(atomicOp.getResult().getType()) + .cast(); + + // Split the block into initial, loop, and ending parts. + auto *initBlock = rewriter.getInsertionBlock(); + auto initPosition = rewriter.getInsertionPoint(); + auto *loopBlock = rewriter.splitBlock(initBlock, initPosition); + auto loopArgument = loopBlock->addArgument(valueType); + auto loopPosition = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition); + + // Compute the loaded value and branch to the loop block. + rewriter.setInsertionPointToEnd(initBlock); + auto memRefType = atomicOp.memref().getType().cast(); + auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter, getModule()); + Value init = rewriter.create(loc, dataPtr); + rewriter.create(loc, init, loopBlock); + + // Prepare the body of the loop block. + rewriter.setInsertionPointToStart(loopBlock); + auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); + + // Clone the GenericAtomicRMWOp region and extract the result. + BlockAndValueMapping mapping; + mapping.map(atomicOp.getCurrentValue(), loopArgument); + Block &entryBlock = atomicOp.body().front(); + for (auto &nestedOp : entryBlock.without_terminator()) { + Operation *clone = rewriter.clone(nestedOp, mapping); + mapping.map(nestedOp.getResults(), clone->getResults()); + } + Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); + + // Prepare the epilog of the loop block. + rewriter.setInsertionPointToEnd(loopBlock); + // Append the cmpxchg op to the end of the loop block. + auto successOrdering = LLVM::AtomicOrdering::acq_rel; + auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); + auto cmpxchg = rewriter.create( + loc, pairType, dataPtr, loopArgument, result, successOrdering, + failureOrdering); + // Extract the %new_loaded and %ok values from the pair. + Value newLoaded = rewriter.create( + loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); + Value ok = rewriter.create( + loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); + + // Conditionally branch to the end or back to the loop depending on %ok. + rewriter.create(loc, ok, endBlock, ArrayRef(), + loopBlock, newLoaded); + + // The 'result' of the atomic_rmw op is the newly loaded value. + rewriter.replaceOp(op, {newLoaded}); + + return success(); + } +}; + } // namespace /// Collect a set of patterns to convert from the Standard dialect to LLVM. @@ -2775,6 +2874,7 @@ DivFOpLowering, ExpOpLowering, Exp2OpLowering, + GenericAtomicRMWOpLowering, LogOpLowering, Log10OpLowering, Log2OpLowering, diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1029,6 +1029,30 @@ // ----- +// CHECK-LABEL: func @generic_atomic_rmw +// CHECK32-LABEL: func @generic_atomic_rmw +func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 { + %x = generic_atomic_rmw %I[%i] : memref<10xf32> { + ^bb0(%old_value : f32): + %c1 = constant 1.0 : f32 + atomic_yield %c1 : f32 + } + // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm<"float*"> + // CHECK-NEXT: llvm.br ^bb1([[init]] : !llvm.float) + // CHECK-NEXT: ^bb1([[loaded:%.*]]: !llvm.float): + // CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1.000000e+00 : f32) + // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]] + // CHECK-SAME: acq_rel monotonic : !llvm.float + // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0] + // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1] + // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float) + // CHECK-NEXT: ^bb2: + return %x : f32 + // CHECK-NEXT: llvm.return [[new]] +} + +// ----- + // CHECK-LABEL: func @assume_alignment func @assume_alignment(%0 : memref<4x4xf16>) { // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm<"{ half*, half*, i64, [2 x i64], [2 x i64] }">