diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -814,6 +814,20 @@ /// Allow access to `offset_base` and `dereference_iterator`. friend RangeBaseT; }; + +//===----------------------------------------------------------------------===// +// Operation Equivalency +//===----------------------------------------------------------------------===// + +/// This class provides utilities for computing if two operations are +/// equivalent. +struct OperationEquivalence { + /// Compute a hash for the given operation. + static llvm::hash_code computeHash(Operation *op); + + /// Compare two operations and return if they are equivalent. + static bool isEquivalentTo(Operation *lhs, Operation *rhs); +}; } // end namespace mlir namespace llvm { diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -395,3 +395,78 @@ Operation *operation = reinterpret_cast(owner.ptr.get()); return operation->getResult(owner.startIndex + index); } + +//===----------------------------------------------------------------------===// +// Operation Equivalency +//===----------------------------------------------------------------------===// + +llvm::hash_code OperationEquivalence::computeHash(Operation *op) { + // Hash operations based upon their: + // - Operation Name + // - Attributes + llvm::hash_code hash = llvm::hash_combine( + op->getName(), op->getMutableAttrDict().getDictionary()); + + // - Result Types + ArrayRef resultTypes = op->getResultTypes(); + switch (resultTypes.size()) { + case 0: + // We don't need to add anything to the hash. + break; + case 1: + // Add in the result type. + hash = llvm::hash_combine(hash, resultTypes.front()); + break; + default: + // Use the type buffer as the hash, as we can guarantee it is the same for + // any given range of result types. This takes advantage of the fact the + // result types >1 are stored in a TupleType and uniqued. + hash = llvm::hash_combine(hash, resultTypes.data()); + break; + } + + // - Operands + // TODO: Allow commutative operations to have different ordering. + return llvm::hash_combine( + hash, llvm::hash_combine_range(op->operand_begin(), op->operand_end())); +} + +bool OperationEquivalence::isEquivalentTo(Operation *lhs, Operation *rhs) { + if (lhs == rhs) + return true; + + // Compare the operation name. + if (lhs->getName() != rhs->getName()) + return false; + // Check operand counts. + if (lhs->getNumOperands() != rhs->getNumOperands()) + return false; + // Compare attributes. + if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict()) + return false; + // Compare result types. + ArrayRef lhsResultTypes = lhs->getResultTypes(); + ArrayRef rhsResultTypes = rhs->getResultTypes(); + if (lhsResultTypes.size() != rhsResultTypes.size()) + return false; + switch (lhsResultTypes.size()) { + case 0: + break; + case 1: + // Compare the single result type. + if (lhsResultTypes.front() != rhsResultTypes.front()) + return false; + break; + default: + // Use the type buffer for the comparison, as we can guarantee it is the + // same for any given range of result types. This takes advantage of the + // fact the result types >1 are stored in a TupleType and uniqued. + if (lhsResultTypes.data() != rhsResultTypes.data()) + return false; + break; + } + // Compare operands. + // TODO: Allow commutative operations to have different ordering. + return std::equal(lhs->operand_begin(), lhs->operand_end(), + rhs->operand_begin()); +} diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -26,19 +26,9 @@ using namespace mlir; namespace { -// TODO(riverriddle) Handle commutative operations. struct SimpleOperationInfo : public llvm::DenseMapInfo { static unsigned getHashValue(const Operation *opC) { - auto *op = const_cast(opC); - // Hash the operations based upon their: - // - Operation Name - // - Attributes - // - Result Types - // - Operands - return llvm::hash_combine( - op->getName(), op->getMutableAttrDict().getDictionary(), - op->getResultTypes(), - llvm::hash_combine_range(op->operand_begin(), op->operand_end())); + return OperationEquivalence::computeHash(const_cast(opC)); } static bool isEqual(const Operation *lhsC, const Operation *rhsC) { auto *lhs = const_cast(lhsC); @@ -48,24 +38,8 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; - - // Compare the operation name. - if (lhs->getName() != rhs->getName()) - return false; - // Check operand and result type counts. - if (lhs->getNumOperands() != rhs->getNumOperands() || - lhs->getNumResults() != rhs->getNumResults()) - return false; - // Compare attributes. - if (lhs->getMutableAttrDict() != rhs->getMutableAttrDict()) - return false; - // Compare operands. - if (!std::equal(lhs->operand_begin(), lhs->operand_end(), - rhs->operand_begin())) - return false; - // Compare result types. - return std::equal(lhs->result_type_begin(), lhs->result_type_end(), - rhs->result_type_begin()); + return OperationEquivalence::isEquivalentTo(const_cast(lhsC), + const_cast(rhsC)); } }; } // end anonymous namespace