diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -28,6 +28,7 @@ // Passes //===----------------------------------------------------------------------===// +#define GEN_PASS_DECL_COMMUTATIVEOPERANDSORT #define GEN_PASS_DECL_CANONICALIZER #define GEN_PASS_DECL_CONTROLFLOWSINK #define GEN_PASS_DECL_CSEPASS @@ -44,6 +45,9 @@ #define GEN_PASS_DECL_TOPOLOGICALSORT #include "mlir/Transforms/Passes.h.inc" +/// Creates a pass to deterministically sort Commutative operation operands +std::unique_ptr createCommutativeOperandSortPass(); + /// Creates an instance of the Canonicalizer pass, configured with default /// settings (which can be overridden by pass options on the command line). std::unique_ptr createCanonicalizerPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -16,6 +16,16 @@ include "mlir/Pass/PassBase.td" include "mlir/Rewrite/PassUtil.td" +def CommutativeOperandSort : Pass<"commutative-operand-sort"> { + let summary = "Sort operands of commutative operands in a deterministic way"; + let description = [{ + This pass implements a deterministic sorting algorithm for commutative + operands. This way commutative operations with swapped operands can be + treated as equal and removed by the CSE pass. + }]; + let constructor = "mlir::createCommutativeOperandSortPass()"; +} + def Canonicalizer : Pass<"canonicalize"> { let summary = "Canonicalize operations"; let description = [{ diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRTransforms Canonicalizer.cpp + CommutativeOperandSort.cpp ControlFlowSink.cpp CSE.cpp GenerateRuntimeVerification.cpp diff --git a/mlir/lib/Transforms/CommutativeOperandSort.cpp b/mlir/lib/Transforms/CommutativeOperandSort.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/CommutativeOperandSort.cpp @@ -0,0 +1,47 @@ +//===- CommutativeOperandSort.cpp - Commutative Operand Sorting Pass ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass sorts operands of commutative operations in a +// deterministic manner so they can be compared and removed in CSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/CommutativityUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_COMMUTATIVEOPERANDSORT +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct CommutativeOperandSortPass + : public impl::CommutativeOperandSortBase { + CommutativeOperandSortPass() = default; + + void runOnOperation() override { + // Populate pattern + RewritePatternSet patterns(&getContext()); + populateCommutativityUtilsPatterns(patterns); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; +} // end anonymous namespace + +// Create a Commutative operand sorting pass +std::unique_ptr mlir::createCommutativeOperandSortPass() { + return std::make_unique(); +} diff --git a/mlir/test/Transforms/commutative-operand-sort.mlir b/mlir/test/Transforms/commutative-operand-sort.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/commutative-operand-sort.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -commutative-operand-sort -cse | FileCheck %s + +// CHECK-LABEL: func @check_commutative_cse +func.func @check_commutative_cse(%a : i32, %b : i32) -> i32 { + %0 = arith.subi %a, %b : i32 + %1 = arith.divsi %a, %b : i32 + // CHECK: %[[ADD1:.*]] = arith.addi %{{.*}}, %{{.*}} : i32 + %2 = arith.addi %0, %1 : i32 + %3 = arith.addi %1, %0 : i32 + // CHECK-NEXT: arith.muli %[[ADD1]], %[[ADD1]] : i32 + %4 = arith.muli %2, %3 : i32 + return %4 : i32 +}