diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -45,21 +45,26 @@ // are split out into a separate continuation (exit) block. A condition block is // created before the continuation block. It checks the exit condition of the // loop and branches either to the continuation block, or to the first block of -// the body. Induction variable modification is appended to the last block of -// the body (which is the exit block from the body subgraph thanks to the +// the body. The condition block takes as arguments the values of the induction +// variable followed by loop-carried values. Since it dominates both the body +// blocks and the continuation block, loop-carried values are visible in all of +// those blocks. Induction variable modification is appended to the last block +// of the body (which is the exit block from the body subgraph thanks to the // invariant we maintain) along with a branch that loops back to the condition -// block. +// block. Loop-carried values are the loop terminator operands, which are +// forwarded to the branch. // // +---------------------------------+ // | | +// | | // | | -// | br cond(%iv) | +// | br cond(%iv, %init...) | // +---------------------------------+ // | // -------| | // | v v // | +--------------------------------+ -// | | cond(%iv): | +// | | cond(%iv, %init...): | // | | | // | | cond_br %r, body, end | // | +--------------------------------+ @@ -68,6 +73,7 @@ // | v | // | +--------------------------------+ | // | | body-first: | | +// | | <%init visible by dominance> | | // | | | | // | +--------------------------------+ | // | | | @@ -76,15 +82,17 @@ // | +--------------------------------+ | // | | body-last: | | // | | | | +// | | | | // | | %new_iv = | | -// | | br cond(%new_iv) | | +// | | br cond(%new_iv, %yields) | | // | +--------------------------------+ | // | | | // |----------- |-------------------- // v // +--------------------------------+ // | end: | -// | | +// | | +// | <%init visible by dominance> | // +--------------------------------+ // struct ForLowering : public OpRewritePattern { @@ -133,7 +141,7 @@ // v v // +--------------------------------+ // | continue: | -// | | +// | | // +--------------------------------+ // struct IfLowering : public OpRewritePattern { @@ -162,10 +170,10 @@ auto initPosition = rewriter.getInsertionPoint(); auto *endBlock = rewriter.splitBlock(initBlock, initPosition); - // Use the first block of the loop body as the condition block since it is - // the block that has the induction variable as its argument. Split out - // all operations from the first block into a new block. Move all body - // blocks from the loop body region to the region containing the loop. + // Use the first block of the loop body as the condition block since it is the + // block that has the induction variable and loop-carried values as arguments. + // Split out all operations from the first block into a new block. Move all + // body blocks from the loop body region to the region containing the loop. auto *conditionBlock = &forOp.region().front(); auto *firstBodyBlock = rewriter.splitBlock(conditionBlock, conditionBlock->begin()); @@ -174,15 +182,20 @@ auto iv = conditionBlock->getArgument(0); // Append the induction variable stepping logic to the last body block and - // branch back to the condition block. Construct an expression f : - // (x -> x+step) and apply this expression to the induction variable. - rewriter.eraseOp(lastBodyBlock->getTerminator()); + // branch back to the condition block. Loop-carried values are taken from + // operands of the loop terminator. + Operation *terminator = lastBodyBlock->getTerminator(); rewriter.setInsertionPointToEnd(lastBodyBlock); auto step = forOp.step(); auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) return matchFailure(); - rewriter.create(loc, conditionBlock, stepped); + + SmallVector loopCarried; + loopCarried.push_back(stepped); + loopCarried.append(terminator->operand_begin(), terminator->operand_end()); + rewriter.create(loc, conditionBlock, loopCarried); + rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. rewriter.setInsertionPointToEnd(initBlock); @@ -190,7 +203,14 @@ Value upperBound = forOp.upperBound(); if (!lowerBound || !upperBound) return matchFailure(); - rewriter.create(loc, conditionBlock, lowerBound); + + // The initial values of loop-carried values is obtained from the operands + // of the loop operation. + SmallVector destOperands; + destOperands.push_back(lowerBound); + auto iterOperands = forOp.getIterOperands(); + destOperands.append(iterOperands.begin(), iterOperands.end()); + rewriter.create(loc, conditionBlock, destOperands); // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); @@ -199,8 +219,9 @@ rewriter.create(loc, comparison, firstBodyBlock, ArrayRef(), endBlock, ArrayRef()); - // Ok, we're done! - rewriter.eraseOp(forOp); + // The result of the loop operation is the values of the condition block + // arguments except the induction variable on the last iteration. + rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); return matchSuccess(); } diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir --- a/mlir/test/Conversion/convert-to-cfg.mlir +++ b/mlir/test/Conversion/convert-to-cfg.mlir @@ -180,3 +180,59 @@ } return } + +// CHECK-LABEL: @for_yield +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK: %[[INIT0:.*]] = constant 0 +// CHECK: %[[INIT1:.*]] = constant 1 +// CHECK: br ^[[COND:.*]](%[[LB]], %[[INIT0]], %[[INIT1]] : index, f32, f32) +// +// CHECK: ^[[COND]](%[[ITER:.*]]: index, %[[ITER_ARG0:.*]]: f32, %[[ITER_ARG1:.*]]: f32): +// CHECK: %[[CMP:.*]] = cmpi "slt", %[[ITER]], %[[UB]] : index +// CHECK: cond_br %[[CMP]], ^[[BODY:.*]], ^[[CONTINUE:.*]] +// +// CHECK: ^[[BODY]]: +// CHECK: %[[SUM:.*]] = addf %[[ITER_ARG0]], %[[ITER_ARG1]] : f32 +// CHECK: %[[STEPPED:.*]] = addi %[[ITER]], %[[STEP]] : index +// CHECK: br ^[[COND]](%[[STEPPED]], %[[SUM]], %[[SUM]] : index, f32, f32) +// +// CHECK: ^[[CONTINUE]]: +// CHECK: return %[[ITER_ARG0]], %[[ITER_ARG1]] : f32, f32 +func @for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> (f32, f32) { + %s0 = constant 0.0 : f32 + %s1 = constant 1.0 : f32 + %result:2 = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0, %sj = %s1) -> (f32, f32) { + %sn = addf %si, %sj : f32 + loop.yield %sn, %sn : f32, f32 + } + return %result#0, %result#1 : f32, f32 +} + +// CHECK-LABEL: @nested_for_yield +// CHECK-SAME: (%[[LB:.*]]: index, %[[UB:.*]]: index, %[[STEP:.*]]: index) +// CHECK: %[[INIT:.*]] = constant +// CHECK: br ^[[COND_OUT:.*]](%[[LB]], %[[INIT]] : index, f32) +// CHECK: ^[[COND_OUT]](%[[ITER_OUT:.*]]: index, %[[ARG_OUT:.*]]: f32): +// CHECK: cond_br %{{.*}}, ^[[BODY_OUT:.*]], ^[[CONT_OUT:.*]] +// CHECK: ^[[BODY_OUT]]: +// CHECK: br ^[[COND_IN:.*]](%[[LB]], %[[ARG_OUT]] : index, f32) +// CHECK: ^[[COND_IN]](%[[ITER_IN:.*]]: index, %[[ARG_IN:.*]]: f32): +// CHECK: cond_br %{{.*}}, ^[[BODY_IN:.*]], ^[[CONT_IN:.*]] +// CHECK: ^[[BODY_IN]] +// CHECK: %[[RES:.*]] = addf +// CHECK: br ^[[COND_IN]](%{{.*}}, %[[RES]] : index, f32) +// CHECK: ^[[CONT_IN]]: +// CHECK: br ^[[COND_OUT]](%{{.*}}, %[[ARG_IN]] : index, f32) +// CHECK: ^[[CONT_OUT]]: +// CHECK: return %[[ARG_OUT]] : f32 +func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 { + %s0 = constant 1.0 : f32 + %r = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%iter = %s0) -> (f32) { + %result = loop.for %i1 = %arg0 to %arg1 step %arg2 iter_args(%si = %iter) -> (f32) { + %sn = addf %si, %si : f32 + loop.yield %sn : f32 + } + loop.yield %result : f32 + } + return %r : f32 +}