diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -28,6 +28,7 @@ // Passes //===----------------------------------------------------------------------===// +#define GEN_PASS_DECL_COMMUTATIVEOPERANDSORT #define GEN_PASS_DECL_CANONICALIZER #define GEN_PASS_DECL_CONTROLFLOWSINK #define GEN_PASS_DECL_CSEPASS @@ -61,6 +62,9 @@ ArrayRef disabledPatterns = std::nullopt, ArrayRef enabledPatterns = std::nullopt); +/// Creates a pass to deterministically sort Commutative operation operands +std::unique_ptr createCommutativeOperandSortPass(); + /// Creates a pass to perform control-flow sinking. std::unique_ptr createControlFlowSinkPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -45,6 +45,20 @@ ] # RewritePassUtils.options; } +def CommutativeOperandSort : Pass<"commutative-operand-sort"> { + let summary = "Sort operands of commutative operands in a deterministic way"; + let description = [{ + This pass calls a deterministic sorting algorithm for commutative + operands. This way commutative operations with swapped operands can be + treated as equal and removed by the CSE pass. + + The sorting algorithm looks at the backward slice of each operand + to establish an ordering. See + mlir/lib/Transforms/Utils/CommutativityUtils.cpp for more details. + }]; + let constructor = "mlir::createCommutativeOperandSortPass()"; +} + def ControlFlowSink : Pass<"control-flow-sink"> { let summary = "Sink operations into conditional blocks"; let description = [{ diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRTransforms Canonicalizer.cpp + CommutativeOperandSort.cpp ControlFlowSink.cpp CSE.cpp GenerateRuntimeVerification.cpp diff --git a/mlir/lib/Transforms/CommutativeOperandSort.cpp b/mlir/lib/Transforms/CommutativeOperandSort.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/CommutativeOperandSort.cpp @@ -0,0 +1,47 @@ +//===- CommutativeOperandSort.cpp - Commutative Operand Sorting Pass ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass sorts operands of commutative operations in a +// deterministic manner so they can be compared and removed in CSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/CommutativityUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_COMMUTATIVEOPERANDSORT +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct CommutativeOperandSortPass + : public impl::CommutativeOperandSortBase { + CommutativeOperandSortPass() = default; + + void runOnOperation() override { + // Populate pattern + RewritePatternSet patterns(&getContext()); + populateCommutativityUtilsPatterns(patterns); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +// Create a Commutative operand sorting pass +std::unique_ptr mlir::createCommutativeOperandSortPass() { + return std::make_unique(); +} diff --git a/mlir/test/Transforms/commutative-operand-sort.mlir b/mlir/test/Transforms/commutative-operand-sort.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/commutative-operand-sort.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -commutative-operand-sort -cse | FileCheck %s + +// CHECK-LABEL: func @check_commutative_cse +func.func @check_commutative_cse(%a : i32, %b : i32) -> i32 { + %0 = arith.subi %a, %b : i32 + %1 = arith.divsi %a, %b : i32 + // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + %2 = arith.addi %0, %1 : i32 + %3 = arith.addi %1, %0 : i32 + // CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32 + %4 = arith.muli %2, %3 : i32 + return %4 : i32 +} diff --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp --- a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp +++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -36,7 +36,10 @@ RewritePatternSet patterns(context); populateCommutativityUtilsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + FrozenRewritePatternSet frozenPatterns = + FrozenRewritePatternSet(std::move(patterns)); + + (void)applyPatternsAndFoldGreedily(func, frozenPatterns); } }; } // namespace