Index: mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp =================================================================== --- mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -114,7 +114,9 @@ // the first block of the "else" region. If the latter is absent, they branch // to the continuation block instead. The last blocks of "then" and "else" // regions (which are known to be exit blocks thanks to the invariant we -// maintain). +// maintain). The terminator operands of the then and else blocks are the +// results of the loop.if operation and thus forwarded to the continuation +// block. // // +--------------------------------+ // | | @@ -126,7 +128,7 @@ // +--------------------------------+ | // | then: | | // | | | -// | br continue | | +// | br continue(%yields...) | | // +--------------------------------+ | // | | // |---------- |------------- @@ -134,13 +136,13 @@ // | +--------------------------------+ // | | else: | // | | | -// | | br continue | +// | | br continue(%yields...) | // | +--------------------------------+ // | | // ------| | // v v // +--------------------------------+ -// | continue: | +// | continue(%yields...): | // | | // +--------------------------------+ // @@ -236,25 +238,39 @@ auto opPosition = rewriter.getInsertionPoint(); auto *continueBlock = rewriter.splitBlock(condBlock, opPosition); + // Add the results of the if loop as block argument to the continuation block + for (auto result : ifOp.results()) + result.replaceAllUsesWith(continueBlock->addArgument(result.getType())); + // Move blocks from the "then" region to the region containing 'loop.if', - // place it before the continuation block, and branch to it. + // place it before the continuation block, and branch to it. The results of + // the thenBlock are taken from the terminator and passed to the branch op. auto &thenRegion = ifOp.thenRegion(); auto *thenBlock = &thenRegion.front(); - rewriter.eraseOp(thenRegion.back().getTerminator()); + Operation *thenTerminator = thenRegion.back().getTerminator(); + SmallVector thenResults; + thenResults.append(thenTerminator->operand_begin(), + thenTerminator->operand_end()); + rewriter.eraseOp(thenTerminator); rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, continueBlock); + rewriter.create(loc, continueBlock, thenResults); rewriter.inlineRegionBefore(thenRegion, continueBlock); // Move blocks from the "else" region (if present) to the region containing // 'loop.if', place it before the continuation block and branch to it. It - // will be placed after the "then" regions. + // will be placed after the "then" regions. The results of the elseBlock are + // taken from the terminator and passed to the branch op. auto *elseBlock = continueBlock; auto &elseRegion = ifOp.elseRegion(); if (!elseRegion.empty()) { elseBlock = &elseRegion.front(); - rewriter.eraseOp(elseRegion.back().getTerminator()); + Operation *elseTerminator = elseRegion.back().getTerminator(); + SmallVector elseResults; + elseResults.append(elseTerminator->operand_begin(), + elseTerminator->operand_end()); + rewriter.eraseOp(elseTerminator); rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, continueBlock); + rewriter.create(loc, continueBlock, elseResults); rewriter.inlineRegionBefore(elseRegion, continueBlock); } Index: mlir/test/Conversion/convert-to-cfg.mlir =================================================================== --- mlir/test/Conversion/convert-to-cfg.mlir +++ mlir/test/Conversion/convert-to-cfg.mlir @@ -81,6 +81,27 @@ return } +// CHECK-LABEL: func @simple_yield_if_else(%{{.*}}: i1) -> (index, index) { +// CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb2 +// CHECK-NEXT: ^bb1: // pred: ^bb0 +// CHECK-NEXT: [[C1:%.*]] = constant 1 : index +// CHECK-NEXT: br ^bb3([[C1]], [[C1]] : index, index) +// CHECK-NEXT: ^bb2: // pred: ^bb0 +// CHECK-NEXT: [[C2:%.*]] = constant 2 : index +// CHECK-NEXT: br ^bb3([[C2]], [[C2]] : index, index) +// CHECK-NEXT: ^bb3([[R1:%.*]]: index, [[R2:%.*]]: index): // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: return [[R1]], [[R2]] : index, index +func @simple_yield_if_else(%arg0 : i1) -> (index, index) { + %result:2 = loop.if %arg0 -> (index, index) { + %c1 = constant 1 : index + loop.yield %c1, %c1 : index, index + } else { + %c2 = constant 2 : index + loop.yield %c2, %c2 : index, index + } + return %result#0, %result#1 : index, index +} + // CHECK-LABEL: func @simple_std_2_ifs(%{{.*}}: i1) { // CHECK-NEXT: cond_br %{{.*}}, ^bb1, ^bb5 // CHECK-NEXT: ^bb1: // pred: ^bb0