diff --git a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/LoopToStandard.cpp @@ -112,13 +112,21 @@ // blocks are respectively the first/last block of the enclosing region. The // operations following the loop.if are split into a continuation (subgraph // exit) block. The condition is lowered to a chain of blocks that implement the -// short-circuit scheme. Condition blocks are created by splitting out an empty -// block from the block that contains the loop.if operation. They -// conditionally branch to either the first block of the "then" region, or to -// the first block of the "else" region. If the latter is absent, they branch -// to the continuation block instead. The last blocks of "then" and "else" -// regions (which are known to be exit blocks thanks to the invariant we -// maintain). +// short-circuit scheme. The "loop.if" operation is replaced with a conditional +// branch to either the first block of the "then" region, or to the first block +// of the "else" region. In these blocks, "loop.yield" is unconditional branches +// to the post-dominating block. When the "loop.if" does not return values, the +// post-dominating block is the same as the continuation block. When it returns +// values, the post-dominating block is a new block with arguments that +// correspond to the values returned by the "loop.if" that unconditionally +// branches to the continuation block. This allows block arguments to dominate +// any uses of the hitherto "loop.if" results that they replaced. (Inserting a +// new block allows us to avoid modifying the argument list of an existing +// block, which is illegal in a conversion pattern). When the "else" region is +// empty, which is only allowed for "loop.if"s that don't return values, the +// condition branches directly to the continuation block. +// +// CFG for a loop.if with else and without results. // // +--------------------------------+ // | | @@ -148,6 +156,42 @@ // | | // +--------------------------------+ // +// CFG for a loop.if with results. +// +// +--------------------------------+ +// | | +// | cond_br %cond, %then, %else | +// +--------------------------------+ +// | | +// | --------------| +// v | +// +--------------------------------+ | +// | then: | | +// | | | +// | br dom(%args...) | | +// +--------------------------------+ | +// | | +// |---------- |------------- +// | V +// | +--------------------------------+ +// | | else: | +// | | | +// | | br dom(%args...) | +// | +--------------------------------+ +// | | +// ------| | +// v v +// +--------------------------------+ +// | dom(%args...): | +// | br continue | +// +--------------------------------+ +// | +// v +// +--------------------------------+ +// | continue: | +// | | +// +--------------------------------+ +// struct IfLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -238,15 +282,25 @@ // continuation point. auto *condBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); - auto *continueBlock = rewriter.splitBlock(condBlock, opPosition); + auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); + Block *continueBlock; + if (ifOp.getNumResults() == 0) { + continueBlock = remainingOpsBlock; + } else { + continueBlock = + rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes()); + rewriter.create(loc, remainingOpsBlock); + } // Move blocks from the "then" region to the region containing 'loop.if', // place it before the continuation block, and branch to it. auto &thenRegion = ifOp.thenRegion(); auto *thenBlock = &thenRegion.front(); - rewriter.eraseOp(thenRegion.back().getTerminator()); + Operation *thenTerminator = thenRegion.back().getTerminator(); + ValueRange thenTerminatorOperands = thenTerminator->getOperands(); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, continueBlock); + rewriter.create(loc, continueBlock, thenTerminatorOperands); + rewriter.eraseOp(thenTerminator); rewriter.inlineRegionBefore(thenRegion, continueBlock); // Move blocks from the "else" region (if present) to the region containing @@ -256,9 +310,11 @@ auto &elseRegion = ifOp.elseRegion(); if (!elseRegion.empty()) { elseBlock = &elseRegion.front(); - rewriter.eraseOp(elseRegion.back().getTerminator()); + Operation *elseTerminator = elseRegion.back().getTerminator(); + ValueRange elseTerminatorOperands = elseTerminator->getOperands(); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, continueBlock); + rewriter.create(loc, continueBlock, elseTerminatorOperands); + rewriter.eraseOp(elseTerminator); rewriter.inlineRegionBefore(elseRegion, continueBlock); } @@ -268,7 +324,7 @@ /*falseArgs=*/ArrayRef()); // Ok, we're done! - rewriter.eraseOp(ifOp); + rewriter.replaceOp(ifOp, continueBlock->getArguments()); return success(); } diff --git a/mlir/test/Conversion/convert-to-cfg.mlir b/mlir/test/Conversion/convert-to-cfg.mlir --- a/mlir/test/Conversion/convert-to-cfg.mlir +++ b/mlir/test/Conversion/convert-to-cfg.mlir @@ -148,6 +148,83 @@ return } +// CHECK-LABEL: func @simple_if_yield +func @simple_if_yield(%arg0: i1) -> (i1, i1) { +// CHECK: cond_br %{{.*}}, ^[[then:.*]], ^[[else:.*]] + %0:2 = loop.if %arg0 -> (i1, i1) { +// CHECK: ^[[then]]: +// CHECK: %[[v0:.*]] = constant 0 +// CHECK: %[[v1:.*]] = constant 1 +// CHECK: br ^[[dom:.*]](%[[v0]], %[[v1]] : i1, i1) + %c0 = constant 0 : i1 + %c1 = constant 1 : i1 + loop.yield %c0, %c1 : i1, i1 + } else { +// CHECK: ^[[else]]: +// CHECK: %[[v2:.*]] = constant 0 +// CHECK: %[[v3:.*]] = constant 1 +// CHECK: br ^[[dom]](%[[v3]], %[[v2]] : i1, i1) + %c0 = constant 0 : i1 + %c1 = constant 1 : i1 + loop.yield %c1, %c0 : i1, i1 + } +// CHECK: ^[[dom]](%[[arg1:.*]]: i1, %[[arg2:.*]]: i1): +// CHECK: br ^[[cont:.*]] +// CHECK: ^[[cont]]: +// CHECK: return %[[arg1]], %[[arg2]] + return %0#0, %0#1 : i1, i1 +} + +// CHECK-LABEL: func @nested_if_yield +func @nested_if_yield(%arg0: i1) -> (index) { +// CHECK: cond_br %{{.*}}, ^[[first_then:.*]], ^[[first_else:.*]] + %0 = loop.if %arg0 -> i1 { +// CHECK: ^[[first_then]]: + %1 = constant 1 : i1 +// CHECK: br ^[[first_dom:.*]]({{.*}}) + loop.yield %1 : i1 + } else { +// CHECK: ^[[first_else]]: + %2 = constant 0 : i1 +// CHECK: br ^[[first_dom]]({{.*}}) + loop.yield %2 : i1 + } +// CHECK: ^[[first_dom]](%[[arg1:.*]]: i1): +// CHECK: br ^[[first_cont:.*]] +// CHECK: ^[[first_cont]]: +// CHECK: cond_br %[[arg1]], ^[[second_outer_then:.*]], ^[[second_outer_else:.*]] + %1 = loop.if %0 -> index { +// CHECK: ^[[second_outer_then]]: +// CHECK: cond_br %arg0, ^[[second_inner_then:.*]], ^[[second_inner_else:.*]] + %3 = loop.if %arg0 -> index { +// CHECK: ^[[second_inner_then]]: + %4 = constant 40 : index +// CHECK: br ^[[second_inner_dom:.*]]({{.*}}) + loop.yield %4 : index + } else { +// CHECK: ^[[second_inner_else]]: + %5 = constant 41 : index +// CHECK: br ^[[second_inner_dom]]({{.*}}) + loop.yield %5 : index + } +// CHECK: ^[[second_inner_dom]](%[[arg2:.*]]: index): +// CHECK: br ^[[second_inner_cont:.*]] +// CHECK: ^[[second_inner_cont]]: +// CHECK: br ^[[second_outer_dom:.*]]({{.*}}) + loop.yield %3 : index + } else { +// CHECK: ^[[second_outer_else]]: + %6 = constant 42 : index +// CHECK: br ^[[second_outer_dom]]({{.*}} + loop.yield %6 : index + } +// CHECK: ^[[second_outer_dom]](%[[arg3:.*]]: index): +// CHECK: br ^[[second_outer_cont:.*]] +// CHECK: ^[[second_outer_cont]]: +// CHECK: return %[[arg3]] : index + return %1 : index +} + // CHECK-LABEL: func @parallel_loop( // CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index) { // CHECK: [[VAL_5:%.*]] = constant 1 : index