diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -949,6 +949,18 @@ /// Compare two operations and return if they are equivalent. static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags); + /// Compare two regions (including their subregions) and return if they are + /// equivalent. See also `isEquivalentTo` for details. + static bool isRegionEquivalentTo( + Region *lhs, Region *rhs, + function_ref checkEquivalent, + function_ref markEquivalent, + OperationEquivalence::Flags flags); + + /// Compare two regions and return if they are equivalent. + static bool isRegionEquivalentTo(Region *lhs, Region *rhs, + OperationEquivalence::Flags flags); + /// Helper that can be used with `isEquivalentTo` above to consider ops /// equivalent even if their operands are not equivalent. static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) { diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -655,11 +655,11 @@ return hash; } -static bool -isRegionEquivalentTo(Region *lhs, Region *rhs, - function_ref checkEquivalent, - function_ref markEquivalent, - OperationEquivalence::Flags flags) { +/*static*/ bool OperationEquivalence::isRegionEquivalentTo( + Region *lhs, Region *rhs, + function_ref checkEquivalent, + function_ref markEquivalent, + OperationEquivalence::Flags flags) { DenseMap blocksMap; auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) { // Check block arguments. @@ -706,7 +706,40 @@ return llvm::all_of_zip(*lhs, *rhs, blocksEquivalent); } -bool OperationEquivalence::isEquivalentTo( +// Value equivalence cache to be used with `isRegionEquivalentTo` and +// `isEquivalentTo`. +struct ValueEquivalenceCache { + DenseMap equivalentValues; + LogicalResult checkEquivalent(Value lhsValue, Value rhsValue) { + return success(lhsValue == rhsValue || + equivalentValues.lookup(lhsValue) == rhsValue); + } + void markEquivalent(Value lhsResult, Value rhsResult) { + auto insertion = equivalentValues.insert({lhsResult, rhsResult}); + // Make sure that the value was not already marked equivalent to some other + // value. + (void)insertion; + assert(insertion.first->second == rhsResult && + "inconsistent OperationEquivalence state"); + } +}; + +/*static*/ bool +OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs, + OperationEquivalence::Flags flags) { + ValueEquivalenceCache cache; + return isRegionEquivalentTo( + lhs, rhs, + [&](Value lhsValue, Value rhsValue) -> LogicalResult { + return cache.checkEquivalent(lhsValue, rhsValue); + }, + [&](Value lhsResult, Value rhsResult) { + cache.markEquivalent(lhsResult, rhsResult); + }, + flags); +} + +/*static*/ bool OperationEquivalence::isEquivalentTo( Operation *lhs, Operation *rhs, function_ref checkEquivalent, function_ref markEquivalent, Flags flags) { @@ -790,24 +823,19 @@ return true; } -bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs, - Flags flags) { - // Equivalent values in lhs and rhs. - DenseMap equivalentValues; - auto checkEquivalent = [&](Value lhsValue, Value rhsValue) -> LogicalResult { - return success(lhsValue == rhsValue || - equivalentValues.lookup(lhsValue) == rhsValue); - }; - auto markEquivalent = [&](Value lhsResult, Value rhsResult) { - auto insertion = equivalentValues.insert({lhsResult, rhsResult}); - // Make sure that the value was not already marked equivalent to some other - // value. - (void)insertion; - assert(insertion.first->second == rhsResult && - "inconsistent OperationEquivalence state"); - }; - return OperationEquivalence::isEquivalentTo(lhs, rhs, checkEquivalent, - markEquivalent, flags); +/*static*/ bool OperationEquivalence::isEquivalentTo(Operation *lhs, + Operation *rhs, + Flags flags) { + ValueEquivalenceCache cache; + return OperationEquivalence::isEquivalentTo( + lhs, rhs, + [&](Value lhsValue, Value rhsValue) -> LogicalResult { + return cache.checkEquivalent(lhsValue, rhsValue); + }, + [&](Value lhsResult, Value rhsResult) { + cache.markEquivalent(lhsResult, rhsResult); + }, + flags); } //===----------------------------------------------------------------------===//