diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -134,20 +134,22 @@ SmallVector operandConstants; SmallVector foldResults; + // If this is a commutative operation, move constants to be trailing operands. + if (op->getNumOperands() >= 2 && op->isCommutative()) { + std::stable_partition(op->getOpOperands().begin(), + op->getOpOperands().end(), [&](OpOperand &O) { + Attribute c; + matchPattern(O.get(), m_Constant(&c)); + return (bool)!c; + }); + } + // Check to see if any operands to the operation is constant and whether // the operation knows how to constant fold itself. operandConstants.assign(op->getNumOperands(), Attribute()); for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); - // If this is a commutative binary operation with a constant on the left - // side move it to the right side. - if (operandConstants.size() == 2 && operandConstants[0] && - !operandConstants[1] && op->isCommutative()) { - std::swap(op->getOpOperand(0), op->getOpOperand(1)); - std::swap(operandConstants[0], operandConstants[1]); - } - // Attempt to constant fold the operation. if (failed(op->fold(operandConstants, foldResults))) return failure(); diff --git a/mlir/test/Transforms/test-canonicalize.mlir b/mlir/test/Transforms/test-canonicalize.mlir --- a/mlir/test/Transforms/test-canonicalize.mlir +++ b/mlir/test/Transforms/test-canonicalize.mlir @@ -35,3 +35,18 @@ %0, %1 = "test.op_with_variadic_results_and_folder"(%arg0, %arg1) : (i32, i32) -> (i32, i32) return %0, %1 : i32, i32 } + +// CHECK-LABEL: func @test_commutative_multi +func @test_commutative_multi(%arg0: i32, %arg1: i32) -> (i32, i32) { + // CHECK: %c42_i32 = constant 42 : i32 + %c42_i32 = constant 42 : i32 + // CHECK: %c43_i32 = constant 43 : i32 + %c43_i32 = constant 43 : i32 + // CHECK-NEXT: %0 = "test.op_commutative"(%arg0, %arg1, %c42_i32, %c43_i32) : (i32, i32, i32, i32) -> i32 + %y = "test.op_commutative"(%c42_i32, %arg0, %arg1, %c43_i32) : (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %1 = "test.op_commutative"(%arg0, %arg1, %c42_i32, %c43_i32) : (i32, i32, i32, i32) -> i32 + %z = "test.op_commutative"(%arg0, %c42_i32, %c43_i32, %arg1): (i32, i32, i32, i32) -> i32 + // CHECK-NEXT: return %0, %1 + return %y, %z: i32, i32 +} diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -639,6 +639,11 @@ let hasFolder = 1; } +def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); + let results = (outs I32); +} + //===----------------------------------------------------------------------===// // Test Patterns (Symbol Binding)