diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -532,7 +532,8 @@ : rewriter.create(loc, /*value=*/1, /*width=*/1); bool hasElseRegion = !op.elseRegion().empty(); - auto ifOp = rewriter.create(loc, cond, hasElseRegion); + auto ifOp = rewriter.create(loc, op.getResultTypes(), cond, + hasElseRegion); rewriter.inlineRegionBefore(op.thenRegion(), &ifOp.thenRegion().back()); rewriter.eraseBlock(&ifOp.thenRegion().back()); if (hasElseRegion) { @@ -540,8 +541,8 @@ rewriter.eraseBlock(&ifOp.elseRegion().back()); } - // Ok, we're done! - rewriter.eraseOp(op); + // Replace the Affine IfOp finally. + rewriter.replaceOp(op, ifOp.results()); return success(); } }; diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -239,6 +239,33 @@ return } +// CHECK-LABEL: func @if_with_yield +// CHECK-NEXT: %[[c0_i64:.*]] = constant 0 : i64 +// CHECK-NEXT: %[[c1_i64:.*]] = constant 1 : i64 +// CHECK-NEXT: %[[v0:.*]] = call @get_idx() : () -> index +// CHECK-NEXT: %[[c0:.*]] = constant 0 : index +// CHECK-NEXT: %[[cm10:.*]] = constant -10 : index +// CHECK-NEXT: %[[v1:.*]] = addi %[[v0]], %[[cm10]] : index +// CHECK-NEXT: %[[v2:.*]] = cmpi sge, %[[v1]], %[[c0]] : index +// CHECK-NEXT: %[[v3:.*]] = scf.if %[[v2]] -> (i64) { +// CHECK-NEXT: scf.yield %[[c0_i64]] : i64 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[c1_i64]] : i64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[v3]] : i64 +// CHECK-NEXT: } +func @if_with_yield() -> (i64) { + %cst0 = constant 0 : i64 + %cst1 = constant 1 : i64 + %i = call @get_idx() : () -> (index) + %1 = affine.if #set2(%i) -> (i64) { + affine.yield %cst0 : i64 + } else { + affine.yield %cst1 : i64 + } + return %1 : i64 +} + #setN = affine_set<(d0)[N,M,K,L] : (N - d0 + 1 >= 0, N - 1 >= 0, M - 1 >= 0, K - 1 >= 0, L - 42 == 0)> // CHECK-LABEL: func @multi_cond