diff --git a/mlir/include/mlir/Transforms/CommutativityUtils.h b/mlir/include/mlir/Transforms/CommutativityUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/CommutativityUtils.h @@ -0,0 +1,27 @@ +//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares a function to populate the commutativity utility +// pattern. This function is intended to be used inside passes to simplify the +// matching of commutative operations by fixing the order of their operands. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H +#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// Populates the commutativity utility patterns. +void populateCommutativityUtilsPatterns(RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + CommutativityUtils.cpp ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,317 @@ +//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a commutativity utility pattern and a function to +// populate this pattern. The function is intended to be used inside passes to +// simplify the matching of commutative operations by fixing the order of their +// operands. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include + +using namespace mlir; + +/// The possible "types" of ancestors. Here, an ancestor is an op or a block +/// argument present in the backward slice of a value. +enum AncestorType { + /// Pertains to a block argument. + BLOCK_ARGUMENT, + + /// Pertains to a non-constant-like op. + NON_CONSTANT_OP, + + /// Pertains to a constant-like op. + CONSTANT_OP +}; + +/// Stores the "key" associated with an ancestor. +struct AncestorKey { + /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on + /// the ancestor. + AncestorType type; + + /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or + /// `CONSTANT_OP`. Else, holds "". + StringRef opName; + + /// Constructor for `AncestorKey`. + AncestorKey(Operation *op) { + if (!op) { + type = BLOCK_ARGUMENT; + } else { + type = + op->hasTrait() ? CONSTANT_OP : NON_CONSTANT_OP; + opName = op->getName().getStringRef(); + } + } + + /// Overloaded operator `<` for `AncestorKey`. + /// + /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those + /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in + /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller + /// ones are the ones with smaller op names (lexicographically). + /// + /// TODO: Include other information like attributes, value type, etc., to + /// enhance this comparison. For example, currently this comparison doesn't + /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and + /// `addi (in i64)`. Such an enhancement should only be done if the need + /// arises. + bool operator<(const AncestorKey &key) const { + return std::tie(type, opName) < std::tie(key.type, key.opName); + } +}; + +/// Stores a commutative operand along with its BFS traversal information. +struct CommutativeOperand { + /// Stores the operand. + Value operand; + + /// Stores the queue of ancestors of the operand's BFS traversal at a + /// particular point in time. + std::queue ancestorQueue; + + /// Stores the list of ancestors that have been visited by the BFS traversal + /// at a particular point in time. + DenseSet visitedAncestors; + + /// Stores the operand's "key". This "key" is defined as a list of the + /// "AncestorKeys" associated with the ancestors of this operand, in a + /// breadth-first order. + /// + /// So, if an operand, say `A`, was produced as follows: + /// + /// `` `` + /// \ / + /// \ / + /// `arith.subi` `arith.constant` + /// \ / + /// `arith.addi` + /// | + /// returns `A` + /// + /// Then, the ancestors of `A`, in the breadth-first order are: + /// `arith.addi`, `arith.subi`, `arith.constant`, ``, and + /// ``. + /// + /// Thus, the "key" associated with operand `A` is: + /// { + /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, + /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, + /// {type: `CONSTANT_OP`, opName: "arith.constant"}, + /// {type: `BLOCK_ARGUMENT`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""} + /// } + SmallVector key; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is an op rather than a block argument). + void pushAncestor(Operation *op) { + ancestorQueue.push(op); + if (op) + visitedAncestors.insert(op); + return; + } + + /// Refresh the key. + /// + /// Refreshing a key entails making it up-to-date with the operand's BFS + /// traversal that has happened till that point in time, i.e, appending the + /// existing key with the front ancestor's "AncestorKey". Note that a key + /// directly reflects the BFS and thus needs to be refreshed during the + /// progression of the traversal. + void refreshKey() { + if (ancestorQueue.empty()) + return; + + Operation *frontAncestor = ancestorQueue.front(); + AncestorKey frontAncestorKey(frontAncestor); + key.push_back(frontAncestorKey); + return; + } + + /// Pop the front ancestor, if any, from the queue and then push its adjacent + /// unvisited ancestors, if any, to the queue (this is the main body of the + /// BFS algorithm). + void popFrontAndPushAdjacentUnvisitedAncestors() { + if (ancestorQueue.empty()) + return; + Operation *frontAncestor = ancestorQueue.front(); + ancestorQueue.pop(); + if (!frontAncestor) + return; + for (Value operand : frontAncestor->getOperands()) { + Operation *operandDefOp = operand.getDefiningOp(); + if (!operandDefOp || !visitedAncestors.contains(operandDefOp)) + pushAncestor(operandDefOp); + } + return; + } +}; + +/// Sorts the operands of `op` in ascending order of the "key" associated with +/// each operand iff `op` is commutative. This is a stable sort. +/// +/// After the application of this pattern, since the commutative operands now +/// have a deterministic order in which they occur in an op, the matching of +/// large DAGs becomes much simpler, i.e., requires much less number of checks +/// to be written by a user in her/his pattern matching function. +/// +/// Some examples of such a sorting: +/// +/// Assume that the sorting is being applied to `foo.commutative`, which is a +/// commutative op. +/// +/// Example 1: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul , +/// %3 = foo.commutative %1, %2 +/// +/// Here, +/// 1. The key associated with %1 is: +/// `{ +/// {CONSTANT_OP, "foo.const"} +/// }` +/// 2. The key associated with %2 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// +/// The key of %2 < the key of %1 +/// Thus, the sorted `foo.commutative` is: +/// %3 = foo.commutative %2, %1 +/// +/// Example 2: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul , +/// %3 = foo.mul %2, %1 +/// %4 = foo.add %2, %1 +/// %5 = foo.commutative %1, %2, %3, %4 +/// +/// Here, +/// 1. The key associated with %1 is: +/// `{ +/// {CONSTANT_OP, "foo.const"} +/// }` +/// 2. The key associated with %2 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// 3. The key associated with %3 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {CONSTANT_OP, "foo.const"}, +/// {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// 4. The key associated with %4 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.add"}, +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {CONSTANT_OP, "foo.const"}, +/// {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// +/// Thus, the sorted `foo.commutative` is: +/// %5 = foo.commutative %4, %3, %2, %1 +class SortCommutativeOperands : public RewritePattern { +public: + SortCommutativeOperands(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Custom comparator for two commutative operands, which returns true iff + // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`, + // i.e., + // 1. In the first unequal pair of corresponding AncestorKeys, the + // AncestorKey in `constCommOperandA` is smaller, or, + // 2. Both the AncestorKeys in every pair are the same and the size of + // `constCommOperandA`'s "key" is smaller. + auto commutativeOperandComparator = + [](const std::unique_ptr &constCommOperandA, + const std::unique_ptr &constCommOperandB) { + if (constCommOperandA->operand == constCommOperandB->operand) + return false; + + auto &commOperandA = + const_cast &>( + constCommOperandA); + auto &commOperandB = + const_cast &>( + constCommOperandB); + + // Iteratively perform the BFS's of both operands until an order among + // them can be determined. + unsigned keyIndex = 0; + while (true) { + if (commOperandA->key.size() <= keyIndex) { + if (commOperandA->ancestorQueue.empty()) + return true; + commOperandA->popFrontAndPushAdjacentUnvisitedAncestors(); + commOperandA->refreshKey(); + } + if (commOperandB->key.size() <= keyIndex) { + if (commOperandB->ancestorQueue.empty()) + return false; + commOperandB->popFrontAndPushAdjacentUnvisitedAncestors(); + commOperandB->refreshKey(); + } + if (commOperandA->ancestorQueue.empty() || + commOperandB->ancestorQueue.empty()) + return commOperandA->key.size() < commOperandB->key.size(); + if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) + return true; + if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) + return false; + keyIndex++; + } + }; + + // If `op` is not commutative, do nothing. + if (!op->hasTrait()) + return failure(); + + // Populate the list of commutative operands. + SmallVector operands = op->getOperands(); + SmallVector, 2> commOperands; + for (Value operand : operands) { + std::unique_ptr commOperand = + std::make_unique(); + commOperand->operand = operand; + commOperand->pushAncestor(operand.getDefiningOp()); + commOperand->refreshKey(); + commOperands.push_back(std::move(commOperand)); + } + + // Sort the operands. + std::stable_sort(commOperands.begin(), commOperands.end(), + commutativeOperandComparator); + SmallVector sortedOperands; + for (const std::unique_ptr &commOperand : commOperands) + sortedOperands.push_back(commOperand->operand); + if (sortedOperands == operands) + return failure(); + rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); + return success(); + } +}; + +void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Transforms/test-commutativity-utils.mlir b/mlir/test/Transforms/test-commutativity-utils.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-commutativity-utils.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s + +// CHECK-LABEL: @test_small_pattern_1 +func.func @test_small_pattern_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 45 : i32 + + // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi" + %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli + %3 = arith.muli %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_small_pattern_2 +// CHECK-SAME: (%[[ARG0:.*]]: i32 +func.func @test_small_pattern_2(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant" + %0 = "test.constant"() {value = 0 : i32} : () -> i32 + + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %1 = arith.constant 0 : i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[ARITH_CONST]], %[[TEST_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_large_pattern +func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK-NEXT: arith.divsi + %0 = arith.divsi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.divsi + %1 = arith.divsi %0, %arg0 : i32 + + // CHECK-NEXT: arith.divsi + %2 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %3 = arith.addi %1, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %4 = arith.subi %2, %3 : i32 + + // CHECK-NEXT: "test.addi" + %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi + %6 = arith.divsi %4, %5 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL8:.*]] = arith.muli + %8 = arith.muli %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL9:.*]] = arith.subi + %9 = arith.subi %7, %8 : i32 + + // CHECK-NEXT: "test.addi" + %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi + %11 = arith.divsi %9, %10 : i32 + + // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi + %12 = arith.divsi %6, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %13 = arith.subi %arg1, %arg0 : i32 + + // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]]) + %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi + %15 = arith.divsi %13, %14 : i32 + + // CHECK-NEXT: %[[VAL16:.*]] = arith.addi + %16 = arith.addi %2, %15 : i32 + + // CHECK-NEXT: arith.subi + %17 = arith.subi %16, %arg1 : i32 + + // CHECK-NEXT: "test.addi" + %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi + %19 = arith.divsi %17, %18 : i32 + + // CHECK-NEXT: "test.addi" + %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi + %21 = arith.divsi %17, %20 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]]) + %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1186,11 +1186,21 @@ let hasFolder = 1; } +def TestAddIOp : TEST_Op<"addi"> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } +def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7); + let results = (outs I32); +} + def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2); let results = (outs I32); diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms + TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp diff --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -0,0 +1,48 @@ +//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the functionality of the commutativity utility pattern. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct CommutativityUtils + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils) + + StringRef getArgument() const final { return "test-commutativity-utils"; } + StringRef getDescription() const final { + return "Test the functionality of the commutativity utility"; + } + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + populateCommutativityUtilsPatterns(patterns); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerCommutativityUtils() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -57,6 +57,7 @@ void registerVectorizerTestPass(); namespace test { +void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -149,6 +150,7 @@ registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner(); mlir::test::registerMemRefBoundCheck();