diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2791,11 +2791,15 @@ // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); - auto initPosition = rewriter.getInsertionPoint(); - auto *loopBlock = rewriter.splitBlock(initBlock, initPosition); - auto loopArgument = loopBlock->addArgument(valueType); - auto loopPosition = rewriter.getInsertionPoint(); - auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition); + auto *loopBlock = + rewriter.createBlock(initBlock->getParent(), + std::next(Region::iterator(initBlock)), valueType); + auto *endBlock = rewriter.createBlock( + loopBlock->getParent(), std::next(Region::iterator(loopBlock))); + + // Operations range to be moved to `endBlock`. + auto opsToMoveStart = atomicOp.getOperation()->getIterator(); + auto opsToMoveEnd = initBlock->back().getIterator(); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); @@ -2807,9 +2811,9 @@ // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); - auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); // Clone the GenericAtomicRMWOp region and extract the result. + auto loopArgument = loopBlock->getArgument(0); BlockAndValueMapping mapping; mapping.map(atomicOp.getCurrentValue(), loopArgument); Block &entryBlock = atomicOp.body().front(); @@ -2820,10 +2824,10 @@ Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); // Prepare the epilog of the loop block. - rewriter.setInsertionPointToEnd(loopBlock); // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; + auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, result, successOrdering, @@ -2838,11 +2842,31 @@ rewriter.create(loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); + rewriter.setInsertionPointToEnd(endBlock); + MoveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), + std::next(opsToMoveEnd), rewriter); + // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(op, {newLoaded}); return success(); } + +private: + // Clones a segment of ops [start, end) and erases the original. + void MoveOpsRange(ValueRange oldResult, ValueRange newResult, + Block::iterator start, Block::iterator end, + ConversionPatternRewriter &rewriter) const { + BlockAndValueMapping mapping; + mapping.map(oldResult, newResult); + SmallVector opsToErase; + for (auto it = start; it != end; ++it) { + rewriter.clone(*it, mapping); + opsToErase.push_back(&*it); + } + for (auto *it : opsToErase) + rewriter.eraseOp(it); + } }; } // namespace diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1030,7 +1030,6 @@ // ----- // CHECK-LABEL: func @generic_atomic_rmw -// CHECK32-LABEL: func @generic_atomic_rmw func @generic_atomic_rmw(%I : memref<10xf32>, %i : index) -> f32 { %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%old_value : f32): @@ -1047,8 +1046,12 @@ // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1] // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : !llvm.float) // CHECK-NEXT: ^bb2: - return %x : f32 - // CHECK-NEXT: llvm.return [[new]] + %c2 = constant 2.0 : f32 + %add = addf %c2, %x : f32 + return %add : f32 + // CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2.000000e+00 : f32) + // CHECK-NEXT: [[add:%.*]] = llvm.fadd [[c2]], [[new]] : !llvm.float + // CHECK-NEXT: llvm.return [[add]] } // -----