diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp --- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -8,8 +8,10 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/IR/Visitors.h" + +using namespace mlir; -namespace mlir { namespace { #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS @@ -18,26 +20,18 @@ // Define a notion of function equivalence that allows for reuse. Ignore the // symbol name for this purpose. struct DuplicateFuncOpEquivalenceInfo - : public llvm::DenseMapInfo { - - static unsigned getHashValue(const func::FuncOp cFunc) { - if (!cFunc) { - return DenseMapInfo::getHashValue(cFunc); - } - - // Aggregate attributes, ignoring the symbol name. - llvm::hash_code hash = {}; - func::FuncOp func = const_cast(cFunc); - StringAttr symNameAttrName = func.getSymNameAttrName(); - for (NamedAttribute namedAttr : cFunc->getAttrs()) { - StringAttr attrName = namedAttr.getName(); - if (attrName == symNameAttrName) - continue; - hash = llvm::hash_combine(hash, namedAttr); - } + : public llvm::DenseMapInfo { + + static unsigned getHashValue(const FunctionOpInterface cFunc) { + if (!cFunc) + return DenseMapInfo::getHashValue(cFunc); + + // Base hash on function type. + auto &func = const_cast(cFunc); + llvm::hash_code hash = llvm::hash_combine(func.getFunctionType()); // Also hash the func body. - func.getBody().walk([&](Operation *op) { + func.getFunctionBody().walk([&](Operation *op) { hash = llvm::hash_combine( hash, OperationEquivalence::computeHash( op, /*hashOperands=*/OperationEquivalence::ignoreHashValue, @@ -48,35 +42,24 @@ return hash; } - static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) { - if (cLhs == cRhs) { + static bool isEqual(const FunctionOpInterface cLhs, + const FunctionOpInterface cRhs) { + if (cLhs == cRhs) return true; - } if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() || - cRhs == getTombstoneKey() || cRhs == getEmptyKey()) { + cRhs == getTombstoneKey() || cRhs == getEmptyKey()) return false; - } - // Check attributes equivalence, ignoring the symbol name. - if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) { + // Check function type equivalence. + auto &lhs = const_cast(cLhs); + auto &rhs = const_cast(cRhs); + if (lhs.getFunctionType() != rhs.getFunctionType()) return false; - } - func::FuncOp lhs = const_cast(cLhs); - StringAttr symNameAttrName = lhs.getSymNameAttrName(); - for (NamedAttribute namedAttr : cLhs->getAttrs()) { - StringAttr attrName = namedAttr.getName(); - if (attrName == symNameAttrName) { - continue; - } - if (namedAttr.getValue() != cRhs->getAttr(attrName)) { - return false; - } - } // Compare inner workings. - func::FuncOp rhs = const_cast(cRhs); return OperationEquivalence::isRegionEquivalentTo( - &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations); + &lhs.getFunctionBody(), &rhs.getFunctionBody(), + OperationEquivalence::IgnoreLocations); } }; @@ -88,37 +71,51 @@ DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase; void runOnOperation() override { - auto module = getOperation(); - - // Find unique representant per equivalent func ops. - DenseSet uniqueFuncOps; - DenseMap getRepresentant; - DenseSet toBeErased; - module.walk([&](func::FuncOp f) { - auto [repr, inserted] = uniqueFuncOps.insert(f); - getRepresentant[f.getSymNameAttr()] = *repr; - if (!inserted) { - toBeErased.insert(f); - } - }); - // Update call ops to call unique func op representants. - module.walk([&](func::CallOp callOp) { - func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()]; - callOp.setCallee(callee.getSymName()); + // Traverse all symbol table ops. + getOperation().walk([](Operation *root) { + if (!root->hasTrait()) + return; + + // Find unique representant per equivalent func ops. Analyze direct + // children of the root op. + DenseSet + uniqueFuncOps; + DenseMap getRepresentant; + root->walk([&](Operation *op) -> WalkResult { + // Don't enter inner modules. + if (op != root && op->hasTrait()) + return WalkResult::skip(); + + if (auto f = llvm::dyn_cast(op)) { + auto [repr, inserted] = uniqueFuncOps.insert(f); + getRepresentant[f.getNameAttr()] = *repr; + } + + return WalkResult::advance(); + }); + + // Update symbol users. + root->walk([&](Operation *op) -> WalkResult { + // Don't enter inner modules. + if (op != root && op->hasTrait()) + return WalkResult::skip(); + + // Update call sites. + if (func::CallOp callOp = llvm::dyn_cast(op)) { + FunctionOpInterface callee = + getRepresentant[callOp.getCalleeAttr().getAttr()]; + callOp.setCallee(callee.getNameAttr()); + } + + return WalkResult::advance(); + }); }); - - // Erase redundant func ops. - for (auto it : toBeErased) { - it.erase(); - } } }; } // namespace -std::unique_ptr mlir::func::createDuplicateFunctionEliminationPass() { +std::unique_ptr func::createDuplicateFunctionEliminationPass() { return std::make_unique(); } - -} // namespace mlir diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir --- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir +++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir @@ -20,11 +20,11 @@ return %2 : tensor } -// CHECK: @identity -// CHECK-NOT: @also_identity -// CHECK-NOT: @yet_another_identity -// CHECK: @user -// CHECK-3: call @identity +// CHECK: @identity +// CHECK: @also_identity +// CHECK: @yet_another_identity +// CHECK: @user +// CHECK-3: call @identity // ----- @@ -56,12 +56,12 @@ return %3 : f32 } -// CHECK: @add_lr -// CHECK-NOT: @also_add_lr -// CHECK-NOT: @add_rl -// CHECK-NOT: @also_add_rl -// CHECK: @user -// CHECK-4: call @add_lr +// CHECK: @add_lr +// CHECK: @also_add_lr +// CHECK: @add_rl +// CHECK: @also_add_rl +// CHECK: @user +// CHECK-4: call @add_lr // ----- @@ -99,12 +99,12 @@ return %2 : f32 } -// CHECK: @ite -// CHECK-NOT: @also_ite -// CHECK: @reverse_ite -// CHECK: @user -// CHECK-2: call @ite -// CHECK: call @reverse_ite +// CHECK: @ite +// CHECK: @also_ite +// CHECK: @reverse_ite +// CHECK: @user +// CHECK-2: call @ite +// CHECK: call @reverse_ite // ----- @@ -359,9 +359,56 @@ return %0, %1, %2 : f32, f32, f32 } -// CHECK: @deep_tree -// CHECK-NOT: @also_deep_tree -// CHECK: @reverse_deep_tree -// CHECK: @user -// CHECK-2: call @deep_tree -// CHECK: call @reverse_deep_tree +// CHECK: @deep_tree +// CHECK: @also_deep_tree +// CHECK: @reverse_deep_tree +// CHECK: @user +// CHECK-2: call @deep_tree +// CHECK: call @reverse_deep_tree + +// ----- + +module { + + func.func @inner_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor + } + + func.func @also_inner_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor + } + + func.func @inner_user(%arg0: tensor) -> tensor { + %0 = call @inner_identity(%arg0) : (tensor) -> tensor + %1 = call @also_inner_identity(%0) : (tensor) -> tensor + return %1 : tensor + } + +} + +func.func @identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @also_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @user(%arg0: tensor) -> tensor { + %0 = call @identity(%arg0) : (tensor) -> tensor + %1 = call @also_identity(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK: module +// CHECK: module +// CHECK: @inner_identity +// CHECK: @also_inner_identity +// CHECK: @inner_user +// CHECK: call @inner_identity +// CHECK: call @inner_identity +// CHECK: @identity +// CHECK: @also_identity +// CHECK: @user +// CHECK: call @identity +// CHECK: call @identity