diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h --- a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h +++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h @@ -46,7 +46,7 @@ /// other. /// - If 2 spv.specConstant's have the same spec_id attribute value, then /// replace one of them using the other. -/// - TODO: If 2 spv.func's are identical replace one of them using the other. +/// - If 2 spv.func's are identical replace one of them using the other. /// /// In all cases, the references to the updated symbol (whether renamed or /// deduplicated) are also updated to reflect the change. diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/SymbolTable.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; @@ -155,6 +156,55 @@ return result == WalkResult::interrupt() ? failure() : success(); } +struct BlockEquivalenceData { + BlockEquivalenceData(Block *block); + + llvm::hash_code hash; +}; + +BlockEquivalenceData::BlockEquivalenceData(Block *block) : hash(0) { + for (Operation &op : *block) { + hash = llvm::hash_combine( + hash, OperationEquivalence::computeHash( + &op, OperationEquivalence::Flags::IgnoreOperands)); + } +} + +static LogicalResult deduplicateFunctions(spirv::ModuleOp combinedModule) { + DenseMap hashToFunctionMap; + + WalkResult result = combinedModule.walk([&](spirv::FuncOp func) { + llvm::hash_code funcHash(0); + for (auto attr : func.getOperation()->getAttrs()) { + if (attr.first == SymbolTable::getSymbolAttrName()) + continue; + funcHash = llvm::hash_combine(funcHash, attr); + } + + for (auto &blk : func) { + funcHash = llvm::hash_combine(funcHash, BlockEquivalenceData(&blk).hash); + } + + spirv::FuncOp replacementSymOp = + emplaceOrGetReplacementSymbolName(funcHash, func, hashToFunctionMap); + + if (!replacementSymOp) + return WalkResult::advance(); + + if (failed(SymbolTable::replaceAllSymbolUses( + func, replacementSymOp.getName(), combinedModule))) + return WalkResult(func.emitError("unable to update all symbol uses for ") + << func.sym_name() << " to " + << replacementSymOp.getName()); + + func.erase(); + + return WalkResult::advance(); + }); + + return result == WalkResult::interrupt() ? failure() : success(); +} + namespace mlir { namespace spirv { @@ -271,7 +321,8 @@ } if (failed(deduplicateGlobalVariables(combinedModule)) || - failed(deduplicateSpecConstants(combinedModule))) + failed(deduplicateSpecConstants(combinedModule)) || + failed(deduplicateFunctions(combinedModule))) return nullptr; return combinedModule; diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir --- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir @@ -39,10 +39,12 @@ // CHECK-NEXT: } // CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.FAdd // CHECK-NEXT: spv.ReturnValue // CHECK-NEXT: } // CHECK-NEXT: spv.func @foo_2 +// CHECK-NEXT: spv.ISub // CHECK-NEXT: spv.ReturnValue // CHECK-NEXT: } // CHECK-NEXT: } @@ -57,13 +59,15 @@ spv.module Logical GLSL450 { spv.func @foo(%arg0 : f32) -> f32 "None" { - spv.ReturnValue %arg0 : f32 + %0 = spv.FAdd %arg0, %arg0 : f32 + spv.ReturnValue %0 : f32 } } spv.module Logical GLSL450 { spv.func @foo(%arg0 : i32) -> i32 "None" { - spv.ReturnValue %arg0 : i32 + %0 = spv.ISub %arg0, %arg0 : i32 + spv.ReturnValue %0 : i32 } } } diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir --- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir @@ -15,6 +15,7 @@ // CHECK-NEXT: spv.func @use_bar // CHECK-NEXT: spv._address_of @foo // CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.FAdd // CHECK-NEXT: spv.ReturnValue // CHECK-NEXT: } // CHECK-NEXT: } @@ -37,7 +38,8 @@ spv.func @use_bar() -> f32 "None" { %0 = spv._address_of @bar : !spv.ptr %1 = spv.Load "Input" %0 : f32 - spv.ReturnValue %1 : f32 + %2 = spv.FAdd %1, %1 : f32 + spv.ReturnValue %2 : f32 } } } @@ -107,6 +109,7 @@ // CHECK-NEXT: spv.func @use_bar() // CHECK-NEXT: %0 = spv._reference_of @foo +// CHECK-NEXT: spv.FAdd // CHECK-NEXT: spv.ReturnValue // CHECK-NEXT: } // CHECK-NEXT: } @@ -127,7 +130,102 @@ spv.func @use_bar() -> (f32) "None" { %0 = spv._reference_of @bar : f32 - spv.ReturnValue %0 : f32 + %1 = spv.FAdd %0, %0 : f32 + spv.ReturnValue %1 : f32 + } +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar spec_id(5) + +// CHECK-NEXT: spv.func @foo(%arg0: f32) +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_different_body(%arg0: f32) +// CHECK-NEXT: spv._reference_of +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz(%arg0: i32) +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32) +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz_no_return_different_control +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz_no_return_another_control +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @kernel +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @kernel_different_attr +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar spec_id(5) = 1. : f32 + + spv.func @foo(%arg0: f32) -> (f32) "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.func @foo_duplicate(%arg0: f32) -> (f32) "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.func @foo_different_body(%arg0: f32) -> (f32) "None" { + %0 = spv._reference_of @bar : f32 + spv.ReturnValue %arg0 : f32 + } + + spv.func @baz(%arg0: i32) -> (i32) "None" { + spv.ReturnValue %arg0 : i32 + } + + spv.func @baz_no_return(%arg0: i32) "None" { + spv.Return + } + + spv.func @baz_no_return_duplicate(%arg0: i32) -> () "None" { + spv.Return + } + + spv.func @baz_no_return_different_control(%arg0: i32) -> () "Inline" { + spv.Return + } + + spv.func @baz_no_return_another_control(%arg0: i32) -> () "Inline|Pure" { + spv.Return + } + + spv.func @kernel( + %arg0: f32, + %arg1: !spv.ptr)>, CrossWorkgroup>) "None" + attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { + spv.Return + } + + spv.func @kernel_different_attr( + %arg0: f32, + %arg1: !spv.ptr)>, CrossWorkgroup>) "None" + attributes {spv.entry_point_abi = {local_size = dense<[64, 1, 1]> : vector<3xi32>}} { + spv.Return } } }