diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -661,19 +661,10 @@ hash = llvm::hash_combine(hash, op->getLoc()); // - Operands - ValueRange operands = op->getOperands(); - SmallVector operandStorage; - if (op->hasTrait()) { - operandStorage.append(operands.begin(), operands.end()); - llvm::sort(operandStorage, [](Value a, Value b) -> bool { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); - operands = operandStorage; - } - for (Value operand : operands) + for (Value operand : op->getOperands()) hash = llvm::hash_combine(hash, hashOperands(operand)); - // - Operands + // - Results for (Value result : op->getResults()) hash = llvm::hash_combine(hash, hashResults(result)); return hash; @@ -784,41 +775,7 @@ return false; // 2. Compare operands. - ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); - SmallVector lhsOperandStorage, rhsOperandStorage; - if (lhs->hasTrait()) { - auto sortValues = [](ValueRange values) { - SmallVector sortedValues = llvm::to_vector(values); - llvm::sort(sortedValues, [](Value a, Value b) { - auto aArg = llvm::dyn_cast(a); - auto bArg = llvm::dyn_cast(b); - - // Case 1. Both `a` and `b` are `BlockArgument`s. - if (aArg && bArg) { - if (aArg.getParentBlock() == bArg.getParentBlock()) - return aArg.getArgNumber() < bArg.getArgNumber(); - return aArg.getParentBlock() < bArg.getParentBlock(); - } - - // Case 2. One of then is a `BlockArgument` and other is not. Treat - // `BlockArgument` as lesser. - if (aArg && !bArg) - return true; - if (bArg && !aArg) - return false; - - // Case 3. Both are values. - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); - return sortedValues; - }; - lhsOperandStorage = sortValues(lhsOperands); - lhsOperands = lhsOperandStorage; - rhsOperandStorage = sortValues(rhsOperands); - rhsOperands = rhsOperandStorage; - } - - for (auto operandPair : llvm::zip(lhsOperands, rhsOperands)) { + for (auto operandPair : llvm::zip(lhs->getOperands(), rhs->getOperands())) { Value curArg = std::get<0>(operandPair); Value otherArg = std::get<1>(operandPair); if (curArg == otherArg) diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -311,18 +311,6 @@ return %2 : i32 } -/// This test is checking that identical commutative operation are gracefully -/// handled but the CSE pass. -// CHECK-LABEL: func @check_cummutative_cse -func.func @check_cummutative_cse(%a : i32, %b : i32) -> i32 { - // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 - %1 = arith.addi %a, %b : i32 - %2 = arith.addi %b, %a : i32 - // CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32 - %3 = arith.muli %1, %2 : i32 - return %3 : i32 -} - // Check that an operation with a single region can CSE. func.func @cse_single_block_ops(%a : tensor, %b : tensor) -> (tensor, tensor) { @@ -448,8 +436,8 @@ // CHECK: return %[[OP]], %[[OP]] func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) { - %false_2 = arith.constant false - %true_5 = arith.constant true + %false_2 = arith.constant false + %true_5 = arith.constant true %9 = test.cse_of_single_block_op inputs(%arg2) { ^bb0(%out: i1): %true_144 = arith.constant true