diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2640,109 +2640,54 @@ } }; -/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be -/// retried until it succeeds in atomically storing a new value into memory. +/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with +/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to +/// `generic_atomic_rmw` with the expanded code. /// -/// +---------------------------------+ -/// | | -/// | | -/// | br loop(%loaded) | -/// +---------------------------------+ -/// | -/// -------| | -/// | v v -/// | +--------------------------------+ -/// | | loop(%loaded): | -/// | | | -/// | | %pair = cmpxchg | -/// | | %ok = %pair[0] | -/// | | %new = %pair[1] | -/// | | cond_br %ok, end, loop(%new) | -/// | +--------------------------------+ -/// | | | -/// |----------- | -/// v -/// +--------------------------------+ -/// | end: | -/// | | -/// +--------------------------------+ +/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32 /// -struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering { - using Base::Base; +/// will be lowered to +/// +/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> { +/// ^bb0(%current: f32): +/// %cmp = cmpf "ogt", %current, %fval : f32 +/// %new_value = select %cmp, %current, %fval : f32 +/// atomic_yield %new_value : f32 +/// } +struct AtomicToGenericAtomicRMWOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto atomicOp = cast(op); - auto maybeKind = matchSimpleAtomicOp(atomicOp); - if (maybeKind) + matchAndRewrite(AtomicRMWOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (matchSimpleAtomicOp(op)) return failure(); - LLVM::FCmpPredicate predicate; - switch (atomicOp.kind()) { + CmpFPredicate predicate; + switch (op.kind()) { case AtomicRMWKind::maxf: - predicate = LLVM::FCmpPredicate::ogt; + predicate = CmpFPredicate::OGT; break; case AtomicRMWKind::minf: - predicate = LLVM::FCmpPredicate::olt; + predicate = CmpFPredicate::OLT; break; default: return failure(); } - OperandAdaptor adaptor(operands); - auto loc = op->getLoc(); - auto valueType = adaptor.value().getType().cast(); + auto loc = op.getLoc(); + auto genericOp = + rewriter.create(loc, op.memref(), op.indices()); + OpBuilder bodyBuilder = genericOp.getBodyBuilder(); - // Split the block into initial, loop, and ending parts. - auto *initBlock = rewriter.getInsertionBlock(); - auto initPosition = rewriter.getInsertionPoint(); - auto *loopBlock = rewriter.splitBlock(initBlock, initPosition); - auto loopArgument = loopBlock->addArgument(valueType); - auto loopPosition = rewriter.getInsertionPoint(); - auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition); - - // Compute the loaded value and branch to the loop block. - rewriter.setInsertionPointToEnd(initBlock); - auto memRefType = atomicOp.getMemRefType(); - auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter, getModule()); - Value init = rewriter.create(loc, dataPtr); - rewriter.create(loc, init, loopBlock); - - // Prepare the body of the loop block. - rewriter.setInsertionPointToStart(loopBlock); - auto predicateI64 = - rewriter.getI64IntegerAttr(static_cast(predicate)); - auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); - auto lhs = loopArgument; - auto rhs = adaptor.value(); - auto cmp = - rewriter.create(loc, boolType, predicateI64, lhs, rhs); - auto select = rewriter.create(loc, cmp, lhs, rhs); - - // Prepare the epilog of the loop block. - rewriter.setInsertionPointToEnd(loopBlock); - // Append the cmpxchg op to the end of the loop block. - auto successOrdering = LLVM::AtomicOrdering::acq_rel; - auto failureOrdering = LLVM::AtomicOrdering::monotonic; - auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); - auto cmpxchg = rewriter.create( - loc, pairType, dataPtr, loopArgument, select, successOrdering, - failureOrdering); - // Extract the %new_loaded and %ok values from the pair. - Value newLoaded = rewriter.create( - loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); - Value ok = rewriter.create( - loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); - - // Conditionally branch to the end or back to the loop depending on %ok. - rewriter.create(loc, ok, endBlock, ArrayRef(), - loopBlock, newLoaded); - - // The 'result' of the atomic_rmw op is the newly loaded value. - rewriter.replaceOp(op, {newLoaded}); + Value lhs = genericOp.getCurrentValue(); + Value rhs = op.value(); + Value cmp = bodyBuilder.create(loc, predicate, lhs, rhs); + Value select = bodyBuilder.create(loc, cmp, lhs, rhs); + bodyBuilder.create(loc, select); + rewriter.replaceOp(op, genericOp.getResult()); return success(); } }; @@ -2858,7 +2803,6 @@ AddIOpLowering, AllocaOpLowering, AndOpLowering, - AtomicCmpXchgOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, @@ -2929,6 +2873,12 @@ // clang-format on } +static void +populateStdToStdConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + patterns.insert(&converter.getContext()); +} + void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool emitCWrappers) { @@ -2938,6 +2888,7 @@ void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool emitCWrappers, bool useAlignedAlloc) { + populateStdToStdConversionPatterns(converter, patterns); populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, emitCWrappers); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);