diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1197,7 +1197,11 @@ // When provided, the location attached to the operation are ignored. IgnoreLocations = 1, - LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations) + // When provided, operands on commutative operations are checked with their + // pointers. + HandleCommutativeOps = 2, + + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ HandleCommutativeOps) }; /// Compute a hash for the given operation. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -635,7 +635,8 @@ // - Operands ValueRange operands = op->getOperands(); SmallVector operandStorage; - if (op->hasTrait()) { + if (op->hasTrait() && + (flags & OperationEquivalence::HandleCommutativeOps)) { operandStorage.append(operands.begin(), operands.end()); llvm::sort(operandStorage, [](Value a, Value b) -> bool { return a.getAsOpaquePointer() < b.getAsOpaquePointer(); @@ -722,7 +723,8 @@ ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); SmallVector lhsOperandStorage, rhsOperandStorage; - if (lhs->hasTrait()) { + if (lhs->hasTrait() && + (flags & OperationEquivalence::HandleCommutativeOps)) { lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end()); llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool { return a.getAsOpaquePointer() < b.getAsOpaquePointer(); diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -32,7 +32,8 @@ const_cast(opC), /*hashOperands=*/OperationEquivalence::directHashValue, /*hashResults=*/OperationEquivalence::ignoreHashValue, - OperationEquivalence::IgnoreLocations); + OperationEquivalence::IgnoreLocations | + OperationEquivalence::HandleCommutativeOps); } static bool isEqual(const Operation *lhsC, const Operation *rhsC) { auto *lhs = const_cast(lhsC); @@ -46,7 +47,8 @@ const_cast(lhsC), const_cast(rhsC), /*mapOperands=*/OperationEquivalence::exactValueMatch, /*mapResults=*/OperationEquivalence::ignoreValueEquivalence, - OperationEquivalence::IgnoreLocations); + OperationEquivalence::IgnoreLocations | + OperationEquivalence::HandleCommutativeOps); } }; } // namespace