diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -30,6 +30,16 @@ using namespace mlir; +static bool hasNoOrEmptyRegions(Operation *op) { + return op->getNumRegions() == 0 || + llvm::all_of(op->getRegions(), + [](auto ®ion) { return region.empty(); }); +} + +static bool hasSingleRegionWithSingleBlock(Operation *op) { + return op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)); +} + namespace { struct SimpleOperationInfo : public llvm::DenseMapInfo { static unsigned getHashValue(const Operation *opC) { @@ -48,9 +58,9 @@ rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; - // If op has no regions, operation equivalence w.r.t operands alone is - // enough. - if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) { + // If op has no or empty regions, operation equivalence + // w.r.t operands alone is enough. + if (hasNoOrEmptyRegions(lhs) && hasNoOrEmptyRegions(rhs)) { return OperationEquivalence::isEquivalentTo( const_cast(lhsC), const_cast(rhsC), OperationEquivalence::exactValueMatch, @@ -60,9 +70,8 @@ // If lhs or rhs does not have a single region with a single block, they // aren't CSEed for now. - if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 || - !llvm::hasSingleElement(lhs->getRegion(0)) || - !llvm::hasSingleElement(rhs->getRegion(0))) + if (!hasSingleRegionWithSingleBlock(lhs) || + !hasSingleRegionWithSingleBlock(rhs)) return false; // Compare the two blocks. @@ -263,8 +272,7 @@ // Don't simplify operations with nested blocks. We don't currently model // equality comparisons correctly among other things. It is also unclear // whether we would want to CSE such operations. - if (op->getNumRegions() != 0 && - (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0)))) + if (!hasSingleRegionWithSingleBlock(op) && !hasNoOrEmptyRegions(op)) return failure(); // Some simple use case of operation with memory side-effect are dealt with diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -468,3 +468,17 @@ // CHECK: %[[OP:.+]] = test.cse_of_single_block_op // CHECK: test.region_yield %[[TRUE]] // CHECK: return %[[OP]], %[[OP]] + +// Operations with empty region CSE +// CHECK-LABEL: func @cse_ops_with_empty_region +// CHECK: %[[OP0:.*[0]]] = "test.any_cond" +// CHECK-NOT: %[[OP1:.*[0]]] = "test.any_cond" +// CHECK: return %[[OP0]], %[[OP0]] +func.func @cse_ops_with_empty_region() + -> (i32, i32) { + %0 = "test.any_cond"() ({ + }) : () -> i32 + %1 = "test.any_cond"() ({ + }) : () -> i32 + return %0, %1 : i32, i32 +}