Index: mlir/include/mlir/IR/OperationSupport.h =================================================================== --- mlir/include/mlir/IR/OperationSupport.h +++ mlir/include/mlir/IR/OperationSupport.h @@ -1194,16 +1194,31 @@ }; /// Compute a hash for the given operation. - /// The `hashOperands` and `hashResults` callbacks are expected to return a - /// unique hash_code for a given Value. + /// The `hashOp` is a callback to compute a hash for structural properties of + /// `op` such as op name, result types and attributes. The `hashOperands` and + /// `hashResults` callbacks are expected to return a unique hash_code for a + /// given Value. static llvm::hash_code computeHash( - Operation *op, + Operation *op, function_ref hashOp, function_ref hashOperands = [](Value v) { return hash_value(v); }, function_ref hashResults = [](Value v) { return hash_value(v); }, Flags flags = Flags::None); + static llvm::hash_code computeHash( + Operation *op, + function_ref hashOperands = + [](Value v) { return hash_value(v); }, + function_ref hashResults = + [](Value v) { return hash_value(v); }, + Flags flags = Flags::None) { + return computeHash(op, simpleOpHash, hashOperands, hashResults, flags); + } + + /// Helper that can be used with `computeHash` above to combine hashes for + /// basic structural properties. + static llvm::hash_code simpleOpHash(Operation *op); /// Helper that can be used with `computeHash` above to ignore operation /// operands/result mapping. static llvm::hash_code ignoreHashValue(Value) { return llvm::hash_code{}; } @@ -1213,38 +1228,51 @@ /// Compare two operations (including their regions) and return if they are /// equivalent. - /// - /// * `checkEquivalent` is a callback to check if two values are equivalent. + /// * `checkOpStructureEquivalent` is a callback to check if the structures of + /// two operations are equivalent. + /// * `checkValueEquivalent` is a callback to check if two values are + /// equivalent. /// For two operations to be equivalent, their operands must be the same SSA /// value or this callback must return `success`. /// * `markEquivalent` is a callback to inform the caller that the analysis /// determined that two values are equivalent. /// /// Note: Additional information regarding value equivalence can be injected - /// into the analysis via `checkEquivalent`. Typically, callers may want - /// values that were determined to be equivalent as per `markEquivalent` to be - /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different + /// into the analysis via `checkOpStructureEquivalent` and + /// `checkValueEquivalent`. Typically, callers may want values that were + /// determined to be equivalent as per `markEquivalent` to be reflected in + /// `checkValueEquivalent`, unless `exactValueMatch` or a different /// equivalence relationship is desired. static bool isEquivalentTo(Operation *lhs, Operation *rhs, - function_ref checkEquivalent, + function_ref + checkOpStructureEquivalent, + function_ref checkValueEquivalent, function_ref markEquivalent = nullptr, Flags flags = Flags::None); /// Compare two operations and return if they are equivalent. static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags); + static bool + isEquivalentTo(Operation *lhs, Operation *rhs, + function_ref + checkOpStructureEquivalent, + Flags flags); /// Compare two regions (including their subregions) and return if they are /// equivalent. See also `isEquivalentTo` for details. static bool isRegionEquivalentTo( Region *lhs, Region *rhs, - function_ref checkEquivalent, + function_ref + checkOpStructureEquivalent, + function_ref checkValueEquivalent, function_ref markEquivalent, OperationEquivalence::Flags flags); /// Compare two regions and return if they are equivalent. static bool isRegionEquivalentTo(Region *lhs, Region *rhs, OperationEquivalence::Flags flags); + static LogicalResult simpleOpEquivalence(Operation *lhs, Operation *rhs); /// Helper that can be used with `isEquivalentTo` above to consider ops /// equivalent even if their operands are not equivalent. Index: mlir/lib/IR/OperationSupport.cpp =================================================================== --- mlir/lib/IR/OperationSupport.cpp +++ mlir/lib/IR/OperationSupport.cpp @@ -646,15 +646,11 @@ //===----------------------------------------------------------------------===// llvm::hash_code OperationEquivalence::computeHash( - Operation *op, function_ref hashOperands, + Operation *op, function_ref hashOp, + function_ref hashOperands, function_ref hashResults, Flags flags) { - // Hash operations based upon their: - // - Operation Name - // - Attributes - // - Result Types - llvm::hash_code hash = - llvm::hash_combine(op->getName(), op->getDiscardableAttrDictionary(), - op->getResultTypes(), op->hashProperties()); + // Hash operations based upon their structural properties using `hashOp`. + llvm::hash_code hash = hashOp(op); // - Location if required if (!(flags & Flags::IgnoreLocations)) @@ -672,7 +668,9 @@ /*static*/ bool OperationEquivalence::isRegionEquivalentTo( Region *lhs, Region *rhs, - function_ref checkEquivalent, + function_ref + checkOpStructureEquivalent, + function_ref checkValueEquivalent, function_ref markEquivalent, OperationEquivalence::Flags flags) { DenseMap blocksMap; @@ -702,8 +700,9 @@ auto opsEquivalent = [&](Operation &lOp, Operation &rOp) { // Check for op equality (recursively). - if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent, - markEquivalent, flags)) + if (!OperationEquivalence::isEquivalentTo( + &lOp, &rOp, checkOpStructureEquivalent, checkValueEquivalent, + markEquivalent, flags)) return false; // Check successor mapping. for (auto successorsPair : @@ -744,7 +743,7 @@ OperationEquivalence::Flags flags) { ValueEquivalenceCache cache; return isRegionEquivalentTo( - lhs, rhs, + lhs, rhs, simpleOpEquivalence, [&](Value lhsValue, Value rhsValue) -> LogicalResult { return cache.checkEquivalent(lhsValue, rhsValue); }, @@ -754,23 +753,38 @@ flags); } +/*static*/ llvm::hash_code OperationEquivalence::simpleOpHash(Operation *op) { + return llvm::hash_combine(op->getName(), op->getResultTypes(), + op->hashProperties(), + op->getDiscardableAttrDictionary()); +} + +/*static*/ LogicalResult +OperationEquivalence::simpleOpEquivalence(Operation *lhs, Operation *rhs) { + return LogicalResult::success( + lhs->getName() == rhs->getName() && + lhs->getDiscardableAttrDictionary() == + rhs->getDiscardableAttrDictionary() && + lhs->getNumRegions() == rhs->getNumRegions() && + lhs->getNumSuccessors() == rhs->getNumSuccessors() && + lhs->getNumOperands() == rhs->getNumOperands() && + lhs->getNumResults() == rhs->getNumResults() && + lhs->hashProperties() == rhs->hashProperties()); +} + /*static*/ bool OperationEquivalence::isEquivalentTo( Operation *lhs, Operation *rhs, - function_ref checkEquivalent, + function_ref + checkOpStructureEquivalent, + function_ref checkValueEquivalent, function_ref markEquivalent, Flags flags) { if (lhs == rhs) return true; - // 1. Compare the operation properties. - if (lhs->getName() != rhs->getName() || - lhs->getDiscardableAttrDictionary() != - rhs->getDiscardableAttrDictionary() || - lhs->getNumRegions() != rhs->getNumRegions() || - lhs->getNumSuccessors() != rhs->getNumSuccessors() || - lhs->getNumOperands() != rhs->getNumOperands() || - lhs->getNumResults() != rhs->getNumResults() || - lhs->hashProperties() != rhs->hashProperties()) + // 1. Compare the operation structural properties. + if (failed(checkOpStructureEquivalent(lhs, rhs))) return false; + if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; @@ -782,7 +796,7 @@ continue; if (curArg.getType() != otherArg.getType()) return false; - if (failed(checkEquivalent(curArg, otherArg))) + if (failed(checkValueEquivalent(curArg, otherArg))) return false; } @@ -799,7 +813,8 @@ // 4. Compare regions. for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) if (!isRegionEquivalentTo(&std::get<0>(regionPair), - &std::get<1>(regionPair), checkEquivalent, + &std::get<1>(regionPair), + checkOpStructureEquivalent, checkValueEquivalent, markEquivalent, flags)) return false; @@ -811,7 +826,24 @@ Flags flags) { ValueEquivalenceCache cache; return OperationEquivalence::isEquivalentTo( - lhs, rhs, + lhs, rhs, simpleOpEquivalence, + [&](Value lhsValue, Value rhsValue) -> LogicalResult { + return cache.checkEquivalent(lhsValue, rhsValue); + }, + [&](Value lhsResult, Value rhsResult) { + cache.markEquivalent(lhsResult, rhsResult); + }, + flags); +} + +/*static*/ bool OperationEquivalence::isEquivalentTo( + Operation *lhs, Operation *rhs, + function_ref + checkOpStructureEquivalent, + Flags flags) { + ValueEquivalenceCache cache; + return OperationEquivalence::isEquivalentTo( + lhs, rhs, checkOpStructureEquivalent, [&](Value lhsValue, Value rhsValue) -> LogicalResult { return cache.checkEquivalent(lhsValue, rhsValue); }, Index: mlir/lib/Transforms/Utils/RegionUtils.cpp =================================================================== --- mlir/lib/Transforms/Utils/RegionUtils.cpp +++ mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -592,7 +592,8 @@ for (int opI = 0; lhsIt != lhsE && rhsIt != rhsE; ++lhsIt, ++rhsIt, ++opI) { // Check that the operations are equivalent. if (!OperationEquivalence::isEquivalentTo( - &*lhsIt, &*rhsIt, OperationEquivalence::ignoreValueEquivalence, + &*lhsIt, &*rhsIt, OperationEquivalence::simpleOpEquivalence, + OperationEquivalence::ignoreValueEquivalence, /*markEquivalent=*/nullptr, OperationEquivalence::Flags::IgnoreLocations)) return failure();