diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1366,6 +1366,7 @@ SmallVector trueYields; SmallVector falseYields; + rewriter.setInsertionPoint(replacement); for (const auto &it : llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) { Value trueVal = std::get<0>(it.value()); diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -321,10 +321,10 @@ // CHECK-LABEL: func @to_select_with_body // CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index // CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] // CHECK: scf.if {{.*}} { // CHECK: "test.op"() : () -> () // CHECK: } -// CHECK: [[V0:%.*]] = arith.select {{.*}}, [[C0]], [[C1]] // CHECK: return [[V0]] : index // ----- @@ -556,10 +556,10 @@ // CHECK: %[[PRE0:.*]] = "test.op"() : () -> i32 // CHECK: %[[PRE1:.*]] = "test.op1"() : () -> i32 // CHECK: %[[COND:.*]] = arith.andi %[[ARG0]], %[[ARG1]] +// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]] // CHECK: scf.if %[[COND]] // CHECK: "test.run"() : () -> () // CHECK: } -// CHECK: %[[RES:.*]] = arith.select %[[COND]], %[[PRE0]], %[[PRE1]] // CHECK: return %[[RES]] %0 = "test.op"() : () -> (i32) %1 = "test.op1"() : () -> (i32) @@ -933,6 +933,7 @@ return %res#0, %res#1 : i32, i1 } // CHECK-NEXT: %true = arith.constant true +// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1 // CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) { // CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32 // CHECK-NEXT: scf.yield %[[sv1]] : i32 @@ -940,7 +941,6 @@ // CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32 // CHECK-NEXT: scf.yield %[[sv2]] : i32 // CHECK-NEXT: } -// CHECK-NEXT: %[[toret:.+]] = arith.xori %arg0, %true : i1 // CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1 // -----