diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -814,20 +814,6 @@ /// Allow access to `offset_base` and `dereference_iterator`. friend RangeBaseT; }; - -//===----------------------------------------------------------------------===// -// Operation Equivalency -//===----------------------------------------------------------------------===// - -/// This class provides utilities for computing if two operations are -/// equivalent. -struct OperationEquivalence { - /// Compute a hash for the given operation. - static llvm::hash_code computeHash(Operation *op); - - /// Compare two operations and return if they are equivalent. - static bool isEquivalentTo(Operation *lhs, Operation *rhs); -}; } // end namespace mlir namespace llvm { diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -395,78 +395,3 @@ Operation *operation = reinterpret_cast(owner.ptr.get()); return operation->getResult(owner.startIndex + index); } - -//===----------------------------------------------------------------------===// -// Operation Equivalency -//===----------------------------------------------------------------------===// - -llvm::hash_code OperationEquivalence::computeHash(Operation *op) { - // Hash operations based upon their: - // - Operation Name - // - Attributes - llvm::hash_code hash = llvm::hash_combine( - op->getName(), op->getMutableAttrDict().getDictionary()); - - // - Result Types - ArrayRef resultTypes = op->getResultTypes(); - switch (resultTypes.size()) { - case 0: - // We don't need to add anything to the hash. - break; - case 1: - // Add in the result type. - hash = llvm::hash_combine(hash, resultTypes.front()); - break; - default: - // Use the type buffer as the hash, as we can guarantee it is the same for - // any given range of result types. This takes advantage of the fact the - // result types >1 are stored in a TupleType and uniqued. - hash = llvm::hash_combine(hash, resultTypes.data()); - break; - } - - // - Operands - // TODO: Allow commutative operations to have different ordering. - return llvm::hash_combine( - hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); -} - -bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) { - if (lhs == rhs) - return true; - - // Compare the operation name. - if (lhs->getName() != rhs->getName()) - return false; - // Check operand counts. - if (lhs->getNumOperands() != rhs->getNumOperands()) - return false; - // Compare attributes. - if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict()) - return false; - // Compare result types. - ArrayRef lhsResultTypes = lhs->getResultTypes(); - ArrayRef rhsResultTypes = rhs->getResultTypes(); - if (lhsResultTypes.size() != rhsResultTypes.size()) - return false; - switch (lhsResultTypes.size()) { - case 0: - break; - case 1: - // Compare the single result type. - if (lhsResultTypes.front() != rhsResultTypes.front()) - return false; - break; - default: - // Use the type buffer for the comparison, as we can guarantee it is the - // same for any given range of result types. This takes advantage of the - // fact the result types >1 are stored in a TupleType and uniqued. - if (lhsResultTypes.data() != rhsResultTypes.data()) - return false; - break; - } - // Compare operands. - // TODO: Allow commutative operations to have different ordering. - return std::equal(lhs->operand_begin(), lhs->operand_end(), - rhs->operand_begin()); -} diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -26,9 +26,19 @@ using namespace mlir; namespace { +// TODO(riverriddle) Handle commutative operations. struct SimpleOperationInfo : public llvm::DenseMapInfo { static unsigned getHashValue(const Operation *opC) { - return OperationEquivalence::computeHash(const_cast(opC)); + auto *op = const_cast(opC); + // Hash the operations based upon their: + // - Operation Name + // - Attributes + // - Result Types + // - Operands + return llvm::hash_combine( + op->getName(), op->getMutableAttrDict().getDictionary(), + op->getResultTypes(), + llvm::hash_combine_range(op->operand_begin(), op->operand_end())); } static bool isEqual(const Operation *lhsC, const Operation *rhsC) { auto *lhs = const_cast(lhsC); @@ -38,8 +48,24 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; - return OperationEquivalence::isEquivalentTo(const_cast(lhsC), - const_cast(rhsC)); + + // Compare the operation name. + if (lhs->getName() != rhs->getName()) + return false; + // Check operand and result type counts. + if (lhs->getNumOperands() != rhs->getNumOperands() || + lhs->getNumResults() != rhs->getNumResults()) + return false; + // Compare attributes. + if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict()) + return false; + // Compare operands. + if (!std::equal(lhs->operand_begin(), lhs->operand_end(), + rhs->operand_begin())) + return false; + // Compare result types. + return std::equal(lhs->result_type_begin(), lhs->result_type_end(), + rhs->result_type_begin()); } }; } // end anonymous namespace