diff --git a/clang/docs/tools/clang-formatted-files.txt b/clang/docs/tools/clang-formatted-files.txt --- a/clang/docs/tools/clang-formatted-files.txt +++ b/clang/docs/tools/clang-formatted-files.txt @@ -7888,6 +7888,7 @@ mlir/include/mlir/Tools/PDLL/ODS/Dialect.h mlir/include/mlir/Tools/PDLL/ODS/Operation.h mlir/include/mlir/Tools/PDLL/Parser/Parser.h +mlir/include/mlir/Transforms/CommutativityUtils.h mlir/include/mlir/Transforms/ControlFlowSinkUtils.h mlir/include/mlir/Transforms/DialectConversion.h mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -8448,6 +8449,7 @@ mlir/lib/Transforms/StripDebugInfo.cpp mlir/lib/Transforms/SymbolDCE.cpp mlir/lib/Transforms/SymbolPrivatize.cpp +mlir/lib/Transforms/Utils/CommutativityUtils.cpp mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp mlir/lib/Transforms/Utils/FoldUtils.cpp diff --git a/mlir/include/mlir/Transforms/CommutativityUtils.h b/mlir/include/mlir/Transforms/CommutativityUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/CommutativityUtils.h @@ -0,0 +1,28 @@ +//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares a utility that is intended to be used inside a pass +// or an individual pattern to simplify the matching of commutative operations. +// Note that this utility can also be used inside PDL patterns in conjunction +// with the `pdl.apply_native_rewrite` op. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H +#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H + +namespace mlir { + +class Operation; +class PatternRewriter; + +void sortCommutativeOperands(Operation *op, PatternRewriter &rewriter); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + CommutativityUtils.cpp ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,399 @@ +//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a utility that is intended to be used inside a pass or +// an individual pattern to simplify the matching of commutative operations. +// Note that this utility can also be used inside PDL patterns in conjunction +// with the `pdl.apply_native_rewrite` op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "mlir/IR/PatternMatch.h" +#include + +#define DEBUG_TYPE "commutativity-utils" + +using namespace mlir; + +/// Stores the BFS traversal information of an operand. +struct OperandBFS { + /// Stores the queue of ancestors of the BFS traversal of an operand at a + /// particular point in time. + std::queue ancestorQueue; + + /// Stores the list of visited ancestors of the BFS traversal of an operand at + /// a particular point in time. + DenseSet visitedAncestors; + + /// Stores the key corresponding to the BFS traversal of an operand at a + /// particular point in time. + /// Some examples: + /// 1. If the BFS has seen `arith.addi`, + /// then, + /// the key will store the string: + /// "1arith.addi". + /// 2. If the BFS has seen `arg5`, + /// then, + /// the key will store the string: + /// "2". + /// 3. If the BFS has seen `arith.constant`, + /// then, + /// the key will store the string: + /// "3arith.constant". + /// 4. If the BFS has seen `arith.addi`, `test.constant`, `scf.if`, `tf.Add`, + /// `arith.constant`, and `arg5` (in BFS order), + /// then, + /// the key will store the string: + /// "1arith.addi3test.constant1scf.if1tf.Add3arith.constant2". + /// + /// Such a definition of "key" will allow the ascending order of keys of + /// different operands to be such the (1) ones defined by non-constant-like + /// ops come first, followed by (2) block arguments, which are finally + /// followed by the (3) ones defined by constant-like ops. In addition to + /// this, within the categories (1) and (3), the order of operands is + /// alphabetical w.r.t. the dialect name and op name. + /// + /// Further, as an example to demonstrate the comparision of keys, note that + /// if we have the following commutative op (foo.op): + /// e = foo.div f, g + /// c = foo.constant + /// b = foo.add e, d + /// a = foo.add c, d + /// s = foo.op a, b, + /// then, + /// the key associated with operand `a` will be "1foo.add3foo.constant", and, + /// the key associated with operand `b` will be "1foo.add1foo.div", + /// and thus, + /// key of `a` > key of `b`, + /// + /// which means that a "sorted" foo.op would look like: + /// s = foo.op b, a (instead of a, b). + std::string key = ""; + + /// Stores true iff the operand has been assigned a sorted position yet. + bool isAssignedSortedPosition = false; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is not null, i.e., corresponds to an op + /// rather than a block argument). + void pushAncestor(Operation *ancestor) { + ancestorQueue.push(ancestor); + if (ancestor) + visitedAncestors.insert(ancestor); + return; + } + + /// Pop the ancestor from the front of the queue. + void popAncestor() { + assert(!ancestorQueue.empty() && + "to pop the ancestor from the front of the queue, the ancestor " + "queue should be non-empty"); + ancestorQueue.pop(); + return; + } + + /// Return the ancestor at the front of the queue. + Operation *frontAncestor() { + assert(!ancestorQueue.empty() && + "to access the ancestor at the front of the queue, the ancestor " + "queue should be non-empty"); + return ancestorQueue.front(); + } +}; + +/// Returns true iff at least one unassigned operand exists. An unassigned +/// operand refers to one which has not been assigned a sorted position yet. +static bool +hasAtLeastOneUnassignedOperand(SmallVector bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (!bfsOfOperand->isAssignedSortedPosition) + return true; + } + return false; +} + +/// Goes through all the unassigned operands of `bfsOfOperands` and: +/// 1. Stores the indices of the ones with the smallest key in +/// `smallestKeyIndices`, +/// 2. Stores the indices of the ones with the largest key in +/// `largestKeyIndices`, +/// 3. Sets `hasASingleOperandWithSmallestKey` as true if exactly one of them +/// has the smallest key (and as false otherwise), AND, +/// 4. Sets `hasASingleOperandWithLargestKey` as true if exactly one of them has +/// the largest key (and as false otherwise). +static void getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + SmallVector bfsOfOperands, + DenseSet &smallestKeyIndices, + DenseSet &largestKeyIndices, + bool &hasASingleOperandWithSmallestKey, + bool &hasASingleOperandWithLargestKey) { + bool foundAnUnassignedOperand = false; + + // Compute the smallest and largest keys present among the unassigned operands + // of `bfsOfOperands`. + std::string smallestKey, largestKey; + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + std::string currentKey = bfsOfOperand->key; + if (!foundAnUnassignedOperand) { + foundAnUnassignedOperand = true; + smallestKey = currentKey; + largestKey = currentKey; + continue; + } + if (smallestKey > currentKey) + smallestKey = currentKey; + if (largestKey < currentKey) + largestKey = currentKey; + } + + // If there is no unassigned operand, assign the necessary values to the input + // arguments and return. + if (!foundAnUnassignedOperand) { + hasASingleOperandWithSmallestKey = false; + hasASingleOperandWithLargestKey = false; + return; + } + + // Populate `smallestKeyIndices` and `largestKeyIndices` and set + // `hasASingleOperandWithSmallestKey` and `hasASingleOperandWithLargestKey` + // accordingly. + bool smallestKeyFound = false; + bool largestKeyFound = false; + hasASingleOperandWithSmallestKey = true; + hasASingleOperandWithLargestKey = true; + for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + unsigned index = indexedBfsOfOperand.index(); + std::string currentKey = bfsOfOperand->key; + + if (smallestKey == currentKey) { + smallestKeyIndices.insert(index); + if (smallestKeyFound) + hasASingleOperandWithSmallestKey = false; + smallestKeyFound = true; + } + + if (largestKey == currentKey) { + largestKeyIndices.insert(index); + if (largestKeyFound) + hasASingleOperandWithLargestKey = false; + largestKeyFound = true; + } + } + return; +} + +/// Update the key associated with each unassigned operand in `bfsOfOperands`. +/// Updating a key entails making it up-to-date with its associated operand's +/// BFS traversal that has happened till that point in time. Note that a key +/// directly reflects the BFS and thus needs to be updated after every change in +/// the BFS queue, as the traversal happens. +static void updateKeys(SmallVector bfsOfOperands) { + for (OperandBFS *bfsOfOperand : bfsOfOperands) { + if (bfsOfOperand->isAssignedSortedPosition || + bfsOfOperand->ancestorQueue.empty()) + continue; + + Operation *frontAncestor = bfsOfOperand->frontAncestor(); + if (!frontAncestor) { + // When the front ancestor is a block argument, we concatenate the old key + // with such a value that allows its corresponding operand to be + // positioned between operands defined by non-constant-like and + // constant-like operations. + bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("2")).str(); + } else if (frontAncestor->hasTrait()) { + // When the front ancestor is a constant-like operation, we concatenate + // the old key with such a value that allows its corresponding operand to + // be positioned after operands defined by non-constant-like operations or + // block arguments (while maintaining that among constant-like operations, + // the corresponding operands are positioned alphabetically). + bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("3") + + std::string(frontAncestor->getName().getStringRef())) + .str(); + } else { + // When the front ancestor is a non-constant-like operation, we + // concatenate the old key with such a value that allows its corresponding + // operand to be positioned before block arguments or operands defined by + // constant-like operations (while maintaining that among + // non-constant-like operations, the corresponding operands are positioned + // alphabetically). + bfsOfOperand->key = (Twine(bfsOfOperand->key) + Twine("1") + + std::string(frontAncestor->getName().getStringRef())) + .str(); + } + } + return; +} + +/// Rewrite `op`, i.e., rearrange its operands in a "sorted" order. +/// The operands of an op are considered to be "sorted" iff: +/// 1. The op is not commutative, OR, +/// 2. It is commutative and its operands are in ascending order of the "keys" +/// associated with them. +/// +/// Note that `operandDefOps` stores the list of ops defining its operands (in +/// the order in which they appear in `op`). If an operand is a block argument, +/// the op defining it stores null. +static void +rewriteCommutativeOperands(Operation *op, + SmallVector operandDefOps, + PatternRewriter &rewriter) { + // If `op` is not commutative, do nothing. + if (!op->hasTrait()) + return; + + // `bfsOfOperands` stores the BFS traversal information of each operand of + // `op`. For each operand, this information comprises a queue of ancestors + // being visited during the BFS (at a particular point in time), a list of + // visited ancestors (at a particular point in time), its associated key (at a + // particular point in time), and whether or not the operand has been assigned + // a sorted position yet. + SmallVector bfsOfOperands; + + // Initially, each operand's ancestor queue contains the op defining it (which + // is considered its first ancestor). Thus, it acts as the starting point for + // that operand's BFS traversal. + for (Operation *operandDefOp : operandDefOps) { + OperandBFS *bfsOfOperand = new OperandBFS(); + bfsOfOperand->pushAncestor(operandDefOp); + bfsOfOperands.push_back(bfsOfOperand); + } + + // Since none of the operands have been assigned a sorted position yet, the + // smallest unassigned position is set as zero and the largest one is set as + // the number of operands in `op` minus one (N - 1). This is because each + // operand will be assigned a sorted position between 0 and (N - 1), both + // inclusive. + unsigned numOperands = op->getNumOperands(); + unsigned smallestUnassignedPosition = 0; + unsigned largestUnassignedPosition = numOperands - 1; + + // `sortedOperands` will store the list of `op`'s operands in sorted order. + // At first, all elements in it are initialized as null. + SmallVector sortedOperands; + while (numOperands) { + sortedOperands.push_back(nullptr); + numOperands--; + } + + // We perform the BFS traversals of all operands parallelly until each of them + // is assigned a sorted position. During the traversals, we try to assign a + // sorted position to an operand as soon as it is possible (based on a + // comparision of its traversal with the other traversals at that particular + // point in time). + while (hasAtLeastOneUnassignedOperand(bfsOfOperands)) { + // Update the keys corresponding to all unassigned operands. + updateKeys(bfsOfOperands); + + // Stores the indices of the unassigned operands whose key is the smallest. + DenseSet smallestKeyIndices; + // Stores the indices of the unassigned operands whose key is the largest. + DenseSet largestKeyIndices; + + // Stores true iff there is a single unassigned operand that has the + // smallest key. + bool hasASingleOperandWithSmallestKey; + // Stores true iff there is a single unassigned operand that has the largest + // key. + bool hasASingleOperandWithLargestKey; + + getIndicesOfUnassignedOperandsWithSmallestAndLargestKeys( + bfsOfOperands, smallestKeyIndices, largestKeyIndices, + hasASingleOperandWithSmallestKey, hasASingleOperandWithLargestKey); + + // Go through each of the unassigned operands and try to assign it a sorted + // position if possible. + for (auto indexedBfsOfOperand : llvm::enumerate(bfsOfOperands)) { + OperandBFS *bfsOfOperand = indexedBfsOfOperand.value(); + if (bfsOfOperand->isAssignedSortedPosition) + continue; + + unsigned index = indexedBfsOfOperand.index(); + + // If an unassigned operand has the smallest key and: + // 1. It is the only operand with the smallest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `smallestUnassignedPosition` (which will + // be its new position in the rearranged `op`). + // + // Likewise, + // + // If an unassigned operand has the largest key and: + // 1. It is the only operand with the largest key, OR, + // 2. Its BFS is complete, + // then, + // this operand is assigned the `largestUnassignedPosition` (which will be + // its new position in the rearranged `op`). + if (smallestKeyIndices.contains(index) && + (hasASingleOperandWithSmallestKey || + bfsOfOperand->ancestorQueue.empty())) { + bfsOfOperand->isAssignedSortedPosition = true; + sortedOperands[smallestUnassignedPosition] = op->getOperand(index); + smallestUnassignedPosition++; + } else if (largestKeyIndices.contains(index) && + (hasASingleOperandWithLargestKey || + bfsOfOperand->ancestorQueue.empty())) { + bfsOfOperand->isAssignedSortedPosition = true; + sortedOperands[largestUnassignedPosition] = op->getOperand(index); + largestUnassignedPosition--; + } + + // Pop the front ancestor from the queue, if any, and then push its + // adjascent unvisited ancestors, if any, to the queue (the main body of + // the BFS algorithm). + if (bfsOfOperand->ancestorQueue.empty()) + continue; + Operation *frontAncestor = bfsOfOperand->frontAncestor(); + bfsOfOperand->popAncestor(); + if (!frontAncestor) + continue; + for (Value operand : frontAncestor->getOperands()) { + Operation *thisOperandDefOp = operand.getDefiningOp(); + if (!thisOperandDefOp || + !bfsOfOperand->visitedAncestors.contains(thisOperandDefOp)) + bfsOfOperand->pushAncestor(thisOperandDefOp); + } + } + } + rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); +} + +/// Sorts `op`. +/// "Sorting" `op` means to "sort" the ops defining each of its operands +/// followed by rearranging its operands in the "sorted" order. Before the +/// rearrangement, it is important to sort the ops defining its operands so that +/// the rearrangement is deterministic. In other words, if these ops were not +/// sorted, the rearrangement would be non-deterministic and would thus make +/// this utility useless. +void mlir::sortCommutativeOperands(Operation *op, PatternRewriter &rewriter) { + assert(op && "the input argument `op` must not be null"); + + // Before the operands of `op` are rearranged, the operations defining the + // operands of `op` are sorted. + SmallVector operandDefOps; + for (Value operand : op->getOperands()) { + Operation *operandDefOp = operand.getDefiningOp(); + operandDefOps.push_back(operandDefOp); + if (operandDefOp) + sortCommutativeOperands(operandDefOp, rewriter); + } + + // Now, rewrite `op`, i.e, rearrange its operands in a "sorted" order. + rewriteCommutativeOperands(op, operandDefOps, rewriter); + return; +} diff --git a/mlir/test/Transforms/test-commutativity-utils.mlir b/mlir/test/Transforms/test-commutativity-utils.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-commutativity-utils.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s + +// CHECK-LABEL: @test_small_pattern_1 +func @test_small_pattern_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 45 : i32 + + // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi" + %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli + %3 = arith.muli %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_small_pattern_2 +// CHECK-SAME: (%[[ARG0:.*]]: i32 +func @test_small_pattern_2(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant" + %0 = "test.constant"() {value = 0 : i32} : () -> i32 + + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %1 = arith.constant 0 : i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARG0]], %[[ARITH_CONST]], %[[TEST_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_large_pattern +func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK-NEXT: arith.divsi + %0 = arith.divsi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.divsi + %1 = arith.divsi %0, %arg0 : i32 + + // CHECK-NEXT: arith.divsi + %2 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %3 = arith.addi %1, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %4 = arith.subi %2, %3 : i32 + + // CHECK-NEXT: "test.addi" + %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi + %6 = arith.divsi %4, %5 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL8:.*]] = arith.muli + %8 = arith.muli %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL9:.*]] = arith.subi + %9 = arith.subi %7, %8 : i32 + + // CHECK-NEXT: "test.addi" + %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi + %11 = arith.divsi %9, %10 : i32 + + // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi + %12 = arith.divsi %6, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %13 = arith.subi %arg1, %arg0 : i32 + + // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]]) + %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi + %15 = arith.divsi %13, %14 : i32 + + // CHECK-NEXT: %[[VAL16:.*]] = arith.addi + %16 = arith.addi %2, %15 : i32 + + // CHECK-NEXT: arith.subi + %17 = arith.subi %16, %arg1 : i32 + + // CHECK-NEXT: "test.addi" + %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi + %19 = arith.divsi %17, %18 : i32 + + // CHECK-NEXT: "test.addi" + %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi + %21 = arith.divsi %17, %20 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL21]], %[[VAL19]], %[[VAL19]], %[[VAL6]], %[[VAL11]], %[[VAL15]]) + %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %19, %21): (i32, i32, i32, i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1101,11 +1101,21 @@ let hasFolder = 1; } +def TestAddIOp : TEST_Op<"addi"> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } +def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7); + let results = (outs I32); +} + def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2); let results = (outs I32); diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms + TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp diff --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -0,0 +1,67 @@ +//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the functionality of the commutativity utility. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace test; + +namespace { + +struct SmallPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TestCommutativeOp testCommOp, + PatternRewriter &rewriter) const override { + sortCommutativeOperands(testCommOp.getOperation(), rewriter); + return success(); + } +}; + +struct LargePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TestLargeCommutativeOp testLargeCommOp, + PatternRewriter &rewriter) const override { + sortCommutativeOperands(testLargeCommOp.getOperation(), rewriter); + return success(); + } +}; + +struct CommutativityUtils + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils) + + StringRef getArgument() const final { return "test-commutativity-utils"; } + StringRef getDescription() const final { + return "Test the functionality of the commutativity utility"; + } + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + patterns.add(context); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerCommutativityUtils() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -56,6 +56,7 @@ void registerVectorizerTestPass(); namespace test { +void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -146,6 +147,7 @@ registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner(); mlir::test::registerMemRefBoundCheck();