diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -924,23 +924,32 @@
   static llvm::hash_code directHashValue(Value v) { return hash_value(v); }
 
   /// Compare two operations and return if they are equivalent.
-  /// `mapOperands` and `mapResults` are optional callbacks that allows the
-  /// caller to check the mapping of SSA value between the lhs and rhs
-  /// operations. It is expected to return success if the mapping is valid and
-  /// failure if it conflicts with a previous mapping.
+  ///
+  /// `checkEquivalent` is a callback to check if two values are equivalent.
+  /// `markEquivalent` is a callback to inform the caller that the analysis
+  /// determined that two values are equivalent.
+  ///
+  /// Note: Additional information regarding value equivalence can be injected
+  /// into the analysis via `checkEquivalent`. Typically, callers may want
+  /// values that were determined to be equivalent as per `markEquivalent` to be
+  /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different
+  /// equivalence relationship is desired.
   static bool
   isEquivalentTo(Operation *lhs, Operation *rhs,
-                 function_ref<LogicalResult(Value, Value)> mapOperands,
-                 function_ref<LogicalResult(Value, Value)> mapResults,
+                 function_ref<LogicalResult(Value, Value)> checkEquivalent,
+                 function_ref<void(Value, Value)> markEquivalent = nullptr,
                  Flags flags = Flags::None);
 
-  /// Helper that can be used with `isEquivalentTo` above to ignore operation
-  /// operands/result mapping.
+  /// Compare two operations and return if they are equivalent.
+  static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags);
+
+  /// Helper that can be used with `isEquivalentTo` above to consider ops
+  /// equivalent even if their operands are not equivalent.
   static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) {
     return success();
   }
-  /// Helper that can be used with `isEquivalentTo` above to ignore operation
-  /// operands/result mapping.
+  /// Helper that can be used with `isEquivalentTo` above to consider ops
+  /// equivalent only if their operands are the exact same SSA values.
   static LogicalResult exactValueMatch(Value lhs, Value rhs) {
     return success(lhs == rhs);
   }
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -652,8 +652,8 @@
 
 static bool
 isRegionEquivalentTo(Region *lhs, Region *rhs,
-                     function_ref<LogicalResult(Value, Value)> mapOperands,
-                     function_ref<LogicalResult(Value, Value)> mapResults,
+                     function_ref<LogicalResult(Value, Value)> checkEquivalent,
+                     function_ref<void(Value, Value)> markEquivalent,
                      OperationEquivalence::Flags flags) {
   DenseMap<Block *, Block *> blocksMap;
   auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) {
@@ -675,15 +675,14 @@
       if (!(flags & OperationEquivalence::IgnoreLocations) &&
           curArg.getLoc() != otherArg.getLoc())
         return false;
-      // Check if this value was already mapped to another value.
-      if (failed(mapOperands(curArg, otherArg)))
-        return false;
+      // Corresponding bbArgs are equivalent.
+      markEquivalent(curArg, otherArg);
     }
 
     auto opsEquivalent = [&](Operation &lOp, Operation &rOp) {
       // Check for op equality (recursively).
-      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands,
-                                                mapResults, flags))
+      if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent,
+                                                markEquivalent, flags))
         return false;
       // Check successor mapping.
       for (auto successorsPair :
@@ -703,12 +702,12 @@
 
 bool OperationEquivalence::isEquivalentTo(
     Operation *lhs, Operation *rhs,
-    function_ref<LogicalResult(Value, Value)> mapOperands,
-    function_ref<LogicalResult(Value, Value)> mapResults, Flags flags) {
+    function_ref<LogicalResult(Value, Value)> checkEquivalent,
+    function_ref<void(Value, Value)> markEquivalent, Flags flags) {
   if (lhs == rhs)
     return true;
 
-  // Compare the operation properties.
+  // 1. Compare the operation properties.
   if (lhs->getName() != rhs->getName() ||
       lhs->getAttrDictionary() != rhs->getAttrDictionary() ||
       lhs->getNumRegions() != rhs->getNumRegions() ||
@@ -719,6 +718,7 @@
   if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc())
     return false;
 
+  // 2. Compare operands.
   ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
   SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
   if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
@@ -752,32 +752,55 @@
     rhsOperandStorage = sortValues(rhsOperands);
     rhsOperands = rhsOperandStorage;
   }
-  auto checkValueRangeMapping =
-      [](ValueRange lhs, ValueRange rhs,
-         function_ref<LogicalResult(Value, Value)> mapValues) {
-        for (auto operandPair : llvm::zip(lhs, rhs)) {
-          Value curArg = std::get<0>(operandPair);
-          Value otherArg = std::get<1>(operandPair);
-          if (curArg.getType() != otherArg.getType())
-            return false;
-          if (failed(mapValues(curArg, otherArg)))
-            return false;
-        }
-        return true;
-      };
-  // Check mapping of operands and results.
-  if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands))
-    return false;
-  if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults))
-    return false;
+
+  for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) {
+    Value curArg = std::get<0>(operandPair);
+    Value otherArg = std::get<1>(operandPair);
+    if (curArg.getType() != otherArg.getType())
+      return false;
+    if (failed(checkEquivalent(curArg, otherArg)))
+      return false;
+  }
+
+  // 3. Compare result types and mark results as equivalent.
+  for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) {
+    Value curArg = std::get<0>(resultPair);
+    Value otherArg = std::get<1>(resultPair);
+    if (curArg.getType() != otherArg.getType())
+      return false;
+    markEquivalent(curArg, otherArg);
+  }
+
+  // 4. Compare regions.
   for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions()))
     if (!isRegionEquivalentTo(&std::get<0>(regionPair),
-                              &std::get<1>(regionPair), mapOperands, mapResults,
-                              flags))
+                              &std::get<1>(regionPair), checkEquivalent,
+                              markEquivalent, flags))
       return false;
+
   return true;
 }
 
+bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs,
+                                          Flags flags) {
+  // Equivalent values in lhs and rhs.
+  DenseMap<Value, Value> equivalentValues;
+  auto checkEquivalent = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
+    return success(lhsValue == rhsValue ||
+                   equivalentValues.lookup(lhsValue) == rhsValue);
+  };
+  auto markEquivalent = [&](Value lhsResult, Value rhsResult) {
+    auto insertion = equivalentValues.insert({lhsResult, rhsResult});
+    // Make sure that the value was not already marked equivalent to some other
+    // value.
+    (void)insertion;
+    assert(insertion.first->second == rhsResult &&
+           "inconsistent OperationEquivalence state");
+  };
+  return OperationEquivalence::isEquivalentTo(lhs, rhs, checkEquivalent,
+                                              markEquivalent, flags);
+}
+
 //===----------------------------------------------------------------------===//
 // OperationFingerPrint
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp
--- a/mlir/lib/Transforms/CSE.cpp
+++ b/mlir/lib/Transforms/CSE.cpp
@@ -47,70 +47,9 @@
     if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
         rhs == getTombstoneKey() || rhs == getEmptyKey())
       return false;
-
-    // If op has no regions, operation equivalence w.r.t operands alone is
-    // enough.
-    if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) {
-      return OperationEquivalence::isEquivalentTo(
-          const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
-          OperationEquivalence::exactValueMatch,
-          OperationEquivalence::ignoreValueEquivalence,
-          OperationEquivalence::IgnoreLocations);
-    }
-
-    // If lhs or rhs does not have a single region with a single block, they
-    // aren't CSEed for now.
-    if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 ||
-        !llvm::hasSingleElement(lhs->getRegion(0)) ||
-        !llvm::hasSingleElement(rhs->getRegion(0)))
-      return false;
-
-    // Compare the two blocks.
-    Block &lhsBlock = lhs->getRegion(0).front();
-    Block &rhsBlock = rhs->getRegion(0).front();
-
-    // Don't CSE if number of arguments differ.
-    if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
-      return false;
-
-    // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in
-    // `rhsBlock`. `Value`s from `lhsBlock` are the key.
-    DenseMap<Value, Value> areEquivalentValues;
-    for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(),
-                                 rhs->getRegion(0).getArguments())) {
-      areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs);
-    }
-
-    // Helper function to get the parent operation.
-    auto getParent = [](Value v) -> Operation * {
-      if (auto blockArg = v.dyn_cast<BlockArgument>())
-        return blockArg.getParentBlock()->getParentOp();
-      return v.getDefiningOp()->getParentOp();
-    };
-
-    // Callback to compare if operands of ops in the region of `lhs` and `rhs`
-    // are equivalent.
-    auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
-      if (lhsValue == rhsValue)
-        return success();
-      if (areEquivalentValues.lookup(lhsValue) == rhsValue)
-        return success();
-      return failure();
-    };
-
-    // Callback to compare if results of ops in the region of `lhs` and `rhs`
-    // are equivalent.
-    auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult {
-      if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) {
-        auto insertion = areEquivalentValues.insert({lhsResult, rhsResult});
-        return success(insertion.first->second == rhsResult);
-      }
-      return success();
-    };
-
     return OperationEquivalence::isEquivalentTo(
         const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
-        mapOperands, mapResults, OperationEquivalence::IgnoreLocations);
+        OperationEquivalence::IgnoreLocations);
   }
 };
 } // namespace
@@ -260,11 +199,10 @@
     return success();
   }
 
-  // Don't simplify operations with nested blocks. We don't currently model
-  // equality comparisons correctly among other things. It is also unclear
-  // whether we would want to CSE such operations.
-  if (op->getNumRegions() != 0 &&
-      (op->getNumRegions() != 1 || !llvm::hasSingleElement(op->getRegion(0))))
+  // Don't simplify operations with regions that have multiple blocks.
+  // TODO: We need additional tests to verify that we handle such IR correctly.
+  if (llvm::any_of(op->getRegions(),
+                   [](Region &r) { return r.getBlocks().size() > 1; }))
     return failure();
 
   // Some simple use case of operation with memory side-effect are dealt with
diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir
--- a/mlir/test/Transforms/cse.mlir
+++ b/mlir/test/Transforms/cse.mlir
@@ -468,3 +468,28 @@
 //       CHECK:   %[[OP:.+]] = test.cse_of_single_block_op
 //       CHECK:     test.region_yield %[[TRUE]]
 //       CHECK:   return %[[OP]], %[[OP]]
+
+func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
+  %r1 = scf.if %c -> (tensor<5xf32>) {
+    %0 = tensor.empty() : tensor<5xf32>
+    scf.yield %0 : tensor<5xf32>
+  } else {
+    scf.yield %t : tensor<5xf32>
+  }
+  %r2 = scf.if %c -> (tensor<5xf32>) {
+    %0 = tensor.empty() : tensor<5xf32>
+    scf.yield %0 : tensor<5xf32>
+  } else {
+    scf.yield %t : tensor<5xf32>
+  }
+  return %r1, %r2 : tensor<5xf32>, tensor<5xf32>
+}
+// CHECK-LABEL: func @cse_multiple_regions
+//       CHECK:   %[[if:.*]] = scf.if {{.*}} {
+//       CHECK:     tensor.empty
+//       CHECK:     scf.yield
+//       CHECK:   } else {
+//       CHECK:     scf.yield
+//       CHECK:   }
+//   CHECK-NOT:   scf.if
+//       CHECK:   return %[[if]], %[[if]]
diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp
--- a/mlir/test/lib/IR/TestOperationEquals.cpp
+++ b/mlir/test/lib/IR/TestOperationEquals.cpp
@@ -28,11 +28,6 @@
                          << opCount;
       return signalPassFailure();
     }
-    DenseMap<Value, Value> valuesMap;
-    auto mapValue = [&](Value lhs, Value rhs) {
-      auto insertion = valuesMap.insert({lhs, rhs});
-      return success(insertion.first->second == rhs);
-    };
 
     Operation *first = &module.getBody()->front();
     llvm::outs() << first->getName().getStringRef() << " with attr "
@@ -41,7 +36,7 @@
     if (!first->hasAttr("strict_loc_check"))
       flags |= OperationEquivalence::IgnoreLocations;
     if (OperationEquivalence::isEquivalentTo(first, &module.getBody()->back(),
-                                             mapValue, mapValue, flags))
+                                             flags))
       llvm::outs() << " compares equals.\n";
     else
       llvm::outs() << " compares NOT equals!\n";