diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -199,11 +199,11 @@ return success(); } - // Don't simplify operations with nested blocks. We don't currently model - // equality comparisons correctly among other things. It is also unclear - // whether we would want to CSE such operations. - if (op->getNumRegions() != 0 && - (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0)))) + // Don't simplify operations with regions that have multiple blocks. + // TODO: We need additional tests to verify that we handle such IR correctly. + if (!llvm::all_of(op->getRegions(), [](Region &r) { + return r.getBlocks().empty() || llvm::hasSingleElement(r.getBlocks()); + })) return failure(); // Some simple use case of operation with memory side-effect are dealt with diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -468,3 +468,28 @@ // CHECK: %[[OP:.+]] = test.cse_of_single_block_op // CHECK: test.region_yield %[[TRUE]] // CHECK: return %[[OP]], %[[OP]] + +func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { + %r1 = scf.if %c -> (tensor<5xf32>) { + %0 = tensor.empty() : tensor<5xf32> + scf.yield %0 : tensor<5xf32> + } else { + scf.yield %t : tensor<5xf32> + } + %r2 = scf.if %c -> (tensor<5xf32>) { + %0 = tensor.empty() : tensor<5xf32> + scf.yield %0 : tensor<5xf32> + } else { + scf.yield %t : tensor<5xf32> + } + return %r1, %r2 : tensor<5xf32>, tensor<5xf32> +} +// CHECK-LABEL: func @cse_multiple_regions +// CHECK: %[[if:.*]] = scf.if {{.*}} { +// CHECK: tensor.empty +// CHECK: scf.yield +// CHECK: } else { +// CHECK: scf.yield +// CHECK: } +// CHECK-NOT: scf.if +// CHECK: return %[[if]], %[[if]]