diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -131,7 +131,8 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder<"Builder *builder, OperationState &result, " - "Value lowerBound, Value upperBound, Value step"> + "Value lowerBound, Value upperBound, Value step, " + "ValueRange iterArgs = llvm::None"> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -274,29 +274,77 @@ Location loc = parallelOp.getLoc(); BlockAndValueMapping mapping; - if (parallelOp.getNumResults() != 0) { - // TODO: Implement lowering of parallelOp with reductions. - return matchFailure(); - } - // For a parallel loop, we essentially need to create an n-dimensional loop // nest. We do this by translating to loop.for ops and have those lowered in - // a further rewrite. + // a further rewrite. If a parallel loop contains reductions (and thus returns + // values), forward the initial values for the reductions down the loop + // hierarchy and bubble up the results by modifying the "yield" terminator. + SmallVector iterArgs; + auto range = parallelOp.initVals(); + iterArgs.assign(range.begin(), range.end()); + bool first = true; + SmallVector loopResults(iterArgs); for (auto loop_operands : llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(), parallelOp.upperBound(), parallelOp.step())) { Value iv, lower, upper, step; std::tie(iv, lower, upper, step) = loop_operands; - ForOp forOp = rewriter.create(loc, lower, upper, step); + ForOp forOp = rewriter.create(loc, lower, upper, step, iterArgs); mapping.map(iv, forOp.getInductionVar()); + auto iterRange = forOp.getRegionIterArgs(); + iterArgs.assign(iterRange.begin(), iterRange.end()); + + if (first) { + // Store the results of the outermost loop that will be used to replace + // the results of the parallel loop when it is fully rewritten. + loopResults.assign(forOp.result_begin(), forOp.result_end()); + } else { + // A loop is constructed with an empty "yield" terminator by default. + // Replace it with another "yield" that forwards the results of the nested + // loop to the parent loop. We need to explicitly make sure the new + // terminator is the last operation in the block because further transfoms + // rely on this. + rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); + rewriter.replaceOpWithNewOp( + rewriter.getInsertionBlock()->getTerminator(), forOp.getResults()); + } + first = false; + rewriter.setInsertionPointToStart(forOp.getBody()); } // Now copy over the contents of the body. - for (auto &op : parallelOp.getBody()->without_terminator()) - rewriter.clone(op, mapping); + SmallVector yieldOperands; + yieldOperands.reserve(parallelOp.getNumResults()); + for (auto &op : parallelOp.getBody()->without_terminator()) { + // Reduction blocks are handled differently. + auto reduce = dyn_cast(op); + if (!reduce) { + rewriter.clone(op, mapping); + continue; + } + + // Clone the body of the reduction operation into the body of the loop, + // using operands of "loop.reduce" and iteration arguments corresponding + // to the reduction value to replace arguments of the reduction block. + // Collect operands of "loop.reduce.return" to be returned by a final + // "loop.yield" instead. + Value arg = iterArgs[yieldOperands.size()]; + Block &reduceBlock = reduce.reductionOperator().front(); + mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg)); + mapping.map(reduceBlock.getArgument(1), + mapping.lookupOrDefault(reduce.operand())); + for (auto &nested : reduceBlock.without_terminator()) + rewriter.clone(nested, mapping); + yieldOperands.push_back( + mapping.lookup(reduceBlock.getTerminator()->getOperand(0))); + } + + rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); + rewriter.replaceOpWithNewOp( + rewriter.getInsertionBlock()->getTerminator(), yieldOperands); - rewriter.eraseOp(parallelOp); + rewriter.replaceOp(parallelOp, loopResults); return matchSuccess(); } diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -61,11 +61,16 @@ //===----------------------------------------------------------------------===// void ForOp::build(Builder *builder, OperationState &result, Value lb, Value ub, - Value step) { + Value step, ValueRange iterArgs) { result.addOperands({lb, ub, step}); + result.addOperands(iterArgs); + for (Value v : iterArgs) + result.addTypes(v.getType()); Region *bodyRegion = result.addRegion(); ForOp::ensureTerminator(*bodyRegion, *builder, result.location); bodyRegion->front().addArgument(builder->getIndexType()); + for (Value v : iterArgs) + bodyRegion->front().addArgument(v.getType()); } static LogicalResult verify(ForOp op) { diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir --- a/mlir/test/Conversion/convert-to-cfg.mlir +++ b/mlir/test/Conversion/convert-to-cfg.mlir @@ -236,3 +236,88 @@ } return %r : f32 } + +func @generate() -> i64 + +// CHECK-LABEL: @simple_parallel_reduce_loop +// CHECK-SAME: %[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index, %[[INIT:.*]]: f32 +func @simple_parallel_reduce_loop(%arg0: index, %arg1: index, + %arg2: index, %arg3: f32) -> f32 { + // A parallel loop with reduction is converted through sequential loops with + // reductions into a CFG of blocks where the partially reduced value is + // passed across as a block argument. + + // Branch to the condition block passing in the initial reduction value. + // CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT]] + + // Condition branch takes as arguments the current value of the iteration + // variable and the current partially reduced value. + // CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG:.*]]: f32 + // CHECK: %[[COMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]] + // CHECK: cond_br %[[COMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] + + // Bodies of loop.reduce operations are folded into the main loop body. The + // result of this partial reduction is passed as argument to the condition + // block. + // CHECK: ^[[BODY]]: + // CHECK: %[[CST:.*]] = constant 4.2 + // CHECK: %[[PROD:.*]] = mulf %[[ITER_ARG]], %[[CST]] + // CHECK: %[[INCR:.*]] = addi %[[ITER]], %[[STEP]] + // CHECK: br ^[[COND]](%[[INCR]], %[[PROD]] + + // The continuation block has access to the (last value of) reduction. + // CHECK: ^[[CONTINUE]]: + // CHECK: return %[[ITER_ARG]] + %0 = loop.parallel (%i) = (%arg0) to (%arg1) step (%arg2) init(%arg3) { + %cst = constant 42.0 : f32 + loop.reduce(%cst) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = mulf %lhs, %rhs : f32 + loop.reduce.return %1 : f32 + } : f32 + } : f32 + return %0 : f32 +} + +// CHECK-LABEL: parallel_reduce_loop +// CHECK-SAME: %[[INIT1:[0-9A-Za-z_]*]]: f32) +func @parallel_reduce_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : f32) -> (f32, i64) { + // Multiple reduction blocks should be folded in the same body, and the + // reduction value must be forwarded through block structures. + // CHECK: %[[INIT2:.*]] = constant 42 + // CHECK: br ^[[COND_OUT:.*]](%{{.*}}, %[[INIT1]], %[[INIT2]] + // CHECK: ^[[COND_OUT]](%{{.*}}: index, %[[ITER_ARG1_OUT:.*]]: f32, %[[ITER_ARG2_OUT:.*]]: i64 + // CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] + // CHECK: ^[[BODY_OUT]]: + // CHECK: br ^[[COND_IN:.*]](%{{.*}}, %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]] + // CHECK: ^[[COND_IN]](%{{.*}}: index, %[[ITER_ARG1_IN:.*]]: f32, %[[ITER_ARG2_IN:.*]]: i64 + // CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] + // CHECK: ^[[BODY_IN]]: + // CHECK: %[[REDUCE1:.*]] = addf %[[ITER_ARG1_IN]], %{{.*}} + // CHECK: %[[REDUCE2:.*]] = or %[[ITER_ARG2_IN]], %{{.*}} + // CHECK: br ^[[COND_IN]](%{{.*}}, %[[REDUCE1]], %[[REDUCE2]] + // CHECK: ^[[CONT_IN]]: + // CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ITER_ARG1_IN]], %[[ITER_ARG2_IN]] + // CHECK: ^[[CONT_OUT]]: + // CHECK: return %[[ITER_ARG1_OUT]], %[[ITER_ARG2_OUT]] + %step = constant 1 : index + %init = constant 42 : i64 + %0:2 = loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %step) init(%arg5, %init) { + %cf = constant 42.0 : f32 + loop.reduce(%cf) { + ^bb0(%lhs: f32, %rhs: f32): + %1 = addf %lhs, %rhs : f32 + loop.reduce.return %1 : f32 + } : f32 + + %2 = call @generate() : () -> i64 + loop.reduce(%2) { + ^bb0(%lhs: i64, %rhs: i64): + %3 = or %lhs, %rhs : i64 + loop.reduce.return %3 : i64 + } : i64 + } : f32, i64 + return %0#0, %0#1 : f32, i64 +}