diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h @@ -32,6 +32,9 @@ /// Creates an instance of func bufferization pass. std::unique_ptr createFuncBufferizePass(); +/// Pass to deduplicate functions. +std::unique_ptr createDuplicateFunctionEliminationPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.td @@ -40,4 +40,15 @@ "memref::MemRefDialect"]; } +def DuplicateFunctionEliminationPass : Pass<"duplicate-function-elimination", + "ModuleOp"> { + let summary = "Deduplicate functions"; + let description = [{ + Deduplicate functions that are equivalent in all aspects but their symbol + name. The pass chooses one representative per equivalence class, erases + the remainder, and updates function calls accordingly. + }]; + let constructor = "mlir::func::createDuplicateFunctionEliminationPass()"; +} + #endif // MLIR_DIALECT_FUNC_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRFuncTransforms DecomposeCallGraphTypes.cpp + DuplicateFunctionElimination.cpp FuncBufferize.cpp FuncConversions.cpp diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -0,0 +1,125 @@ +//===- DuplicateFunctionElimination.cpp - Duplicate function elimination --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" + +namespace mlir { + +#define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS +#include "mlir/Dialect/Func/Transforms/Passes.h.inc" + +namespace { + +// Define a notion of function equivalence that allows for reuse. Ignore the +// symbol name for this purpose. +struct DuplicateFuncOpEquivalenceInfo + : public llvm::DenseMapInfo { + + static unsigned getHashValue(const func::FuncOp cFunc) { + if (!cFunc) { + return DenseMapInfo::getHashValue(cFunc); + } + + // Aggregate attributes, ignoring the symbol name. + llvm::hash_code hash = {}; + func::FuncOp func = const_cast(cFunc); + StringAttr symNameAttrName = func.getSymNameAttrName(); + for (NamedAttribute namedAttr : cFunc->getAttrs()) { + StringAttr attrName = namedAttr.getName(); + if (attrName == symNameAttrName) + continue; + hash = llvm::hash_combine(hash, namedAttr); + } + + // Also hash the func body. + func.getBody().walk([&](Operation *op) { + hash = llvm::hash_combine( + hash, OperationEquivalence::computeHash( + op, /*hashOperands=*/OperationEquivalence::ignoreHashValue, + /*hashResults=*/OperationEquivalence::ignoreHashValue, + OperationEquivalence::IgnoreLocations)); + }); + + return hash; + } + + static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) { + if (cLhs == cRhs) { + return true; + } + if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() || + cRhs == getTombstoneKey() || cRhs == getEmptyKey()) { + return false; + } + + // Check attributes equivalence, ignoring the symbol name. + if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) { + return false; + } + func::FuncOp lhs = const_cast(cLhs); + StringAttr symNameAttrName = lhs.getSymNameAttrName(); + for (NamedAttribute namedAttr : cLhs->getAttrs()) { + StringAttr attrName = namedAttr.getName(); + if (attrName == symNameAttrName) { + continue; + } + if (namedAttr.getValue() != cRhs->getAttr(attrName)) { + return false; + } + } + + // Compare inner workings. + func::FuncOp rhs = const_cast(cRhs); + return OperationEquivalence::isRegionEquivalentTo( + &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations); + } +}; + +struct DuplicateFunctionEliminationPass + : public impl::DuplicateFunctionEliminationPassBase< + DuplicateFunctionEliminationPass> { + + using DuplicateFunctionEliminationPassBase< + DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase; + + void runOnOperation() override { + auto module = getOperation(); + + // Find unique representant per equivalent func ops. + DenseSet uniqueFuncOps; + DenseMap getRepresentant; + DenseSet toBeErased; + module.walk([&](func::FuncOp f) { + auto [repr, inserted] = uniqueFuncOps.insert(f); + getRepresentant[f.getSymNameAttr()] = *repr; + if (!inserted) { + toBeErased.insert(f); + } + }); + + // Update call ops to call unique func op representants. + module.walk([&](func::CallOp callOp) { + func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()]; + callOp.setCallee(callee.getSymName()); + }); + + // Erase redundant func ops. + for (auto it : toBeErased) { + it.erase(); + } + } +}; + +} // namespace + +std::unique_ptr mlir::func::createDuplicateFunctionEliminationPass() { + return std::make_unique(); +} + +} // namespace mlir diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir @@ -0,0 +1,367 @@ +// RUN: mlir-opt %s --split-input-file --duplicate-function-elimination | \ +// RUN: FileCheck %s + +func.func @identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @also_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @yet_another_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @user(%arg0: tensor) -> tensor { + %0 = call @identity(%arg0) : (tensor) -> tensor + %1 = call @also_identity(%0) : (tensor) -> tensor + %2 = call @yet_another_identity(%1) : (tensor) -> tensor + return %2 : tensor +} + +// CHECK: @identity +// CHECK: @user +// CHECK: call @identity +// CHECK: call @identity +// CHECK: call @identity + +// ----- + +func.func @add_lr(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg0, %arg1 : f32 + return %0 : f32 +} + +func.func @also_add_lr(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg0, %arg1 : f32 + return %0 : f32 +} + +func.func @add_rl(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg1, %arg0 : f32 + return %0 : f32 +} + +func.func @also_add_rl(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg1, %arg0 : f32 + return %0 : f32 +} + +func.func @user(%arg0: f32, %arg1: f32) -> f32 { + %0 = call @add_lr(%arg0, %arg1) : (f32, f32) -> f32 + %1 = call @also_add_lr(%arg0, %arg1) : (f32, f32) -> f32 + %2 = call @add_rl(%0, %1) : (f32, f32) -> f32 + %3 = call @also_add_rl(%arg0, %2) : (f32, f32) -> f32 + return %3 : f32 +} + +// CHECK: @add_lr +// CHECK: @user +// CHECK: call @add_lr +// CHECK: call @add_lr +// CHECK: call @add_lr +// CHECK: call @add_lr + +// ----- + +func.func @ite(%pred: i1, %then: f32, %else: f32) -> f32 { + %0 = scf.if %pred -> f32 { + scf.yield %then : f32 + } else { + scf.yield %else : f32 + } + return %0 : f32 +} + +func.func @also_ite(%pred: i1, %then: f32, %else: f32) -> f32 { + %0 = scf.if %pred -> f32 { + scf.yield %then : f32 + } else { + scf.yield %else : f32 + } + return %0 : f32 +} + +func.func @reverse_ite(%pred: i1, %then: f32, %else: f32) -> f32 { + %0 = scf.if %pred -> f32 { + scf.yield %else : f32 + } else { + scf.yield %then : f32 + } + return %0 : f32 +} + +func.func @user(%pred : i1, %arg0: f32, %arg1: f32) -> f32 { + %0 = call @also_ite(%pred, %arg0, %arg1) : (i1, f32, f32) -> f32 + %1 = call @ite(%pred, %arg0, %arg1) : (i1, f32, f32) -> f32 + %2 = call @reverse_ite(%pred, %0, %1) : (i1, f32, f32) -> f32 + return %2 : f32 +} + +// CHECK: @ite +// CHECK: @reverse_ite +// CHECK: @user +// CHECK: call @ite +// CHECK: call @ite +// CHECK: call @reverse_ite + +// ----- + +func.func @deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, %odd: f32) + -> f32 { + %0 = scf.if %p0 -> f32 { + %1 = scf.if %p1 -> f32 { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } else { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } + scf.yield %1 : f32 + } else { + %1 = scf.if %p1 -> f32 { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } else { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } + scf.yield %1 : f32 + } + return %0 : f32 +} + +func.func @also_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, + %odd: f32) -> f32 { + %0 = scf.if %p0 -> f32 { + %1 = scf.if %p1 -> f32 { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } else { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } + scf.yield %1 : f32 + } else { + %1 = scf.if %p1 -> f32 { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } else { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } + scf.yield %1 : f32 + } + return %0 : f32 +} + +func.func @reverse_deep_tree(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %even: f32, + %odd: f32) -> f32 { + %0 = scf.if %p0 -> f32 { + %1 = scf.if %p1 -> f32 { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } else { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } + scf.yield %1 : f32 + } else { + %1 = scf.if %p1 -> f32 { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } else { + %2 = scf.if %p2 -> f32 { + %3 = scf.if %p3 -> f32 { + scf.yield %odd : f32 + } else { + scf.yield %even : f32 + } + scf.yield %3 : f32 + } else { + %3 = scf.if %p3 -> f32 { + scf.yield %even : f32 + } else { + scf.yield %odd : f32 + } + scf.yield %3 : f32 + } + scf.yield %2 : f32 + } + scf.yield %1 : f32 + } + return %0 : f32 +} + +func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32) + -> (f32, f32, f32) { + %0 = call @deep_tree(%p0, %p1, %p2, %p3, %odd, %even) + : (i1, i1, i1, i1, f32, f32) -> f32 + %1 = call @also_deep_tree(%p0, %p1, %p2, %p3, %odd, %even) + : (i1, i1, i1, i1, f32, f32) -> f32 + %2 = call @reverse_deep_tree(%p0, %p1, %p2, %p3, %odd, %even) + : (i1, i1, i1, i1, f32, f32) -> f32 + return %0, %1, %2 : f32, f32, f32 +} + +// CHECK: @deep_tree +// CHECK: @reverse_deep_tree +// CHECK: @user +// CHECK: call @deep_tree +// CHECK: call @deep_tree +// CHECK: call @reverse_deep_tree