diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp --- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp +++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp @@ -404,7 +404,6 @@ ParallelLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { Location loc = parallelOp.getLoc(); - BlockAndValueMapping mapping; // For a parallel loop, we essentially need to create an n-dimensional loop // nest. We do this by translating to scf.for ops and have those lowered in @@ -412,6 +411,8 @@ // values), forward the initial values for the reductions down the loop // hierarchy and bubble up the results by modifying the "yield" terminator. SmallVector iterArgs = llvm::to_vector<4>(parallelOp.initVals()); + SmallVector ivs; + ivs.reserve(parallelOp.getNumLoops()); bool first = true; SmallVector loopResults(iterArgs); for (auto loop_operands : @@ -420,7 +421,7 @@ Value iv, lower, upper, step; std::tie(iv, lower, upper, step) = loop_operands; ForOp forOp = rewriter.create(loc, lower, upper, step, iterArgs); - mapping.map(iv, forOp.getInductionVar()); + ivs.push_back(forOp.getInductionVar()); auto iterRange = forOp.getRegionIterArgs(); iterArgs.assign(iterRange.begin(), iterRange.end()); @@ -439,33 +440,33 @@ rewriter.setInsertionPointToStart(forOp.getBody()); } - // Now copy over the contents of the body. + // First, merge reduction blocks into the main region. SmallVector yieldOperands; yieldOperands.reserve(parallelOp.getNumResults()); - for (auto &op : parallelOp.getBody()->without_terminator()) { - // Reduction blocks are handled differently. + for (auto &op : *parallelOp.getBody()) { auto reduce = dyn_cast(op); - if (!reduce) { - rewriter.clone(op, mapping); + if (!reduce) continue; - } - // Clone the body of the reduction operation into the body of the loop, - // using operands of "scf.reduce" and iteration arguments corresponding - // to the reduction value to replace arguments of the reduction block. - // Collect operands of "scf.reduce.return" to be returned by a final - // "scf.yield" instead. - Value arg = iterArgs[yieldOperands.size()]; Block &reduceBlock = reduce.reductionOperator().front(); - mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg)); - mapping.map(reduceBlock.getArgument(1), - mapping.lookupOrDefault(reduce.operand())); - for (auto &nested : reduceBlock.without_terminator()) - rewriter.clone(nested, mapping); - yieldOperands.push_back( - mapping.lookup(reduceBlock.getTerminator()->getOperand(0))); + Value arg = iterArgs[yieldOperands.size()]; + yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0)); + rewriter.eraseOp(reduceBlock.getTerminator()); + rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.operand()}); + rewriter.eraseOp(reduce); } + // Then merge the loop body without the terminator. + rewriter.eraseOp(parallelOp.getBody()->getTerminator()); + Block *newBody = rewriter.getInsertionBlock(); + if (newBody->empty()) + rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs); + else + rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(), + ivs); + + // Finally, create the terminator if required (for loops with no results, it + // has been already created in loop construction). if (!yieldOperands.empty()) { rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); rewriter.create(loc, yieldOperands); diff --git a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir --- a/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToStandard/convert-to-cfg.mlir @@ -546,3 +546,44 @@ return %0 : i64 } +// CHECK-LABEL: @ifs_in_parallel +// CHECK: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: i1, %[[ARG4:.*]]: i1) +func @ifs_in_parallel(%arg1: index, %arg2: index, %arg3: index, %arg4: i1, %arg5: i1) { + // CHECK: br ^[[LOOP_LATCH:.*]](%[[ARG0]] : index) + // CHECK: ^[[LOOP_LATCH]](%[[LOOP_IV:.*]]: index): + // CHECK: %[[LOOP_COND:.*]] = cmpi "slt", %[[LOOP_IV]], %[[ARG1]] : index + // CHECK: cond_br %[[LOOP_COND]], ^[[LOOP_BODY:.*]], ^[[LOOP_CONT:.*]] + // CHECK: ^[[LOOP_BODY]]: + // CHECK: cond_br %[[ARG3]], ^[[IF1_THEN:.*]], ^[[IF1_CONT:.*]] + // CHECK: ^[[IF1_THEN]]: + // CHECK: cond_br %[[ARG4]], ^[[IF2_THEN:.*]], ^[[IF2_ELSE:.*]] + // CHECK: ^[[IF2_THEN]]: + // CHECK: %{{.*}} = "test.if2"() : () -> index + // CHECK: br ^[[IF2_MERGE:.*]](%{{.*}} : index) + // CHECK: ^[[IF2_ELSE]]: + // CHECK: %{{.*}} = "test.else2"() : () -> index + // CHECK: br ^[[IF2_MERGE]](%{{.*}} : index) + // CHECK: ^[[IF2_MERGE]](%{{.*}}: index): + // CHECK: br ^[[IF2_CONT:.*]] + // CHECK: ^[[IF2_CONT]]: + // CHECK: br ^[[IF1_CONT]] + // CHECK: ^[[IF1_CONT]]: + // CHECK: %{{.*}} = addi %[[LOOP_IV]], %[[ARG2]] : index + // CHECK: br ^[[LOOP_LATCH]](%{{.*}} : index) + scf.parallel (%i) = (%arg1) to (%arg2) step (%arg3) { + scf.if %arg4 { + %0 = scf.if %arg5 -> (index) { + %1 = "test.if2"() : () -> index + scf.yield %1 : index + } else { + %2 = "test.else2"() : () -> index + scf.yield %2 : index + } + } + scf.yield + } + + // CHECK: ^[[LOOP_CONT]]: + // CHECK: return + return +}