diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -633,8 +633,18 @@ op->getName(), op->getAttrDictionary(), op->getResultTypes()); // - Operands - for (Value operand : op->getOperands()) + ValueRange operands = op->getOperands(); + SmallVector operandStorage; + if (op->hasTrait()) { + operandStorage.append(operands.begin(), operands.end()); + llvm::sort(operandStorage, [](Value a, Value b) -> bool { + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + }); + operands = operandStorage; + } + for (Value operand : operands) hash = llvm::hash_combine(hash, hashOperands(operand)); + // - Operands for (Value result : op->getResults()) hash = llvm::hash_combine(hash, hashResults(result)); @@ -710,6 +720,21 @@ if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; + ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); + SmallVector lhsOperandStorage, rhsOperandStorage; + if (lhs->hasTrait()) { + lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end()); + llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool { + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + }); + lhsOperands = lhsOperandStorage; + + rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end()); + llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool { + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + }); + rhsOperands = rhsOperandStorage; + } auto checkValueRangeMapping = [](ValueRange lhs, ValueRange rhs, function_ref mapValues) { @@ -724,8 +749,7 @@ return true; }; // Check mapping of operands and results. - if (!checkValueRangeMapping(lhs->getOperands(), rhs->getOperands(), - mapOperands)) + if (!checkValueRangeMapping(lhsOperands, rhsOperands, mapOperands)) return false; if (!checkValueRangeMapping(lhs->getResults(), rhs->getResults(), mapResults)) return false; diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -310,3 +310,15 @@ %2 = arith.addi %0, %1 : i32 return %2 : i32 } + +/// This test is checking that identical commutative operation are gracefully +/// handled but the CSE pass. +// CHECK-LABEL: func @check_cummutative_cse +func @check_cummutative_cse(%a : i32, %b : i32) -> i32 { + // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + %1 = arith.addi %a, %b : i32 + %2 = arith.addi %b, %a : i32 + // CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32 + %3 = arith.muli %1, %2 : i32 + return %3 : i32 +}