diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp --- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -9,7 +9,8 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" -namespace mlir { +using namespace mlir; + namespace { #define GEN_PASS_DEF_DUPLICATEFUNCTIONELIMINATIONPASS @@ -21,9 +22,8 @@ : public llvm::DenseMapInfo { static unsigned getHashValue(const func::FuncOp cFunc) { - if (!cFunc) { + if (!cFunc) return DenseMapInfo::getHashValue(cFunc); - } // Aggregate attributes, ignoring the symbol name. llvm::hash_code hash = {}; @@ -49,28 +49,23 @@ } static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) { - if (cLhs == cRhs) { + if (cLhs == cRhs) return true; - } if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() || - cRhs == getTombstoneKey() || cRhs == getEmptyKey()) { + cRhs == getTombstoneKey() || cRhs == getEmptyKey()) return false; - } // Check attributes equivalence, ignoring the symbol name. - if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) { + if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) return false; - } func::FuncOp lhs = const_cast(cLhs); StringAttr symNameAttrName = lhs.getSymNameAttrName(); for (NamedAttribute namedAttr : cLhs->getAttrs()) { StringAttr attrName = namedAttr.getName(); - if (attrName == symNameAttrName) { + if (attrName == symNameAttrName) continue; - } - if (namedAttr.getValue() != cRhs->getAttr(attrName)) { + if (namedAttr.getValue() != cRhs->getAttr(attrName)) return false; - } } // Compare inner workings. @@ -88,37 +83,32 @@ DuplicateFunctionEliminationPass>::DuplicateFunctionEliminationPassBase; void runOnOperation() override { - auto module = getOperation(); - - // Find unique representant per equivalent func ops. - DenseSet uniqueFuncOps; - DenseMap getRepresentant; - DenseSet toBeErased; - module.walk([&](func::FuncOp f) { - auto [repr, inserted] = uniqueFuncOps.insert(f); - getRepresentant[f.getSymNameAttr()] = *repr; - if (!inserted) { - toBeErased.insert(f); + + getOperation().walk([](ModuleOp module) { + // Find unique representant per equivalent func ops. + DenseSet uniqueFuncOps; + DenseMap getRepresentant; + for (Operation &op : module.getBody()->getOperations()) { + auto f = llvm::dyn_cast(op); + if (!f) + continue; + auto [repr, inserted] = uniqueFuncOps.insert(f); + getRepresentant[f.getSymNameAttr()] = *repr; } - }); - // Update call ops to call unique func op representants. - module.walk([&](func::CallOp callOp) { - func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()]; - callOp.setCallee(callee.getSymName()); + // Update call ops to call unique func op representants. + module.walk([&](func::CallOp callOp) { + if (callOp->getParentOfType() != module) + return; + func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()]; + callOp.setCallee(callee.getSymName()); + }); }); - - // Erase redundant func ops. - for (auto it : toBeErased) { - it.erase(); - } } }; } // namespace -std::unique_ptr mlir::func::createDuplicateFunctionEliminationPass() { +std::unique_ptr func::createDuplicateFunctionEliminationPass() { return std::make_unique(); } - -} // namespace mlir diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir --- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir +++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir @@ -20,11 +20,11 @@ return %2 : tensor } -// CHECK: @identity -// CHECK-NOT: @also_identity -// CHECK-NOT: @yet_another_identity -// CHECK: @user -// CHECK-3: call @identity +// CHECK: @identity +// CHECK: @also_identity +// CHECK: @yet_another_identity +// CHECK: @user +// CHECK-3: call @identity // ----- @@ -56,12 +56,12 @@ return %3 : f32 } -// CHECK: @add_lr -// CHECK-NOT: @also_add_lr -// CHECK-NOT: @add_rl -// CHECK-NOT: @also_add_rl -// CHECK: @user -// CHECK-4: call @add_lr +// CHECK: @add_lr +// CHECK: @also_add_lr +// CHECK: @add_rl +// CHECK: @also_add_rl +// CHECK: @user +// CHECK-4: call @add_lr // ----- @@ -99,12 +99,12 @@ return %2 : f32 } -// CHECK: @ite -// CHECK-NOT: @also_ite -// CHECK: @reverse_ite -// CHECK: @user -// CHECK-2: call @ite -// CHECK: call @reverse_ite +// CHECK: @ite +// CHECK: @also_ite +// CHECK: @reverse_ite +// CHECK: @user +// CHECK-2: call @ite +// CHECK: call @reverse_ite // ----- @@ -359,9 +359,56 @@ return %0, %1, %2 : f32, f32, f32 } -// CHECK: @deep_tree -// CHECK-NOT: @also_deep_tree -// CHECK: @reverse_deep_tree -// CHECK: @user -// CHECK-2: call @deep_tree -// CHECK: call @reverse_deep_tree +// CHECK: @deep_tree +// CHECK: @also_deep_tree +// CHECK: @reverse_deep_tree +// CHECK: @user +// CHECK-2: call @deep_tree +// CHECK: call @reverse_deep_tree + +// ----- + +module { + + func.func @inner_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor + } + + func.func @also_inner_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor + } + + func.func @inner_user(%arg0: tensor) -> tensor { + %0 = call @inner_identity(%arg0) : (tensor) -> tensor + %1 = call @also_inner_identity(%0) : (tensor) -> tensor + return %1 : tensor + } + +} + +func.func @identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @also_identity(%arg0: tensor) -> tensor { + return %arg0 : tensor +} + +func.func @user(%arg0: tensor) -> tensor { + %0 = call @identity(%arg0) : (tensor) -> tensor + %1 = call @also_identity(%0) : (tensor) -> tensor + return %1 : tensor +} + +// CHECK: module +// CHECK: module +// CHECK: @inner_identity +// CHECK: @also_inner_identity +// CHECK: @inner_user +// CHECK: call @inner_identity +// CHECK: call @inner_identity +// CHECK: @identity +// CHECK: @also_identity +// CHECK: @user +// CHECK: call @identity +// CHECK: call @identity