diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -924,23 +924,32 @@ static llvm::hash_code directHashValue(Value v) { return hash_value(v); } /// Compare two operations and return if they are equivalent. - /// `mapOperands` and `mapResults` are optional callbacks that allows the - /// caller to check the mapping of SSA value between the lhs and rhs - /// operations. It is expected to return success if the mapping is valid and - /// failure if it conflicts with a previous mapping. + /// + /// `checkEquivalent` is a callback to check if two values are equivalent. + /// `markEquivalent` is a callback to inform the caller that the analysis + /// determined that two values are equivalent. + /// + /// Note: Additional information regarding value equivalence can be injected + /// into the analysis via `checkEquivalent`. Typically, callers may want + /// values that were determined to be equivalent as per `markEquivalent` to be + /// reflected in `checkEquivalent`, unless `exactValueMatch` or a different + /// equivalence relationship is desired. static bool isEquivalentTo(Operation *lhs, Operation *rhs, - function_ref mapOperands, - function_ref mapResults, + function_ref checkEquivalent, + function_ref markEquivalent = nullptr, Flags flags = Flags::None); - /// Helper that can be used with `isEquivalentTo` above to ignore operation - /// operands/result mapping. + /// Compare two operations and return if they are equivalent. + static bool isEquivalentTo(Operation *lhs, Operation *rhs, Flags flags); + + /// Helper that can be used with `isEquivalentTo` above to consider ops + /// equivalent even if their operands are not equivalent. static LogicalResult ignoreValueEquivalence(Value lhs, Value rhs) { return success(); } - /// Helper that can be used with `isEquivalentTo` above to ignore operation - /// operands/result mapping. + /// Helper that can be used with `isEquivalentTo` above to consider ops + /// equivalent only if their operands are the exact same SSA values. static LogicalResult exactValueMatch(Value lhs, Value rhs) { return success(lhs == rhs); } diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -652,8 +652,8 @@ static bool isRegionEquivalentTo(Region *lhs, Region *rhs, - function_ref mapOperands, - function_ref mapResults, + function_ref checkEquivalent, + function_ref markEquivalent, OperationEquivalence::Flags flags) { DenseMap blocksMap; auto blocksEquivalent = [&](Block &lBlock, Block &rBlock) { @@ -675,15 +675,14 @@ if (!(flags & OperationEquivalence::IgnoreLocations) && curArg.getLoc() != otherArg.getLoc()) return false; - // Check if this value was already mapped to another value. - if (failed(mapOperands(curArg, otherArg))) - return false; + // Corresponding bbArgs are equivalent. + markEquivalent(curArg, otherArg); } auto opsEquivalent = [&](Operation &lOp, Operation &rOp) { // Check for op equality (recursively). - if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, mapOperands, - mapResults, flags)) + if (!OperationEquivalence::isEquivalentTo(&lOp, &rOp, checkEquivalent, + markEquivalent, flags)) return false; // Check successor mapping. for (auto successorsPair : @@ -703,12 +702,12 @@ bool OperationEquivalence::isEquivalentTo( Operation *lhs, Operation *rhs, - function_ref mapOperands, - function_ref mapResults, Flags flags) { + function_ref checkEquivalent, + function_ref markEquivalent, Flags flags) { if (lhs == rhs) return true; - // Compare the operation properties. + // 1. Compare the operation properties. if (lhs->getName() != rhs->getName() || lhs->getAttrDictionary() != rhs->getAttrDictionary() || lhs->getNumRegions() != rhs->getNumRegions() || @@ -719,6 +718,7 @@ if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; + // 2. Compare operands. ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); SmallVector lhsOperandStorage, rhsOperandStorage; if (lhs->hasTrait()) { @@ -752,32 +752,55 @@ rhsOperandStorage = sortValues(rhsOperands); rhsOperands = rhsOperandStorage; } - auto checkValueRangeMapping = - [](ValueRange lhs, ValueRange rhs, - function_ref mapValues) { - for (auto operandPair : llvm::zip(lhs, rhs)) { - Value curArg = std::get<0>(operandPair); - Value otherArg = std::get<1>(operandPair); - if (curArg.getType() != otherArg.getType()) - return false; - if (failed(mapValues(curArg, otherArg))) - return false; - } - return true; - }; - // Check mapping of operands and results. - if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands)) - return false; - if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults)) - return false; + + for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) { + Value curArg = std::get<0>(operandPair); + Value otherArg = std::get<1>(operandPair); + if (curArg.getType() != otherArg.getType()) + return false; + if (failed(checkEquivalent(curArg, otherArg))) + return false; + } + + // 3. Compare result types and mark results as equivalent. + for (auto resultPair : llvm::zip(lhs->getResults(), rhs->getResults())) { + Value curArg = std::get<0>(resultPair); + Value otherArg = std::get<1>(resultPair); + if (curArg.getType() != otherArg.getType()) + return false; + markEquivalent(curArg, otherArg); + } + + // 4. Compare regions. for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) if (!isRegionEquivalentTo(&std::get<0>(regionPair), - &std::get<1>(regionPair), mapOperands, mapResults, - flags)) + &std::get<1>(regionPair), checkEquivalent, + markEquivalent, flags)) return false; + return true; } +bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs, + Flags flags) { + // Equivalent values in lhs and rhs. + DenseMap equivalentValues; + auto checkEquivalent = [&](Value lhsValue, Value rhsValue) -> LogicalResult { + return success(lhsValue == rhsValue || + equivalentValues.lookup(lhsValue) == rhsValue); + }; + auto markEquivalent = [&](Value lhsResult, Value rhsResult) { + auto insertion = equivalentValues.insert({lhsResult, rhsResult}); + // Make sure that the value was not already marked equivalent to some other + // value. + (void)insertion; + assert(insertion.first->second == rhsResult && + "inconsistent OperationEquivalence state"); + }; + return OperationEquivalence::isEquivalentTo(lhs, rhs, checkEquivalent, + markEquivalent, flags); +} + //===----------------------------------------------------------------------===// // OperationFingerPrint //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -47,70 +47,9 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; - - // If op has no regions, operation equivalence w.r.t operands alone is - // enough. - if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) { - return OperationEquivalence::isEquivalentTo( - const_cast(lhsC), const_cast(rhsC), - OperationEquivalence::exactValueMatch, - OperationEquivalence::ignoreValueEquivalence, - OperationEquivalence::IgnoreLocations); - } - - // If lhs or rhs does not have a single region with a single block, they - // aren't CSEed for now. - if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 || - !llvm::hasSingleElement(lhs->getRegion(0)) || - !llvm::hasSingleElement(rhs->getRegion(0))) - return false; - - // Compare the two blocks. - Block &lhsBlock = lhs->getRegion(0).front(); - Block &rhsBlock = rhs->getRegion(0).front(); - - // Don't CSE if number of arguments differ. - if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) - return false; - - // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in - // `rhsBlock`. `Value`s from `lhsBlock` are the key. - DenseMap areEquivalentValues; - for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(), - rhs->getRegion(0).getArguments())) { - areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs); - } - - // Helper function to get the parent operation. - auto getParent = [](Value v) -> Operation * { - if (auto blockArg = v.dyn_cast()) - return blockArg.getParentBlock()->getParentOp(); - return v.getDefiningOp()->getParentOp(); - }; - - // Callback to compare if operands of ops in the region of `lhs` and `rhs` - // are equivalent. - auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult { - if (lhsValue == rhsValue) - return success(); - if (areEquivalentValues.lookup(lhsValue) == rhsValue) - return success(); - return failure(); - }; - - // Callback to compare if results of ops in the region of `lhs` and `rhs` - // are equivalent. - auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult { - if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) { - auto insertion = areEquivalentValues.insert({lhsResult, rhsResult}); - return success(insertion.first->second == rhsResult); - } - return success(); - }; - return OperationEquivalence::isEquivalentTo( const_cast(lhsC), const_cast(rhsC), - mapOperands, mapResults, OperationEquivalence::IgnoreLocations); + OperationEquivalence::IgnoreLocations); } }; } // namespace diff --git a/mlir/test/lib/IR/TestOperationEquals.cpp b/mlir/test/lib/IR/TestOperationEquals.cpp --- a/mlir/test/lib/IR/TestOperationEquals.cpp +++ b/mlir/test/lib/IR/TestOperationEquals.cpp @@ -28,11 +28,6 @@ << opCount; return signalPassFailure(); } - DenseMap valuesMap; - auto mapValue = [&](Value lhs, Value rhs) { - auto insertion = valuesMap.insert({lhs, rhs}); - return success(insertion.first->second == rhs); - }; Operation *first = &module.getBody()->front(); llvm::outs() << first->getName().getStringRef() << " with attr " @@ -41,7 +36,7 @@ if (!first->hasAttr("strict_loc_check")) flags |= OperationEquivalence::IgnoreLocations; if (OperationEquivalence::isEquivalentTo(first, &module.getBody()->back(), - mapValue, mapValue, flags)) + flags)) llvm::outs() << " compares equals.\n"; else llvm::outs() << " compares NOT equals!\n";