diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -633,8 +633,18 @@ op->getName(), op->getAttrDictionary(), op->getResultTypes()); // - Operands - for (Value operand : op->getOperands()) - hash = llvm::hash_combine(hash, hashOperands(operand)); + if (op->hasTrait()) { + llvm::SmallVector vec; + for (auto i = op->operand_begin(), e = op->operand_end(); i != e; ++i) + vec.push_back((*i).getAsOpaquePointer()); + llvm::sort(vec.begin(), vec.end()); + hash = llvm::hash_combine(hash, + llvm::hash_combine_range(vec.begin(), vec.end())); + } else { + for (Value operand : op->getOperands()) + hash = llvm::hash_combine(hash, hashOperands(operand)); + } + // - Operands for (Value result : op->getResults()) hash = llvm::hash_combine(hash, hashResults(result)); @@ -723,10 +733,23 @@ } return true; }; - // Check mapping of operands and results. - if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(), - mapOperands)) - return false; + if (lhs->hasTrait()) { + SmallVector lops; + for (auto lod : lhs->getOperands()) + lops.push_back(lod.getAsOpaquePointer()); + llvm::sort(lops.begin(), lops.end()); + SmallVector rops; + for (auto rod : rhs->getOperands()) + rops.push_back(rod.getAsOpaquePointer()); + llvm::sort(rops.begin(), rops.end()); + if (!std::equal(lops.begin(), lops.end(), rops.begin())) + return false; + } else { + // Check mapping of operands and results. + if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(), + mapOperands)) + return false; + } if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults)) return false; for (auto regionPair : llvm::zip(lhs->getRegions(), rhs->getRegions())) diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -310,3 +310,15 @@ %2 = arith.addi %0, %1 : i32 return %2 : i32 } + +/// This test is checking that identical commutative operation are gracefully +/// handled but the CSE pass. +// CHECK-LABEL: func @check_cummutative_cse +func @check_cummutative_cse(%a : i32, %b : i32) -> i32 { + // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + %1 = arith.addi %a, %b : i32 + %2 = arith.addi %b, %a : i32 + // CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32 + %3 = arith.muli %1, %2 : i32 + return %3 : i32 +}