diff --git a/mlir/include/mlir/Transforms/CommutativityUtils.h b/mlir/include/mlir/Transforms/CommutativityUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/CommutativityUtils.h @@ -0,0 +1,27 @@ +//===- CommutativityUtils.h - Commutativity utilities -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file declares a function to populate the commutativity utility +// pattern. This function is intended to be used inside passes to simplify the +// matching of commutative operations by fixing the order of their operands. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H +#define MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +/// Populates the commutativity utility patterns. +void populateCommutativityUtilsPatterns(RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_COMMUTATIVITYUTILS_H diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt --- a/mlir/lib/Transforms/Utils/CMakeLists.txt +++ b/mlir/lib/Transforms/Utils/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_library(MLIRTransformUtils + CommutativityUtils.cpp ControlFlowSinkUtils.cpp DialectConversion.cpp FoldUtils.cpp diff --git a/mlir/lib/Transforms/Utils/CommutativityUtils.cpp b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/CommutativityUtils.cpp @@ -0,0 +1,305 @@ +//===- CommutativityUtils.cpp - Commutativity utilities ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a commutativity utility pattern and a function to +// populate this pattern. The function is intended to be used inside passes to +// simplify the matching of commutative operations by fixing the order of their +// operands. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include + +using namespace mlir; + +/// The possible "types" of ancestors. Here, an ancestor is an op or a block +/// argument present in the backward slice of a value. +enum AncestorType { + /// Pertains to a block argument. + BLOCK_ARGUMENT, + + /// Pertains to a non-constant-like op. + NON_CONSTANT_OP, + + /// Pertains to a constant-like op. + CONSTANT_OP +}; + +/// Stores the "key" associated with an ancestor. +struct AncestorKey { + /// Holds `BLOCK_ARGUMENT`, `NON_CONSTANT_OP`, or `CONSTANT_OP`, depending on + /// the ancestor. + AncestorType type; + + /// Holds the op name of the ancestor if its `type` is `NON_CONSTANT_OP`. + /// Else, holds "". + StringRef opName; + + /// Constructor for `AncestorKey`. + AncestorKey(Operation *op) { + if (!op) { + type = BLOCK_ARGUMENT; + } else if (!op->hasTrait()) { + type = NON_CONSTANT_OP; + opName = op->getName().getStringRef(); + } else { + type = CONSTANT_OP; + } + } + + /// Overloaded operator `<` for `AncestorKey`. + /// + /// AncestorKeys of type `BLOCK_ARGUMENT` are considered the smallest and + /// those of type `CONSTANT_OP`, the largest. `NON_CONSTANT_OP` types come in + /// between, with the smaller ones being the ones with smaller op names + /// (lexicographically). + bool operator<(const AncestorKey &key) const { + if ((type == BLOCK_ARGUMENT && key.type != BLOCK_ARGUMENT) || + (type == NON_CONSTANT_OP && key.type == CONSTANT_OP)) + return true; + if ((key.type == BLOCK_ARGUMENT && type != BLOCK_ARGUMENT) || + (key.type == NON_CONSTANT_OP && type == CONSTANT_OP)) + return false; + return opName < key.opName; + } +}; + +/// Stores the BFS traversal information of an operand. +struct OperandBFS { + /// Stores the queue of ancestors of the operand's BFS traversal at a + /// particular point in time. + std::queue ancestorQueue; + + /// Stores the list of ancestors that have been visited by the BFS traversal + /// at a particular point in time. + DenseSet visitedAncestors; + + /// Stores the operand's "key". This "key" is defined as a list of the + /// "AncestorKeys" associated with the ancestors of this operand, in a + /// breadth-first order. + /// + /// So, if an operand, say `A`, was produced as follows: + /// + /// `` `` + /// \ / + /// \ / + /// `arith.subi` `arith.constant` + /// \ / + /// `arith.addi` + /// | + /// returns `A` + /// + /// Then, the ancestors of `A`, in the breadth-first order are: + /// `arith.addi`, `arith.subi`, `arith.constant`, ``, and + /// ``. + /// + /// Thus, the "key" associated with operand `A` is: + /// { + /// {type: `NON_CONSTANT_OP`, opName: "arith.addi"}, + /// {type: `NON_CONSTANT_OP`, opName: "arith.subi"}, + /// {type: `CONSTANT_OP`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""}, + /// {type: `BLOCK_ARGUMENT`, opName: ""} + /// } + SmallVector key; + + /// Push an ancestor into the operand's BFS information structure. This + /// entails it being pushed into the queue (always) and inserted into the + /// "visited ancestors" list (iff it is an op rather than a block argument). + void pushAncestor(Operation *op) { + ancestorQueue.push(op); + if (op) + visitedAncestors.insert(op); + return; + } + + /// Refresh the key. + /// + /// Refreshing a key entails making it up-to-date with the operand's BFS + /// traversal that has happened till that point in time, i.e, appending the + /// existing key with the front ancestor's "AncestorKey". Note that a key + /// directly reflects the BFS and thus needs to be refreshed during the + /// progression of the traversal. + void refreshKey() { + if (ancestorQueue.empty()) + return; + + Operation *frontAncestor = ancestorQueue.front(); + AncestorKey frontAncestorKey(frontAncestor); + key.push_back(frontAncestorKey); + return; + } + + /// Pop the front ancestor, if any, from the queue and then push its adjacent + /// unvisited ancestors, if any, to the queue (this is the main body of the + /// BFS algorithm). + void popFrontAndPushAdjacentUnvisitedAncestors() { + if (ancestorQueue.empty()) + return; + Operation *frontAncestor = ancestorQueue.front(); + ancestorQueue.pop(); + if (!frontAncestor) + return; + for (Value operand : frontAncestor->getOperands()) { + Operation *operandDefOp = operand.getDefiningOp(); + if (!operandDefOp || !visitedAncestors.contains(operandDefOp)) + pushAncestor(operandDefOp); + } + return; + } +}; + +/// Sorts the operands of `op` in ascending order of the "key" associated with +/// each operand iff `op` is commutative. +/// +/// After the application of this pattern, since the commutative operands now +/// have a deterministic order in which they occur in an op, the matching of +/// large DAGs becomes much simpler, i.e., requires much less number of checks +/// to be written by a user in her/his pattern matching function. +/// +/// Some examples of such a sorting: +/// +/// Assume that the sorting is being applied to `foo.commutative`, which is a +/// commutative op. +/// +/// Example 1: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul , +/// %3 = foo.commutative %1, %2 +/// +/// Here, +/// 1. The key associated with %1 is: +/// `{ +/// {CONSTANT_OP, ""} +/// }` +/// 2. The key associated with %2 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// +/// The key of %2 < the key of %1 +/// Thus, the sorted `foo.commutative` is: +/// %3 = foo.commutative %2, %1 +/// +/// Example 2: +/// +/// %1 = foo.const 0 +/// %2 = foo.mul , +/// %3 = foo.mul %2, %1 +/// %4 = foo.add %2, %1 +/// %5 = foo.commutative %1, %2, %3, %4 +/// +/// Here, +/// 1. The key associated with %1 is: +/// `{ +/// {CONSTANT_OP, ""} +/// }` +/// 2. The key associated with %2 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// 3. The key associated with %3 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {CONSTANT_OP, ""}, +/// {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// 4. The key associated with %4 is: +/// `{ +/// {NON_CONSTANT_OP, "foo.add"}, +/// {NON_CONSTANT_OP, "foo.mul"}, +/// {CONSTANT_OP, ""}, +/// {BLOCK_ARGUMENT, ""}, +/// {BLOCK_ARGUMENT, ""} +/// }` +/// +/// Thus, the sorted `foo.commutative` is: +/// %5 = foo.commutative %4, %3, %2, %1 +class SortCommutativeOperands : public RewritePattern { +public: + SortCommutativeOperands(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/5, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Stores the mapping between an operand and its BFS traversal information. + DenseMap operandToItsBFSMap; + + // Custom comparator for two commutative operands, which returns true iff + // the "key" of `operandA` < the "key" of `operandB`, i.e., + // 1. In the first unequal pair of corresponding AncestorKeys, the + // AncestorKey in `operandA` is smaller, or, + // 2. Both the AncestorKeys in every pair are the same and the size of + // `operandA`'s "key" is smaller. + auto CommutativeOperandComparator = [&](Value operandA, Value operandB) { + if (operandA == operandB) + return false; + + // Iteratively perform the BFS's of both operands until an order among + // them can be determined. + OperandBFS *bfsA = operandToItsBFSMap.lookup(operandA); + OperandBFS *bfsB = operandToItsBFSMap.lookup(operandB); + unsigned keyIndex = 0; + while (true) { + if (bfsA->key.size() <= keyIndex) { + if (bfsA->ancestorQueue.empty()) + return true; + bfsA->popFrontAndPushAdjacentUnvisitedAncestors(); + bfsA->refreshKey(); + } + if (bfsB->key.size() <= keyIndex) { + if (bfsB->ancestorQueue.empty()) + return false; + bfsB->popFrontAndPushAdjacentUnvisitedAncestors(); + bfsB->refreshKey(); + } + if (bfsA->ancestorQueue.empty() || bfsB->ancestorQueue.empty()) + return bfsA->key.size() < bfsB->key.size(); + if (bfsA->key[keyIndex] < bfsB->key[keyIndex]) + return true; + if (bfsB->key[keyIndex] < bfsA->key[keyIndex]) + return false; + keyIndex++; + } + }; + + // If `op` is not commutative, do nothing. + if (!op->hasTrait()) + return failure(); + + // Populate the mapping between the operands and their BFS traversal + // information. + SmallVector operands = op->getOperands(); + for (unsigned i = 0, e = op->getNumOperands(); i < e; i++) { + OperandBFS *bfs = new OperandBFS(); + bfs->pushAncestor(operands[i].getDefiningOp()); + bfs->refreshKey(); + operandToItsBFSMap.insert({operands[i], bfs}); + } + + // Sort the operands. + SmallVector sortedOperands = operands; + std::stable_sort(sortedOperands.begin(), sortedOperands.end(), + CommutativeOperandComparator); + if (sortedOperands == operands) + return failure(); + rewriter.updateRootInPlace(op, [&] { op->setOperands(sortedOperands); }); + return success(); + } +}; + +void mlir::populateCommutativityUtilsPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Transforms/test-commutativity-utils.mlir b/mlir/test/Transforms/test-commutativity-utils.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-commutativity-utils.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -test-commutativity-utils | FileCheck %s + +// CHECK-LABEL: @test_small_pattern_1 +func.func @test_small_pattern_1(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %0 = arith.constant 45 : i32 + + // CHECK-NEXT: %[[TEST_ADD:.*]] = "test.addi" + %1 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[ARITH_MUL:.*]] = arith.muli + %3 = arith.muli %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARITH_ADD]], %[[ARITH_MUL]], %[[TEST_ADD]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %3): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_small_pattern_2 +// CHECK-SAME: (%[[ARG0:.*]]: i32 +func.func @test_small_pattern_2(%arg0 : i32) -> i32 { + // CHECK-NEXT: %[[TEST_CONST:.*]] = "test.constant" + %0 = "test.constant"() {value = 0 : i32} : () -> i32 + + // CHECK-NEXT: %[[ARITH_CONST:.*]] = arith.constant + %1 = arith.constant 0 : i32 + + // CHECK-NEXT: %[[ARITH_ADD:.*]] = arith.addi + %2 = arith.addi %arg0, %arg0 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_commutative"(%[[ARG0]], %[[ARITH_ADD]], %[[TEST_CONST]], %[[ARITH_CONST]]) + %result = "test.op_commutative"(%0, %1, %2, %arg0): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} + +// CHECK-LABEL: @test_large_pattern +func.func @test_large_pattern(%arg0 : i32, %arg1 : i32) -> i32 { + // CHECK-NEXT: arith.divsi + %0 = arith.divsi %arg0, %arg1 : i32 + + // CHECK-NEXT: arith.divsi + %1 = arith.divsi %0, %arg0 : i32 + + // CHECK-NEXT: arith.divsi + %2 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: arith.addi + %3 = arith.addi %1, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %4 = arith.subi %2, %3 : i32 + + // CHECK-NEXT: "test.addi" + %5 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL6:.*]] = arith.divsi + %6 = arith.divsi %4, %5 : i32 + + // CHECK-NEXT: arith.divsi + %7 = arith.divsi %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL8:.*]] = arith.muli + %8 = arith.muli %1, %arg1 : i32 + + // CHECK-NEXT: %[[VAL9:.*]] = arith.subi + %9 = arith.subi %7, %8 : i32 + + // CHECK-NEXT: "test.addi" + %10 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL11:.*]] = arith.divsi + %11 = arith.divsi %9, %10 : i32 + + // CHECK-NEXT: %[[VAL12:.*]] = arith.divsi + %12 = arith.divsi %6, %arg1 : i32 + + // CHECK-NEXT: arith.subi + %13 = arith.subi %arg1, %arg0 : i32 + + // CHECK-NEXT: "test.op_commutative"(%[[VAL12]], %[[VAL12]], %[[VAL8]], %[[VAL9]]) + %14 = "test.op_commutative"(%12, %9, %12, %8): (i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL15:.*]] = arith.divsi + %15 = arith.divsi %13, %14 : i32 + + // CHECK-NEXT: %[[VAL16:.*]] = arith.addi + %16 = arith.addi %2, %15 : i32 + + // CHECK-NEXT: arith.subi + %17 = arith.subi %16, %arg1 : i32 + + // CHECK-NEXT: "test.addi" + %18 = "test.addi"(%arg0, %arg0): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL19:.*]] = arith.divsi + %19 = arith.divsi %17, %18 : i32 + + // CHECK-NEXT: "test.addi" + %20 = "test.addi"(%arg0, %16): (i32, i32) -> i32 + + // CHECK-NEXT: %[[VAL21:.*]] = arith.divsi + %21 = arith.divsi %17, %20 : i32 + + // CHECK-NEXT: %[[RESULT:.*]] = "test.op_large_commutative"(%[[VAL16]], %[[VAL19]], %[[VAL19]], %[[VAL21]], %[[VAL6]], %[[VAL11]], %[[VAL15]]) + %result = "test.op_large_commutative"(%16, %6, %11, %15, %19, %21, %19): (i32, i32, i32, i32, i32, i32, i32) -> i32 + + // CHECK-NEXT: return %[[RESULT]] + return %result : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1180,11 +1180,21 @@ let hasFolder = 1; } +def TestAddIOp : TEST_Op<"addi"> { + let arguments = (ins I32:$op1, I32:$op2); + let results = (outs I32); +} + def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } +def TestLargeCommutativeOp : TEST_Op<"op_large_commutative", [Commutative]> { + let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4, I32:$op5, I32:$op6, I32:$op7); + let results = (outs I32); +} + def TestCommutative2Op : TEST_Op<"op_commutative2", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2); let results = (outs I32); diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestTransforms + TestCommutativityUtils.cpp TestConstantFold.cpp TestControlFlowSink.cpp TestInlining.cpp diff --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -0,0 +1,48 @@ +//===- TestCommutativityUtils.cpp - Pass to test the commutativity utility-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass tests the functionality of the commutativity utility pattern. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/CommutativityUtils.h" + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { + +struct CommutativityUtils + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CommutativityUtils) + + StringRef getArgument() const final { return "test-commutativity-utils"; } + StringRef getDescription() const final { + return "Test the functionality of the commutativity utility"; + } + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + RewritePatternSet patterns(context); + populateCommutativityUtilsPatterns(patterns); + + (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerCommutativityUtils() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -57,6 +57,7 @@ void registerVectorizerTestPass(); namespace test { +void registerCommutativityUtils(); void registerConvertCallOpPass(); void registerInliner(); void registerMemRefBoundCheck(); @@ -153,6 +154,7 @@ registerVectorizerTestPass(); registerTosaTestQuantUtilAPIPass(); + mlir::test::registerCommutativityUtils(); mlir::test::registerConvertCallOpPass(); mlir::test::registerInliner(); mlir::test::registerMemRefBoundCheck();