diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp --- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -79,10 +79,6 @@ /// particular point in time. std::queue ancestorQueue; - /// Stores the list of ancestors that have been visited by the BFS traversal - /// at a particular point in time. - DenseSet visitedAncestors; - /// Stores the operand's "key". This "key" is defined as a list of the /// "AncestorKeys" associated with the ancestors of this operand, in a /// breadth-first order. @@ -115,11 +111,7 @@ /// Push an ancestor into the operand's BFS information structure. This /// entails it being pushed into the queue (always) and inserted into the /// "visited ancestors" list (iff it is an op rather than a block argument). - void pushAncestor(Operation *op) { - ancestorQueue.push(op); - if (op) - visitedAncestors.insert(op); - } + void pushAncestor(Operation *op) { ancestorQueue.push(op); } /// Refresh the key. /// @@ -149,8 +141,7 @@ return; for (Value operand : frontAncestor->getOperands()) { Operation *operandDefOp = operand.getDefiningOp(); - if (!operandDefOp || !visitedAncestors.contains(operandDefOp)) - pushAncestor(operandDefOp); + pushAncestor(operandDefOp); } } }; @@ -206,6 +197,7 @@ /// 2. The key associated with %2 is: /// `{ /// {NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}, /// {BLOCK_ARGUMENT, ""} /// }` /// 3. The key associated with %3 is: @@ -226,11 +218,11 @@ /// }` /// /// Thus, the sorted `foo.commutative` is: -/// %5 = foo.commutative %4, %3, %2, %1 -class SortCommutativeOperands : public RewritePattern { -public: +/// %5 = foo.commutative %4, %2, %3, %1 +struct SortCommutativeOperands final + : public OpTraitRewritePattern { SortCommutativeOperands(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} + : OpTraitRewritePattern(context, /*benefit=*/5) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Custom comparator for two commutative operands, which returns true iff @@ -269,21 +261,25 @@ commOperandB->popFrontAndPushAdjacentUnvisitedAncestors(); commOperandB->refreshKey(); } - if (commOperandA->ancestorQueue.empty() || - commOperandB->ancestorQueue.empty()) - return commOperandA->key.size() < commOperandB->key.size(); - if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) - return true; - if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) - return false; + // Try comparing the keys at the current keyIndex + if (keyIndex < commOperandA->key.size() && + keyIndex < commOperandB->key.size()) { + if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) + return true; + if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) + return false; + } else { // keyIndex exceeds one or both key sizes + // Compare key sizes if the values at every possible keyIndex were + // equal Both operands must have fully generated key and cannot + // have anything in the ancestorQueue + if (commOperandA->ancestorQueue.empty() && + commOperandB->ancestorQueue.empty()) + return commOperandA->key.size() < commOperandB->key.size(); + } keyIndex++; } }; - // If `op` is not commutative, do nothing. - if (!op->hasTrait()) - return failure(); - // Populate the list of commutative operands. SmallVector operands = op->getOperands(); SmallVector, 2> commOperands; diff --git a/mlir/test/Transforms/test-commutativity-utils.mlir b/mlir/test/Transforms/test-commutativity-utils.mlir --- a/mlir/test/Transforms/test-commutativity-utils.mlir +++ b/mlir/test/Transforms/test-commutativity-utils.mlir @@ -1,5 +1,41 @@ // RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s +// CHECK-LABEL: @example1_test +func.func @example1_test(%arg0 : i32, %arg1 : i32) -> (i32) { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 2 : i32 + + // CHECK-NEXT: %[[ARITH_MULI:.*]] = arith.muli + %1 = arith.muli %arg1, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative2"(%[[ARITH_MULI]], %[[ARITH_CONST]]) + %result = "test.op_commutative2"(%0, %1): (i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @example2_test +func.func @example2_test(%arg0 : i32, %arg1 : i32) -> (i32) { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 2 : i32 + + // CHECK-NEXT: %[[ARITH_MULI1:.*]] = arith.muli + %1 = arith.muli %arg1, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MULI2:.*]] = arith.muli + %2 = arith.muli %1, %0 : i32 + + // CHECK-NEXT: %[[ARITH_ADDI:.*]] = arith.addi + %3 = arith.addi %1, %0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADDI]], %[[ARITH_MULI1]], %[[ARITH_MULI2]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + // CHECK-LABEL: @test_small_pattern_1 func.func @test_small_pattern_1(%arg0 : i32) -> i32 { // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant @@ -114,3 +150,75 @@ // CHECK-NEXT: return %[[RESULT]] return %result : i32 } + +// CHECK-LABEL: @check_commutative_small_similar_ancestor_tree +func.func @check_commutative_small_similar_ancestor_tree(%arg0 : i32, %arg1 : i32) -> (i32, i32) { + // CHECK-NEXT: arith.addi + %0 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: arith.subi + %1 = arith.subi %arg0, %arg1 : i32 + + // CHECK-NEXT: %[[VAL1:.*]] = arith.divsi + %2 = arith.divsi %0, %0 : i32 + + // CHECK-NEXT: %[[VAL2:.*]] = arith.divsi + %3 = arith.divsi %1, %1 : i32 + + // CHECK-NEXT: %[[VAL3:.*]] = arith.divsi + %4 = arith.divsi %0, %1 : i32 + + // CHECK-NEXT: %[[RESULT1:.*]] = "test.op_commutative3"(%[[VAL1]], %[[VAL3]], %[[VAL2]]) + %result1 = "test.op_commutative3"(%2, %3, %4): (i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[RESULT2:.*]] = "test.op_commutative3"(%[[VAL1]], %[[VAL3]], %[[VAL2]]) + %result2 = "test.op_commutative3"(%4, %2, %3): (i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT1]], %[[RESULT2]] + return %result1, %result2 : i32, i32 +} + +// CHECK-LABEL: @check_commutative_large_similar_ancestor_tree +func.func @check_commutative_large_similar_ancestor_tree(%arg0 : i32, %arg1 : i32) -> (i32, i32) { + // CHECK-NEXT: arith.addi + %0 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: arith.subi + %1 = arith.subi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.muli + %2 = arith.muli %0, %0 : i32 + + // CHECK-NEXT: arith.muli + %3 = arith.muli %1, %1 : i32 + + // CHECK-NEXT: arith.muli + %4 = arith.muli %0, %1 : i32 + + // CHECK-NEXT: arith.divsi + %5 = arith.divsi %2, %3 : i32 + + // CHECK-NEXT: arith.divsi + %6 = arith.divsi %3, %4 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %2, %4 : i32 + + // CHECK-NEXT: %[[VAL1:.*]] = arith.subi + %8 = arith.subi %5, %6 : i32 + + // CHECK-NEXT: %[[VAL2:.*]] = arith.subi + %9 = arith.subi %5, %7 : i32 + + // CHECK-NEXT: %[[VAL3:.*]] = arith.subi + %10 = arith.subi %6, %7 : i32 + + // CHECK-NEXT: %[[RESULT1:.*]] = "test.op_commutative3"(%[[VAL2]], %[[VAL1]], %[[VAL3]]) + %result1 = "test.op_commutative3"(%8, %9, %10): (i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[RESULT2:.*]] = "test.op_commutative3"(%[[VAL2]], %[[VAL1]], %[[VAL3]]) + %result2 = "test.op_commutative3"(%10, %8, %9): (i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT1]], %[[RESULT2]] + return %result1, %result2 : i32, i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1278,6 +1278,11 @@ let results = (outs I32); } +def TestCommutative3Op : TEST_Op<"op_commutative3", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3); + let results = (outs I32); +} + def TestIdempotentTraitOp : TEST_Op<"op_idempotent_trait", [SameOperandsAndResultType, NoMemoryEffect, Idempotent]> {