diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -335,16 +335,6 @@ LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, PatternRewriter &rewriter) const override { - // Replace SCF yield with OpenMP yield. - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(parallelOp.getBody()); - assert(llvm::hasSingleElement(parallelOp.getRegion()) && - "expected scf.parallel to have one block"); - rewriter.replaceOpWithNewOp( - parallelOp.getBody()->getTerminator(), ValueRange()); - } - // Declare reductions. // TODO: consider checking it here is already a compatible reduction // declaration and use it instead of redeclaring. @@ -394,22 +384,31 @@ OpBuilder::InsertionGuard guard(rewriter); rewriter.createBlock(&ompParallel.region()); + // Replace the loop. { - auto scope = rewriter.create(parallelOp.getLoc(), - TypeRange()); - rewriter.create(loc); OpBuilder::InsertionGuard allocaGuard(rewriter); - rewriter.createBlock(&scope.getBodyRegion()); - rewriter.setInsertionPointToStart(&scope.getBodyRegion().front()); - - // Replace the loop. auto loop = rewriter.create( parallelOp.getLoc(), parallelOp.getLowerBound(), parallelOp.getUpperBound(), parallelOp.getStep()); - rewriter.create(loc); + rewriter.create(loc); rewriter.inlineRegionBefore(parallelOp.getRegion(), loop.region(), loop.region().begin()); + + Block *ops = rewriter.splitBlock(&*loop.region().begin(), + loop.region().begin()->begin()); + + rewriter.setInsertionPointToStart(&*loop.region().begin()); + + auto scope = rewriter.create(parallelOp.getLoc(), + TypeRange()); + rewriter.create(loc, ValueRange()); + Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); + rewriter.mergeBlocks(ops, scopeBlock); + auto oldYield = cast(scopeBlock->getTerminator()); + rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); + rewriter.replaceOpWithNewOp( + oldYield, oldYield->getOperands()); if (!reductionVariables.empty()) { loop.reductionsAttr( ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); diff --git a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir --- a/mlir/test/Conversion/SCFToOpenMP/reductions.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/reductions.mlir @@ -26,9 +26,9 @@ %step = arith.constant 1 : index %zero = arith.constant 0.0 : f32 // CHECK: omp.parallel - // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF]] -> %[[BUF]] + // CHECK: memref.alloca_scope scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%zero) -> (f32) { // CHECK: %[[CST_INNER:.*]] = arith.constant 1.0 @@ -161,10 +161,10 @@ // CHECK: llvm.store %[[IONE]], %[[BUF2]] // CHECK: omp.parallel - // CHECK: memref.alloca_scope // CHECK: omp.wsloop // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]] // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]] + // CHECK: memref.alloca_scope %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %step) init (%zero, %ione) -> (f32, i64) { %one = arith.constant 1.0 : f32 diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -4,8 +4,8 @@ func @parallel(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { - // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR1:.*]], %[[LVAR2:.*]]) : index = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + // CHECK: memref.alloca_scope scf.parallel (%i, %j) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { // CHECK: "test.payload"(%[[LVAR1]], %[[LVAR2]]) : (index, index) -> () "test.payload"(%i, %j) : (index, index) -> () @@ -21,12 +21,12 @@ func @nested_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { - // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { + // CHECK: memref.alloca_scope scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: omp.parallel - // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { + // CHECK: memref.alloca_scope scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> () "test.payload"(%i, %j) : (index, index) -> () @@ -44,8 +44,8 @@ func @adjacent_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { - // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { + // CHECK: memref.alloca_scope scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { // CHECK: "test.payload1"(%[[LVAR_AL1]]) : (index) -> () "test.payload1"(%i) : (index) -> () @@ -56,8 +56,8 @@ // CHECK: } // CHECK: omp.parallel { - // CHECK: memref.alloca_scope // CHECK: omp.wsloop (%[[LVAR_AL2:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { + // CHECK: memref.alloca_scope scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload2"(%[[LVAR_AL2]]) : (index) -> () "test.payload2"(%j) : (index) -> ()