diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp --- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -206,6 +206,7 @@ /// 2. The key associated with %2 is: /// `{ /// {NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}, /// {BLOCK_ARGUMENT, ""} /// }` /// 3. The key associated with %3 is: @@ -226,11 +227,11 @@ /// }` /// /// Thus, the sorted `foo.commutative` is: -/// %5 = foo.commutative %4, %3, %2, %1 -class SortCommutativeOperands : public RewritePattern { -public: +/// %5 = foo.commutative %4, %2, %3, %1 +struct SortCommutativeOperands final + : public OpTraitRewritePattern { SortCommutativeOperands(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} + : OpTraitRewritePattern(context, /*benefit=*/5) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Custom comparator for two commutative operands, which returns true iff @@ -258,8 +259,12 @@ unsigned keyIndex = 0; while (true) { if (commOperandA->key.size() <= keyIndex) { + // Comparator must return false for equal elements + // B is only larger if its key size is larger than the current + // index or its ancestor queue is not empty if (commOperandA->ancestorQueue.empty()) - return true; + return commOperandB->key.size() > keyIndex || + !commOperandB->ancestorQueue.empty(); commOperandA->popFrontAndPushAdjacentUnvisitedAncestors(); commOperandA->refreshKey(); } @@ -280,10 +285,6 @@ } }; - // If `op` is not commutative, do nothing. - if (!op->hasTrait()) - return failure(); - // Populate the list of commutative operands. SmallVector operands = op->getOperands(); SmallVector, 2> commOperands; diff --git a/mlir/test/Transforms/test-commutativity-utils.mlir b/mlir/test/Transforms/test-commutativity-utils.mlir --- a/mlir/test/Transforms/test-commutativity-utils.mlir +++ b/mlir/test/Transforms/test-commutativity-utils.mlir @@ -114,3 +114,26 @@ // CHECK-NEXT: return %[[RESULT]] return %result : i32 } + +// CHECK-LABEL: @test_equal_ancestor_trees +func.func @test_equal_ancestor_trees(%arg0 : i32, %arg1: i32) -> i32 { + // CHECK-NEXT: arith.addi + %0 = arith.addi %arg1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %1 = arith.addi %arg0, %arg0 : i32 + + // CHECK: %[[ARITH_MULI1:.*]] = arith.muli + %2 = arith.muli %0, %0 : i32 + + // CHECK-NEXT: %[[ARITH_MULI2:.*]] = arith.muli + %3 = arith.muli %1, %1 : i32 + + // Without additional logic to differentiate between block arguments, %2 and %3 are treated + // as equivalent, so the sorting logic keeps the original order + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative2"(%[[ARITH_MULI1]], %[[ARITH_MULI2]]) + %result = "test.op_commutative2"(%2, %3): (i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +}