diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -633,8 +633,18 @@ op->getName(), op->getAttrDictionary(), op->getResultTypes()); // - Operands - for (Value operand : op->getOperands()) - hash = llvm::hash_combine(hash, hashOperands(operand)); + if (op->hasTrait()) { + llvm::SmallVector vec; + for (auto i = op->operand_begin(), e = op->operand_end(); i != e; ++i) + vec.push_back((*i).getAsOpaquePointer()); + llvm::sort(vec.begin(), vec.end()); + hash = llvm::hash_combine(hash, + llvm::hash_combine_range(vec.begin(), vec.end())); + } else { + for (Value operand : op->getOperands()) + hash = llvm::hash_combine(hash, hashOperands(operand)); + } + // - Operands for (Value result : op->getResults()) hash = llvm::hash_combine(hash, hashResults(result)); diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -265,3 +265,16 @@ } return } + + +/// This test is checking that identical commutative operation are gracefully +/// handled but the CSE pass. +// CHECK-LABEL: func @check_cummutative_cse +func @check_cummutative_cse(%a : i32, %b : i32) -> i32 { + // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + %1 = arith.addi %a, %b : i32 + %2 = arith.addi %b, %a : i32 + // CHECK-NEXT: arith.muli i32 %[[ADD1]], %[[ADD1]] : i32 + %3 = arith.muli %1, %2 : i32 + return %3 : i32 +}