diff --git a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt --- a/clang/docs/tools/clang-formatted-files.txt +++ b/clang/docs/tools/clang-formatted-files.txt @@ -7887,6 +7887,7 @@ mlir/include/mlir/Tools/PDLL/ODS/Dialect.h mlir/include/mlir/Tools/PDLL/ODS/Operation.h mlir/include/mlir/Tools/PDLL/Parser/Parser.h +mlir/include/mlir/Transforms/CommutativityUtils.h mlir/include/mlir/Transforms/ControlFlowSinkUtils.h mlir/include/mlir/Transforms/DialectConversion.h mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -8447,6 +8448,7 @@ mlir/lib/Transforms/StripDebugInfo.cpp mlir/lib/Transforms/SymbolDCE.cpp mlir/lib/Transforms/SymbolPrivatize.cpp +mlir/lib/Transforms/Utils/CommutativityUtils.cpp mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp mlir/lib/Transforms/Utils/FoldUtils.cpp diff --git a/mlir/include/mlir/Transforms/CommutativityUtils.h b/mlir/include/mlir/Transforms/CommutativityUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/CommutativityUtils.h @@ -0,0 +1,334 @@ +//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares some helper functions and a commutativity utility +// (a templated op rewrite pattern). The latter is intended to be used inside +// passes to simplify the matching of commutative operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H +#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H + +#include "mlir/Transforms/DialectConversion.h" +#include + +namespace mlir { + +/// Declares various types of operations and block argument. +enum BlockArgumentOrOpType { + /// Pertains to a block argument. + BLOCK_ARGUMENT, + /// Pertains to a non-constant-like operation. + NON_CONSTANT_OP, + /// Pertains to a constant-like operation. + CONSTANT_OP +}; + +/// Stores the "key" associated with a block argument or an operation. +struct BlockArgumentOrOpKey { + /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on + /// the block argument/operation. + BlockArgumentOrOpType type; + /// Holds the full op name iff the `type` is `NON_CONSTANT_OP`. + StringRef opName; + + /// Declares the overloaded operator `<`. + /// `BlockArgumentOrOpKey1` is considered < `BlockArgumentOrOpKey2` iff: + /// 1. The `type` of `BlockArgumentOrOpKey1` is `BLOCK_ARGUMENT` and that of + /// `BlockArgumentOrOpKey2` isn't, + /// 2. The `type` of `BlockArgumentOrOpKey1` is `NON_CONSTANT_OP` and that + /// of `BlockArgumentOrOpKey2` is `CONSTANT_OP`, or + /// 3. Both have the same `type` and the `opName` of `BlockArgumentOrOpKey1` + /// is alphabetically smaller than that of `BlockArgumentOrOpKey2`. + bool operator<(const BlockArgumentOrOpKey &key) const { + if ((type == BLOCK_ARGUMENT && key.type != BLOCK_ARGUMENT) || + (type == NON_CONSTANT_OP && key.type == CONSTANT_OP)) + return true; + if ((key.type == BLOCK_ARGUMENT && type != BLOCK_ARGUMENT) || + (key.type == NON_CONSTANT_OP && type == CONSTANT_OP)) + return false; + return opName < key.opName; + } +}; + +/// Stores the BFS traversal information of an operand. +struct OperandBFS { + /// Stores the queue of ancestors of the BFS traversal of an operand at a + /// particular point in time. + std::queue ancestorQueue; + + /// Stores the list of visited ancestors of the BFS traversal of an operand at + /// a particular point in time. + DenseSet visitedAncestors; + + /// Stores the "key" associated with an operand. This "key" is defined as the + /// list of the "BlockArgumentOrOpKeys" associated with the block arguments + /// and operations present in the "backward slice" of this operand, in a + /// breadth-first order. + /// + /// So, if an operand, say `A`, was produced as follows: + /// + /// `` `` + /// \ / + /// \ / + /// `arith.subi` `arith.constant` + /// \ / + /// `arith.addi` + /// | + /// returns `A` + /// + /// Then, the block arguments and operations present in the backward slice of + /// `A`, in the breadth-first order are: + /// `arith.addi`, `arith.subi`, `arith.constant`, ``, and + /// ``. + /// + /// Now, the "BlockArgumentOrOpKey" associated with: + /// 1. A block argument is {type: `BLOCK_ARGUMENT`, opName: null}. + /// 2. A non-constant-like op, for example, `arith.addi`, is {type: + /// `NON_CONSTANT_OP`, opName: "arith.addi"}. + /// 3. A constant-like op, for example, `arith.constant`, is {type: + /// `CONSTANT_OP`, opName: null}. + /// + /// Thus, the "key" associated with operand `A` is: + /// {{type: `NON_CONSTANT_OP`, opName: "arith.addi"}, + /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, + /// {type: `CONSTANT_OP`, opName: null}, + /// {type: `BLOCK_ARGUMENT`, opName: null}, + /// {type: `BLOCK_ARGUMENT`, opName: null}}. + SmallVector key; + + /// Stores true iff the operand has been assigned a sorted position yet. + bool isAssignedSortedPosition = false; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is not null, i.e., corresponds to an op + /// rather than a block argument). + void pushAncestor(Operation *ancestor) { + ancestorQueue.push(ancestor); + if (ancestor) + visitedAncestors.insert(ancestor); + return; + } + + /// Pop the ancestor from the front of the queue. + void popAncestor() { + assert(!ancestorQueue.empty() && + "to pop the ancestor from the front of the queue, the ancestor " + "queue should be non-empty"); + ancestorQueue.pop(); + return; + } + + /// Return the ancestor at the front of the queue. + Operation *frontAncestor() { + assert(!ancestorQueue.empty() && + "to access the ancestor at the front of the queue, the ancestor " + "queue should be non-empty"); + return ancestorQueue.front(); + } +}; + +/// Returns true iff at least one unassigned operand exists. An unassigned +/// operand refers to one which has not been assigned a sorted position yet. +bool hasAtLeastOneUnassignedOperand(SmallVector bfsOfOperands); + +/// Returns: +/// -1 if `firstKey` < `secondKey`, +/// 0 if `firstKey` == `secondKey`, and +/// 1 if `firstKey` > `secondKey`. +/// +/// Note that: +/// +/// (A) `firstKey` == `secondKey` iff: +/// Both these keys, each of which is a list, have the same size and both +/// the elements in each pair of corresponding elements among them is the +/// same. +/// +/// (B) `firstKey` < `secondKey` iff: +/// 1. In the first unequal pair of corresponding elements among them, +/// `firstKey`'s element is smaller, or +/// 2. Both the elements in every pair of corresponding elements are the same +/// in both keys and the size of `firstKey` is smaller. +/// +/// (C) `secondKey` < `firstKey` condition is defined likewise. +int compareKeys(SmallVector firstKey, + SmallVector secondKey); + +/// Goes through all the unassigned operands of `bfsOfOperands` and: +/// 1. Stores the indices of the ones with the smallest key in +/// `smallestKeyIndices`, +/// 2. Stores the indices of the ones with the largest key in +/// `largestKeyIndices`, +/// 3. Sets `hasASingleOperandWithSmallestKey` as true if exactly one of them +/// has the smallest key (and as false otherwise), AND, +/// 4. Sets `hasASingleOperandWithLargestKey` as true if exactly one of them has +/// the largest key (and as false otherwise). +void getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + SmallVector bfsOfOperands, + DenseSet &smallestKeyIndices, + DenseSet &largestKeyIndices, + bool &hasASingleOperandWithSmallestKey, + bool &hasASingleOperandWithLargestKey); + +/// Update the key associated with each unassigned operand in `bfsOfOperands`. +/// Updating a key entails making it up-to-date with its associated operand's +/// BFS traversal that has happened till that point in time, i.e, appending the +/// existing key with the current front ancestor's "key". Note that a key +/// directly reflects the BFS and thus needs to be updated during the +/// progression of the traversal. +void updateKeys(SmallVector bfsOfOperands); + +/// If `keyIndices` contains `indexOfOperand` and either `isTheOnlyKey` is true +/// or the ancestor queue of `bfsOfOperand` is empty, assign the sorted position +/// `positionToAssign` to the operand of `op` at index `indexOfOperand`, and +/// return true. Else, return false. +bool assignSortedPositionTo(OperandBFS *bfsOfOperand, unsigned indexOfOperand, + DenseSet keyIndices, bool isTheOnlyKey, + SmallVector &sortedOperands, + unsigned positionToAssign, Operation *op); + +/// In each of the operands of `bfsOfOperands`, pop the front ancestor from the +/// queue, if any, and then push its adjacent unvisited ancestors, if any, to +/// the queue (this is the main body of the BFS algorithm). +void popFrontAndPushAdjacentUnvisitedAncestors( + SmallVector bfsOfOperands); + +/// Sorts the operands of `op` in ascending order of the "key" associated with +/// each operand iff `op` is commutative. +/// +/// After the application of this pattern, since the commutative operands now +/// have a deterministic order in which they occur in an operation, the matching +/// of DAGs becomes much simpler, i.e., requires much less number of checks to +/// be written by a user in her/his pattern matching function. +template +class SortCommutativeOperands : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // If `op` is not commutative, do nothing. + if (!op->template hasTrait()) + return failure(); + + // `bfsOfOperands` stores the BFS traversal information of each operand of + // `op`. For each operand, this information comprises a queue of ancestors + // being visited during the BFS (at a particular point in time), a list of + // visited ancestors (at a particular point in time), its associated key (at + // a particular point in time), and whether or not the operand has been + // assigned a sorted position yet. + SmallVector bfsOfOperands; + + // Initially, each operand's ancestor queue contains the op defining it + // (which is considered its first ancestor). Thus, it acts as the starting + // point for that operand's BFS traversal. + for (Value operand : op->getOperands()) { + OperandBFS *bfsOfOperand = new OperandBFS(); + bfsOfOperand->pushAncestor(operand.getDefiningOp()); + bfsOfOperands.push_back(bfsOfOperand); + } + + // Since none of the operands have been assigned a sorted position yet, the + // smallest unassigned position is set as zero and the largest one is set as + // the number of operands in `op` minus one (N - 1). This is because each + // operand will be assigned a sorted position between 0 and (N - 1), both + // inclusive. + unsigned numOperands = op->getNumOperands(); + unsigned smallestUnassignedPosition = 0; + unsigned largestUnassignedPosition = numOperands - 1; + + // `sortedOperands` will store the list of `op`'s operands in sorted order. + // At first, all elements in it are initialized as null. + SmallVector sortedOperands; + for (unsigned i = 0; i < numOperands; i++) + sortedOperands.push_back(nullptr); + + // We perform the BFS traversals of all operands parallelly until each of + // them is assigned a sorted position. During the traversals, we try to + // assign a sorted position to an operand as soon as it is possible (based + // on a comparision of its traversal with the other traversals at that + // particular point in time). + while (hasAtLeastOneUnassignedOperand(bfsOfOperands)) { + // Update the keys corresponding to all unassigned operands. + updateKeys(bfsOfOperands); + + // Stores the indices of the unassigned operands whose key is the + // smallest. + DenseSet smallestKeyIndices; + // Stores the indices of the unassigned operands whose key is the largest. + DenseSet largestKeyIndices; + + // Stores true iff there is a single unassigned operand that has the + // smallest key. + bool hasASingleOperandWithSmallestKey; + // Stores true iff there is a single unassigned operand that has the + // largest key. + bool hasASingleOperandWithLargestKey; + + getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + bfsOfOperands, smallestKeyIndices, largestKeyIndices, + hasASingleOperandWithSmallestKey, hasASingleOperandWithLargestKey); + + // Go through each of the unassigned operands with the smallest key and + // try to assign it a sorted position if possible (ensuring stable + // sorting). + for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + // If an unassigned operand has the smallest key and: + // 1. It is the only operand with the smallest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `smallestUnassignedPosition` (which will + // be its new position in the rearranged `op`). + if (assignSortedPositionTo( + bfsOfOperand, /*indexOfOperand=*/indexedBfsOfOperand.index(), + /*keyIndices=*/smallestKeyIndices, + /*isTheOnlyKey=*/hasASingleOperandWithSmallestKey, + /*sortedOperands=*/sortedOperands, + /*positionToAssign=*/smallestUnassignedPosition, /*op=*/op)) + smallestUnassignedPosition++; + } + // Go through each of the unassigned operands with the largest key and try + // to assign it a sorted position if possible (ensuring stable sorting). + for (auto indexedBfsOfOperand : + llvm::enumerate(llvm::reverse(bfsOfOperands))) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + // If an unassigned operand has the largest key and: + // 1. It is the only operand with the largest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `largestUnassignedPosition` (which will + // be its new position in the rearranged `op`). + if (assignSortedPositionTo( + bfsOfOperand, /*indexOfOperand=*/numOperands - + indexedBfsOfOperand.index() - 1, + /*keyIndices=*/largestKeyIndices, + /*isTheOnlyKey=*/hasASingleOperandWithLargestKey, + /*sortedOperands=*/sortedOperands, + /*positionToAssign=*/largestUnassignedPosition, /*op=*/op)) + largestUnassignedPosition--; + } + + // For each operand in `bfsOfOperands`, pop the front ancestor from the + // queue and push its adjacent unvisited ancestors into the queue. + popFrontAndPushAdjacentUnvisitedAncestors(bfsOfOperands); + } + rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); + return success(); + } +}; + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + CommutativityUtils.cpp ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,231 @@ +//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements some helper functions that are used by a commutativity +// utility (a templated op rewrite pattern). The latter is intended to be used +// inside passes to simplify the matching of commutative operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#define DEBUG_TYPE "commutativity-utils" + +using namespace mlir; + +/// Returns true iff at least one unassigned operand exists. An unassigned +/// operand refers to one which has not been assigned a sorted position yet. +bool mlir::hasAtLeastOneUnassignedOperand( + SmallVector bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (!bfsOfOperand->isAssignedSortedPosition) + return true; + } + return false; +} + +/// Returns: +/// -1 if `firstKey` < `secondKey`, +/// 0 if `firstKey` == `secondKey`, and +/// 1 if `firstKey` > `secondKey`. +/// +/// Note that: +/// +/// (A) `firstKey` == `secondKey` iff: +/// Both these keys, each of which is a list, have the same size and both +/// the elements in each pair of corresponding elements among them is the +/// same. +/// +/// (B) `firstKey` < `secondKey` iff: +/// 1. In the first unequal pair of corresponding elements among them, +/// `firstKey`'s element is smaller, or +/// 2. Both the elements in every pair of corresponding elements are the same +/// in both keys and the size of `firstKey` is smaller. +/// +/// (C) `secondKey` < `firstKey` condition is defined likewise. +int mlir::compareKeys(SmallVector firstKey, + SmallVector secondKey) { + unsigned firstKeySize = firstKey.size(); + unsigned secondKeySize = secondKey.size(); + unsigned smallestSize = firstKeySize; + if (secondKeySize < smallestSize) + smallestSize = secondKeySize; + + for (unsigned i = 0; i < smallestSize; i++) { + if (firstKey[i] < secondKey[i]) + return -1; + if (secondKey[i] < firstKey[i]) + return 1; + } + + if (firstKeySize == secondKeySize) + return 0; + if (firstKeySize < secondKeySize) + return -1; + return 1; +} + +/// Goes through all the unassigned operands of `bfsOfOperands` and: +/// 1. Stores the indices of the ones with the smallest key in +/// `smallestKeyIndices`, +/// 2. Stores the indices of the ones with the largest key in +/// `largestKeyIndices`, +/// 3. Sets `hasASingleOperandWithSmallestKey` as true if exactly one of them +/// has the smallest key (and as false otherwise), AND, +/// 4. Sets `hasASingleOperandWithLargestKey` as true if exactly one of them has +/// the largest key (and as false otherwise). +void mlir::getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + SmallVector bfsOfOperands, + DenseSet &smallestKeyIndices, + DenseSet &largestKeyIndices, + bool &hasASingleOperandWithSmallestKey, + bool &hasASingleOperandWithLargestKey) { + bool foundAnUnassignedOperand = false; + + // Compute the smallest and largest keys present among the unassigned operands + // of `bfsOfOperands`. + SmallVector smallestKey, largestKey; + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + SmallVector currentKey = bfsOfOperand->key; + if (!foundAnUnassignedOperand) { + foundAnUnassignedOperand = true; + smallestKey = currentKey; + largestKey = currentKey; + continue; + } + if (compareKeys(smallestKey, currentKey) == 1) + smallestKey = currentKey; + if (compareKeys(largestKey, currentKey) == -1) + largestKey = currentKey; + } + + // If there is no unassigned operand, assign the necessary values to the input + // arguments and return. + if (!foundAnUnassignedOperand) { + hasASingleOperandWithSmallestKey = false; + hasASingleOperandWithLargestKey = false; + return; + } + + // Populate `smallestKeyIndices` and `largestKeyIndices` and set + // `hasASingleOperandWithSmallestKey` and `hasASingleOperandWithLargestKey` + // accordingly. + bool smallestKeyFound = false; + bool largestKeyFound = false; + hasASingleOperandWithSmallestKey = true; + hasASingleOperandWithLargestKey = true; + for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + unsigned index = indexedBfsOfOperand.index(); + SmallVector currentKey = bfsOfOperand->key; + + if (compareKeys(smallestKey, currentKey) == 0) { + smallestKeyIndices.insert(index); + if (smallestKeyFound) + hasASingleOperandWithSmallestKey = false; + smallestKeyFound = true; + } + + if (compareKeys(largestKey, currentKey) == 0) { + largestKeyIndices.insert(index); + if (largestKeyFound) + hasASingleOperandWithLargestKey = false; + largestKeyFound = true; + } + } + return; +} + +/// Update the key associated with each unassigned operand in `bfsOfOperands`. +/// Updating a key entails making it up-to-date with its associated operand's +/// BFS traversal that has happened till that point in time, i.e, appending the +/// existing key with the current front ancestor's "key". Note that a key +/// directly reflects the BFS and thus needs to be updated during the +/// progression of the traversal. +void mlir::updateKeys(SmallVector bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition || + bfsOfOperand->ancestorQueue.empty()) + continue; + + Operation *frontAncestor = bfsOfOperand->frontAncestor(); + if (!frontAncestor) { + // When the front ancestor is a block argument, we append the old key + // with an element whose `type` is `BLOCK_ARGUMENT` and `opName` is null, + // which is the key associated with a block argument. + mlir::BlockArgumentOrOpKey blockArgumentOrOpKey; + blockArgumentOrOpKey.type = BLOCK_ARGUMENT; + bfsOfOperand->key.push_back(blockArgumentOrOpKey); + } else if (frontAncestor->hasTrait()) { + // When the front ancestor is a constant-like operation, we append the old + // key with an element whose `type` is `CONSTANT_OP` and `opName` is null, + // which is the key associated with a constant-like operation. + mlir::BlockArgumentOrOpKey blockArgumentOrOpKey; + blockArgumentOrOpKey.type = CONSTANT_OP; + bfsOfOperand->key.push_back(blockArgumentOrOpKey); + } else { + // When the front ancestor is a non-constant-like operation, we append the + // old key with an element whose `type` is `NON_CONSTANT_OP` and `opName` + // is its full op name, which is the key associated with a + // non-constant-like operation. + mlir::BlockArgumentOrOpKey blockArgumentOrOpKey; + blockArgumentOrOpKey.type = NON_CONSTANT_OP; + blockArgumentOrOpKey.opName = frontAncestor->getName().getStringRef(); + bfsOfOperand->key.push_back(blockArgumentOrOpKey); + } + } + return; +} + +/// If `keyIndices` contains `indexOfOperand` and either `isTheOnlyKey` is true +/// or the ancestor queue of `bfsOfOperand` is empty, assign the sorted position +/// `positionToAssign` to the operand of `op` at index `indexOfOperand`, and +/// return true. Else, return false. +bool mlir::assignSortedPositionTo(OperandBFS *bfsOfOperand, + unsigned indexOfOperand, + DenseSet keyIndices, + bool isTheOnlyKey, + SmallVector &sortedOperands, + unsigned positionToAssign, Operation *op) { + if (keyIndices.contains(indexOfOperand) && + (isTheOnlyKey || bfsOfOperand->ancestorQueue.empty())) { + bfsOfOperand->isAssignedSortedPosition = true; + sortedOperands[positionToAssign] = op->getOperand(indexOfOperand); + return true; + } + return false; +} + +/// In each of the operands of `bfsOfOperands`, pop the front ancestor from the +/// queue, if any, and then push its adjacent unvisited ancestors, if any, to +/// the queue (this is the main body of the BFS algorithm). +void mlir::popFrontAndPushAdjacentUnvisitedAncestors( + SmallVector bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition || + bfsOfOperand->ancestorQueue.empty()) + continue; + Operation *frontAncestor = bfsOfOperand->frontAncestor(); + bfsOfOperand->popAncestor(); + if (!frontAncestor) + continue; + for (Value operand : frontAncestor->getOperands()) { + Operation *thisOperandDefOp = operand.getDefiningOp(); + if (!thisOperandDefOp || + !bfsOfOperand->visitedAncestors.contains(thisOperandDefOp)) + bfsOfOperand->pushAncestor(thisOperandDefOp); + } + } + return; +} diff --git a/mlir/test/Transforms/test-commutativity-utils.mlir b/mlir/test/Transforms/test-commutativity-utils.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-commutativity-utils.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s + +// CHECK-LABEL: @test_small_pattern_1 +func.func @test_small_pattern_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 45 : i32 + + // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi" + %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli + %3 = arith.muli %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_small_pattern_2 +// CHECK-SAME: (%[[ARG0:.*]]: i32 +func.func @test_small_pattern_2(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant" + %0 = "test.constant"() {value = 0 : i32} : () -> i32 + + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %1 = arith.constant 0 : i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[TEST_CONST]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_large_pattern +func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK-NEXT: arith.divsi + %0 = arith.divsi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.divsi + %1 = arith.divsi %0, %arg0 : i32 + + // CHECK-NEXT: arith.divsi + %2 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %3 = arith.addi %1, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %4 = arith.subi %2, %3 : i32 + + // CHECK-NEXT: "test.addi" + %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi + %6 = arith.divsi %4, %5 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL8:.*]] = arith.muli + %8 = arith.muli %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL9:.*]] = arith.subi + %9 = arith.subi %7, %8 : i32 + + // CHECK-NEXT: "test.addi" + %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi + %11 = arith.divsi %9, %10 : i32 + + // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi + %12 = arith.divsi %6, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %13 = arith.subi %arg1, %arg0 : i32 + + // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]]) + %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi + %15 = arith.divsi %13, %14 : i32 + + // CHECK-NEXT: %[[VAL16:.*]] = arith.addi + %16 = arith.addi %2, %15 : i32 + + // CHECK-NEXT: arith.subi + %17 = arith.subi %16, %arg1 : i32 + + // CHECK-NEXT: "test.addi" + %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi + %19 = arith.divsi %17, %18 : i32 + + // CHECK-NEXT: "test.addi" + %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi + %21 = arith.divsi %17, %20 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]]) + %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1172,11 +1172,21 @@ let hasFolder = 1; } +def TestAddIOp : TEST_Op<"addi"> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } +def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7); + let results = (outs I32); +} + def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2); let results = (outs I32); diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms + TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp diff --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -0,0 +1,56 @@ +//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the functionality of the commutativity utility. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "TestDialect.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace arith; +using namespace test; + +namespace { + +struct CommutativityUtils + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils) + + StringRef getArgument() const final { return "test-commutativity-utils"; } + StringRef getDescription() const final { + return "Test the functionality of the commutativity utility"; + } + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + patterns.add, + mlir::SortCommutativeOperands, + mlir::SortCommutativeOperands, + mlir::SortCommutativeOperands, + mlir::SortCommutativeOperands>( + context); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerCommutativityUtils() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -57,6 +57,7 @@ void registerVectorizerTestPass(); namespace test { +void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -150,6 +151,7 @@ registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner(); mlir::test::registerMemRefBoundCheck();