diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -580,15 +580,11 @@ // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); - auto *loopBlock = rewriter.createBlock( - initBlock->getParent(), std::next(Region::iterator(initBlock)), - valueType, loc); - auto *endBlock = rewriter.createBlock( - loopBlock->getParent(), std::next(Region::iterator(loopBlock))); + auto *loopBlock = rewriter.splitBlock(initBlock, Block::iterator(atomicOp)); + loopBlock->addArgument(valueType, loc); - // Operations range to be moved to `endBlock`. - auto opsToMoveStart = atomicOp->getIterator(); - auto opsToMoveEnd = initBlock->back().getIterator(); + auto *endBlock = + rewriter.splitBlock(loopBlock, Block::iterator(atomicOp)++); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); @@ -628,30 +624,12 @@ loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); - moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), - std::next(opsToMoveEnd), rewriter); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(atomicOp, {newLoaded}); return success(); } - -private: - // Clones a segment of ops [start, end) and erases the original. - void moveOpsRange(ValueRange oldResult, ValueRange newResult, - Block::iterator start, Block::iterator end, - ConversionPatternRewriter &rewriter) const { - IRMapping mapping; - mapping.map(oldResult, newResult); - SmallVector opsToErase; - for (auto it = start; it != end; ++it) { - rewriter.clone(*it, mapping); - opsToErase.push_back(&*it); - } - for (auto *it : opsToErase) - rewriter.eraseOp(it); - } }; /// Returns the LLVM type of the global variable given the memref type `type`. diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -362,16 +362,47 @@ ^bb0(%old_value : i32): memref.atomic_yield %old_value : i32 } - // CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32 - // CHECK-NEXT: llvm.br ^bb1([[init]] : i32) - // CHECK-NEXT: ^bb1([[loaded:%.*]]: i32): - // CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]] - // CHECK-SAME: acq_rel monotonic : !llvm.ptr, i32 - // CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0] - // CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1] - // CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32) llvm.return } +// CHECK: %[[INIT:.*]] = llvm.load %{{.*}} : !llvm.ptr -> i32 +// CHECK-NEXT: llvm.br ^bb1(%[[INIT]] : i32) +// CHECK-NEXT: ^bb1(%[[LOADED:.*]]: i32): +// CHECK-NEXT: %[[PAIR:.*]] = llvm.cmpxchg %{{.*}}, %[[LOADED]], %[[LOADED]] +// CHECK-SAME: acq_rel monotonic : !llvm.ptr, i32 +// CHECK-NEXT: %[[NEW:.*]] = llvm.extractvalue %[[PAIR]][0] +// CHECK-NEXT: %[[OK:.*]] = llvm.extractvalue %[[PAIR]][1] +// CHECK-NEXT: llvm.cond_br %[[OK]], ^bb2, ^bb1(%[[NEW]] : i32) + +// ----- + +// CHECK-LABEL: func @generic_atomic_rmw_in_alloca_scope +func.func @generic_atomic_rmw_in_alloca_scope(){ + %c1 = arith.constant 1 : index + %alloc = memref.alloc() : memref<2x3xi32> + memref.alloca_scope { + %0 = memref.generic_atomic_rmw %alloc[%c1, %c1] : memref<2x3xi32> { + ^bb0(%arg0: i32): + memref.atomic_yield %arg0 : i32 + } + } + return +} +// CHECK: %[[STACK_SAVE:.*]] = llvm.intr.stacksave : !llvm.ptr +// CHECK-NEXT: llvm.br ^bb1 +// CHECK: ^bb1: +// CHECK: %[[INIT:.*]] = llvm.load %[[BUF:.*]] : !llvm.ptr -> i32 +// CHECK-NEXT: llvm.br ^bb2(%[[INIT]] : i32) +// CHECK-NEXT: ^bb2(%[[LOADED:.*]]: i32): +// CHECK-NEXT: %[[PAIR:.*]] = llvm.cmpxchg %[[BUF]], %[[LOADED]], %[[LOADED]] +// CHECK-SAME: acq_rel monotonic : !llvm.ptr, i32 +// CHECK-NEXT: %[[NEW:.*]] = llvm.extractvalue %[[PAIR]][0] +// CHECK-NEXT: %[[OK:.*]] = llvm.extractvalue %[[PAIR]][1] +// CHECK-NEXT: llvm.cond_br %[[OK]], ^bb3, ^bb2(%[[NEW]] : i32) +// CHECK-NEXT: ^bb3: +// CHECK-NEXT: llvm.intr.stackrestore %[[STACK_SAVE]] : !llvm.ptr +// CHECK-NEXT: llvm.br ^bb4 +// CHECK-NEXT: ^bb4: +// CHECK-NEXT: return // -----