diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -44,44 +44,21 @@ } // Replace the loop. + auto omp = rewriter.create(parallelOp.getLoc()); + Block *block = rewriter.createBlock(&omp.getRegion()); + rewriter.setInsertionPointToStart(block); auto loop = rewriter.create( parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(), parallelOp.step()); rewriter.inlineRegionBefore(parallelOp.region(), loop.region(), loop.region().begin()); + rewriter.create(parallelOp.getLoc()); + rewriter.eraseOp(parallelOp); return success(); } }; -/// Inserts OpenMP "parallel" operations around top-level SCF "parallel" -/// operations in the given function. This is implemented as a direct IR -/// modification rather than as a conversion pattern because it does not -/// modify the top-level operation it matches, which is a requirement for -/// rewrite patterns. -// -// TODO: consider creating nested parallel operations when necessary. -static void insertOpenMPParallel(FuncOp func) { - // Collect top-level SCF "parallel" ops. - SmallVector topLevelParallelOps; - func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) { - // Ignore ops that are already within OpenMP parallel construct. - if (!parallelOp->getParentOfType()) - topLevelParallelOps.push_back(parallelOp); - }); - - // Wrap SCF ops into OpenMP "parallel" ops. - for (scf::ParallelOp parallelOp : topLevelParallelOps) { - OpBuilder builder(parallelOp); - auto omp = builder.create(parallelOp.getLoc()); - Block *block = builder.createBlock(&omp.getRegion()); - builder.create(parallelOp.getLoc()); - block->getOperations().splice(block->begin(), - parallelOp->getBlock()->getOperations(), - parallelOp.getOperation()); - } -} - /// Applies the conversion patterns in the given function. static LogicalResult applyPatterns(FuncOp func) { ConversionTarget target(*func.getContext()); @@ -100,7 +77,6 @@ struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase { /// Pass entry point. void runOnFunction() override { - insertOpenMPParallel(getFunction()); if (failed(applyPatterns(getFunction()))) signalPassFailure(); } diff --git a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir --- a/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir +++ b/mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir @@ -21,8 +21,8 @@ %arg3: index, %arg4: index, %arg5: index) { // CHECK: omp.parallel { // CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) { - // CHECK-NOT: omp.parallel scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) { + // CHECK: omp.parallel // CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) { scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) { // CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()