diff --git a/mlir/include/mlir/Transforms/CommutativityUtils.h b/mlir/include/mlir/Transforms/CommutativityUtils.h --- a/mlir/include/mlir/Transforms/CommutativityUtils.h +++ b/mlir/include/mlir/Transforms/CommutativityUtils.h @@ -17,8 +17,124 @@ #include "mlir/Transforms/DialectConversion.h" +#include + namespace mlir { +/// The possible "types" of ancestors. Here, an ancestor is an op or a block +/// argument present in the backward slice of a value. +enum AncestorType { + /// Pertains to a block argument. + BLOCK_ARGUMENT, + + /// Pertains to a non-constant-like op. + NON_CONSTANT_OP, + + /// Pertains to a constant-like op. + CONSTANT_OP +}; + +/// Stores the "key" associated with an ancestor. +struct AncestorKey { + /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on + /// the ancestor. + AncestorType type; + + /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or + /// `CONSTANT_OP`. Else, holds "". + StringRef opName; + + /// Constructor for `AncestorKey`. + AncestorKey(Operation *op); + + /// Overloaded operator `<` for `AncestorKey`. + /// + /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those + /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in + /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller + /// ones are the ones with smaller op names (lexicographically). + /// + /// TODO: Include other information like attributes, value type, etc., to + /// enhance this comparison. For example, currently this comparison doesn't + /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and + /// `addi (in i64)`. Such an enhancement should only be done if the need + /// arises. + bool operator<(const AncestorKey &key) const; +}; + +/// Stores a commutative operand along with its BFS traversal information. +struct CommutativeOperand { + /// Stores the operand. + Value operand; + + /// Stores the queue of ancestors of the operand's BFS traversal at a + /// particular point in time. + std::queue ancestorQueue; + + /// Stores the list of ancestors that have been visited by the BFS traversal + /// at a particular point in time. + DenseSet visitedAncestors; + + /// Stores the operand's "key". This "key" is defined as a list of the + /// "AncestorKeys" associated with the ancestors of this operand, in a + /// breadth-first order. + /// + /// So, if an operand, say `A`, was produced as follows: + /// + /// `` `` + /// \ / + /// \ / + /// `arith.subi` `arith.constant` + /// \ / + /// `arith.addi` + /// | + /// returns `A` + /// + /// Then, the ancestors of `A`, in the breadth-first order are: + /// `arith.addi`, `arith.subi`, `arith.constant`, ``, and + /// ``. + /// + /// Thus, the "key" associated with operand `A` is: + /// { + /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, + /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, + /// {type: `CONSTANT_OP`, opName: "arith.constant"}, + /// {type: `BLOCK_ARGUMENT`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""} + /// } + SmallVector key; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is an op rather than a block argument). + void pushAncestor(Operation *op); + + /// Refresh the key. + /// + /// Refreshing a key entails making it up-to-date with the operand's BFS + /// traversal that has happened till that point in time, i.e, appending the + /// existing key with the front ancestor's "AncestorKey". Note that a key + /// directly reflects the BFS and thus needs to be refreshed during the + /// progression of the traversal. + void refreshKey(); + + /// Pop the front ancestor, if any, from the queue and then push its adjacent + /// unvisited ancestors, if any, to the queue (this is the main body of the + /// BFS algorithm). + void popFrontAndPushAdjacentUnvisitedAncestors(); + + // Custom comparator for two commutative operands, which returns true iff + // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`, + // i.e., + // 1. In the first unequal pair of corresponding AncestorKeys, the + // AncestorKey in `constCommOperandA` is smaller, or, + // 2. Both the AncestorKeys in every pair are the same and the size of + // `constCommOperandA`'s "key" is smaller. + static bool commutativeOperandComparator( + const std::unique_ptr &constCommOperandA, + const std::unique_ptr &constCommOperandB); +}; + /// Populates the commutativity utility patterns. void populateCommutativityUtilsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp --- a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -15,145 +15,96 @@ #include "mlir/Transforms/CommutativityUtils.h" -#include - using namespace mlir; -/// The possible "types" of ancestors. Here, an ancestor is an op or a block -/// argument present in the backward slice of a value. -enum AncestorType { - /// Pertains to a block argument. - BLOCK_ARGUMENT, - - /// Pertains to a non-constant-like op. - NON_CONSTANT_OP, - - /// Pertains to a constant-like op. - CONSTANT_OP -}; - -/// Stores the "key" associated with an ancestor. -struct AncestorKey { - /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on - /// the ancestor. - AncestorType type; - - /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP` or - /// `CONSTANT_OP`. Else, holds "". - StringRef opName; - - /// Constructor for `AncestorKey`. - AncestorKey(Operation *op) { - if (!op) { - type = BLOCK_ARGUMENT; - } else { - type = - op->hasTrait() ? CONSTANT_OP : NON_CONSTANT_OP; - opName = op->getName().getStringRef(); - } +AncestorKey::AncestorKey(Operation *op) { + if (!op) { + type = BLOCK_ARGUMENT; + } else { + type = + op->hasTrait() ? CONSTANT_OP : NON_CONSTANT_OP; + opName = op->getName().getStringRef(); } +} - /// Overloaded operator `<` for `AncestorKey`. - /// - /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest, those - /// of type `CONSTANT_OP`, the largest, and `NON_CONSTANT_OP` types come in - /// between. Within the types `NON_CONSTANT_OP` and `CONSTANT_OP`, the smaller - /// ones are the ones with smaller op names (lexicographically). - /// - /// TODO: Include other information like attributes, value type, etc., to - /// enhance this comparison. For example, currently this comparison doesn't - /// differentiate between `cmpi sle` and `cmpi sgt` or `addi (in i32)` and - /// `addi (in i64)`. Such an enhancement should only be done if the need - /// arises. - bool operator<(const AncestorKey &key) const { - return std::tie(type, opName) < std::tie(key.type, key.opName); - } -}; - -/// Stores a commutative operand along with its BFS traversal information. -struct CommutativeOperand { - /// Stores the operand. - Value operand; - - /// Stores the queue of ancestors of the operand's BFS traversal at a - /// particular point in time. - std::queue ancestorQueue; - - /// Stores the list of ancestors that have been visited by the BFS traversal - /// at a particular point in time. - DenseSet visitedAncestors; +bool AncestorKey::operator<(const AncestorKey &key) const { + return std::tie(type, opName) < std::tie(key.type, key.opName); +} - /// Stores the operand's "key". This "key" is defined as a list of the - /// "AncestorKeys" associated with the ancestors of this operand, in a - /// breadth-first order. - /// - /// So, if an operand, say `A`, was produced as follows: - /// - /// `` `` - /// \ / - /// \ / - /// `arith.subi` `arith.constant` - /// \ / - /// `arith.addi` - /// | - /// returns `A` - /// - /// Then, the ancestors of `A`, in the breadth-first order are: - /// `arith.addi`, `arith.subi`, `arith.constant`, ``, and - /// ``. - /// - /// Thus, the "key" associated with operand `A` is: - /// { - /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, - /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, - /// {type: `CONSTANT_OP`, opName: "arith.constant"}, - /// {type: `BLOCK_ARGUMENT`, opName: ""}, - /// {type: `BLOCK_ARGUMENT`, opName: ""} - /// } - SmallVector key; +void CommutativeOperand::pushAncestor(Operation *op) { + ancestorQueue.push(op); + if (op) + visitedAncestors.insert(op); +} - /// Push an ancestor into the operand's BFS information structure. This - /// entails it being pushed into the queue (always) and inserted into the - /// "visited ancestors" list (iff it is an op rather than a block argument). - void pushAncestor(Operation *op) { - ancestorQueue.push(op); - if (op) - visitedAncestors.insert(op); - } +void CommutativeOperand::refreshKey() { + if (ancestorQueue.empty()) + return; - /// Refresh the key. - /// - /// Refreshing a key entails making it up-to-date with the operand's BFS - /// traversal that has happened till that point in time, i.e, appending the - /// existing key with the front ancestor's "AncestorKey". Note that a key - /// directly reflects the BFS and thus needs to be refreshed during the - /// progression of the traversal. - void refreshKey() { - if (ancestorQueue.empty()) - return; + Operation *frontAncestor = ancestorQueue.front(); + AncestorKey frontAncestorKey(frontAncestor); + key.push_back(frontAncestorKey); +} - Operation *frontAncestor = ancestorQueue.front(); - AncestorKey frontAncestorKey(frontAncestor); - key.push_back(frontAncestorKey); +void CommutativeOperand::popFrontAndPushAdjacentUnvisitedAncestors() { + if (ancestorQueue.empty()) + return; + Operation *frontAncestor = ancestorQueue.front(); + ancestorQueue.pop(); + if (!frontAncestor) + return; + for (Value operand : frontAncestor->getOperands()) { + Operation *operandDefOp = operand.getDefiningOp(); + if (!operandDefOp || !visitedAncestors.contains(operandDefOp)) + pushAncestor(operandDefOp); } +} - /// Pop the front ancestor, if any, from the queue and then push its adjacent - /// unvisited ancestors, if any, to the queue (this is the main body of the - /// BFS algorithm). - void popFrontAndPushAdjacentUnvisitedAncestors() { - if (ancestorQueue.empty()) - return; - Operation *frontAncestor = ancestorQueue.front(); - ancestorQueue.pop(); - if (!frontAncestor) - return; - for (Value operand : frontAncestor->getOperands()) { - Operation *operandDefOp = operand.getDefiningOp(); - if (!operandDefOp || !visitedAncestors.contains(operandDefOp)) - pushAncestor(operandDefOp); +bool CommutativeOperand::commutativeOperandComparator( + const std::unique_ptr &constCommOperandA, + const std::unique_ptr &constCommOperandB) { + if (constCommOperandA->operand == constCommOperandB->operand) + return false; + + auto &commOperandA = + const_cast &>(constCommOperandA); + auto &commOperandB = + const_cast &>(constCommOperandB); + + // Iteratively perform the BFS's of both operands until an order among + // them can be determined. + unsigned keyIndex = 0; + while (true) { + if (commOperandA->key.size() <= keyIndex) { + if (commOperandA->ancestorQueue.empty()) + return true; + commOperandA->popFrontAndPushAdjacentUnvisitedAncestors(); + commOperandA->refreshKey(); } + if (commOperandB->key.size() <= keyIndex) { + if (commOperandB->ancestorQueue.empty()) + return false; + commOperandB->popFrontAndPushAdjacentUnvisitedAncestors(); + commOperandB->refreshKey(); + } + // Try comparing the keys at the current keyIndex + if (keyIndex < commOperandA->key.size() && + keyIndex < commOperandB->key.size()) { + if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) + return true; + if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) + return false; + } else { // keyIndex exceeds one or both key sizes + // Compare key sizes if the values at every possible keyIndex were + // equal Both operands must have fully generated key and cannot + // have anything in the ancestorQueue + if (commOperandA->ancestorQueue.empty() && + commOperandB->ancestorQueue.empty()) + return commOperandA->key.size() < commOperandB->key.size(); + } + keyIndex++; } -}; +} /// Sorts the operands of `op` in ascending order of the "key" associated with /// each operand iff `op` is commutative. This is a stable sort. @@ -206,6 +157,7 @@ /// 2. The key associated with %2 is: /// `{ /// {NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}, /// {BLOCK_ARGUMENT, ""} /// }` /// 3. The key associated with %3 is: @@ -226,64 +178,13 @@ /// }` /// /// Thus, the sorted `foo.commutative` is: -/// %5 = foo.commutative %4, %3, %2, %1 -class SortCommutativeOperands : public RewritePattern { -public: +/// %5 = foo.commutative %4, %2, %3, %1 +struct SortCommutativeOperands final + : public OpTraitRewritePattern { SortCommutativeOperands(MLIRContext *context) - : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} + : OpTraitRewritePattern(context, /*benefit=*/5) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - // Custom comparator for two commutative operands, which returns true iff - // the "key" of `constCommOperandA` < the "key" of `constCommOperandB`, - // i.e., - // 1. In the first unequal pair of corresponding AncestorKeys, the - // AncestorKey in `constCommOperandA` is smaller, or, - // 2. Both the AncestorKeys in every pair are the same and the size of - // `constCommOperandA`'s "key" is smaller. - auto commutativeOperandComparator = - [](const std::unique_ptr &constCommOperandA, - const std::unique_ptr &constCommOperandB) { - if (constCommOperandA->operand == constCommOperandB->operand) - return false; - - auto &commOperandA = - const_cast &>( - constCommOperandA); - auto &commOperandB = - const_cast &>( - constCommOperandB); - - // Iteratively perform the BFS's of both operands until an order among - // them can be determined. - unsigned keyIndex = 0; - while (true) { - if (commOperandA->key.size() <= keyIndex) { - if (commOperandA->ancestorQueue.empty()) - return true; - commOperandA->popFrontAndPushAdjacentUnvisitedAncestors(); - commOperandA->refreshKey(); - } - if (commOperandB->key.size() <= keyIndex) { - if (commOperandB->ancestorQueue.empty()) - return false; - commOperandB->popFrontAndPushAdjacentUnvisitedAncestors(); - commOperandB->refreshKey(); - } - if (commOperandA->ancestorQueue.empty() || - commOperandB->ancestorQueue.empty()) - return commOperandA->key.size() < commOperandB->key.size(); - if (commOperandA->key[keyIndex] < commOperandB->key[keyIndex]) - return true; - if (commOperandB->key[keyIndex] < commOperandA->key[keyIndex]) - return false; - keyIndex++; - } - }; - - // If `op` is not commutative, do nothing. - if (!op->hasTrait()) - return failure(); - // Populate the list of commutative operands. SmallVector operands = op->getOperands(); SmallVector, 2> commOperands; @@ -298,7 +199,7 @@ // Sort the operands. std::stable_sort(commOperands.begin(), commOperands.end(), - commutativeOperandComparator); + CommutativeOperand::commutativeOperandComparator); SmallVector sortedOperands; for (const std::unique_ptr &commOperand : commOperands) sortedOperands.push_back(commOperand->operand); diff --git a/mlir/unittests/Transforms/CMakeLists.txt b/mlir/unittests/Transforms/CMakeLists.txt --- a/mlir/unittests/Transforms/CMakeLists.txt +++ b/mlir/unittests/Transforms/CMakeLists.txt @@ -6,3 +6,5 @@ PRIVATE MLIRParser MLIRTransforms) + +add_subdirectory(Utils) diff --git a/mlir/unittests/Transforms/Utils/CMakeLists.txt b/mlir/unittests/Transforms/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Transforms/Utils/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_unittest(MLIRTransformUtilsTests + CommutativityUtils.cpp +) +target_link_libraries(MLIRTransformUtilsTests + PRIVATE + MLIRTransformUtils) diff --git a/mlir/unittests/Transforms/Utils/CommutativityUtils.cpp b/mlir/unittests/Transforms/Utils/CommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,63 @@ +//===- CommutativityUtils.cpp - Commutative Operand Comparison unit tests -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" +#include "mlir/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "gtest/gtest.h" +#include + +using namespace mlir; + +static Operation *createOp(MLIRContext *context, StringRef name, + TypeRange types = std::nullopt) { + context->allowUnregisteredDialects(); + return Operation::create( + UnknownLoc::get(context), OperationName(name, context), types, + std::nullopt, std::nullopt, /*properties=*/nullptr, std::nullopt, 0); +} + +// The following test case is based on this example +// %0 = dummy1 +// %1 = dummy2 +// %commutative = %0, %1 +// +// The test case compares the commutative operands dummy1 and dummy2. +// dummy1's key is complete and its ancestorQueue is empty while dummy2's +// key only has one element and its ancestorQueue has the remaining 2 +// elements. +// +// dummy1 is the smaller operand because its name is lexicographically +// smaller than dummy2 +TEST(CommutativityUtilsTest, TestOneEmptyAncestorQueue) { + MLIRContext context; + Builder builder(&context); + Operation *dummy1 = createOp(&context, "dummy1", builder.getIntegerType(32)); + Operation *dummy2 = createOp(&context, "dummy2", builder.getIntegerType(32)); + + std::unique_ptr commOperandA = + std::make_unique(); + commOperandA->operand = dummy1->getOpResult(0); + AncestorKey commOperandAKey0(dummy1); + AncestorKey commOperandAKey1(nullptr); + AncestorKey commOperandAKey2(nullptr); + commOperandA->key.push_back(commOperandAKey0); + commOperandA->key.push_back(commOperandAKey1); + commOperandA->key.push_back(commOperandAKey2); + + std::unique_ptr commOperandB = + std::make_unique(); + commOperandB->operand = dummy2->getOpResult(0); + AncestorKey commOperandBKey0(dummy2); + commOperandB->key.push_back(commOperandBKey0); + commOperandB->ancestorQueue.push(nullptr); + commOperandB->ancestorQueue.push(nullptr); + + EXPECT_TRUE(CommutativeOperand::commutativeOperandComparator(commOperandA, + commOperandB)); +}