diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h --- a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h +++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h @@ -28,7 +28,7 @@ /// combination process proceeds in 2 phases: /// /// (1) resolve conflicts between pairs of ops from different modules -/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) +/// (2) deduplicate equivalent ops/sub-ops in the merged module. /// /// For the conflict resolution phase, the following rules are employed to /// resolve such conflicts: @@ -39,13 +39,22 @@ /// other symbol. /// - If none of the 2 conflicting ops are spv.func, then rename either. /// -/// In all cases, the references to the updated symbol are also updated to -/// reflect the change. +/// For deduplication, the following 3 cases are taken into consideration: +/// +/// - If 2 spv.globalVariable's have either the same descriptor set + binding +/// or the same build_in attribute value, then replace one of them using the +/// other. +/// - If 2 spv.specConstant's have the same spec_id attribute value, then +/// replace one of them using the other. +/// - If 2 spv.func's are identical replace one of them using the other. +/// +/// In all cases, the references to the updated symbol (whether renamed or +/// deduplicated) are also updated to reflect the change. /// /// \param modules the list of modules to combine. Input modules are not /// modified. /// \param combinedMdouleBuilder an OpBuilder to be used for -/// building up the combined module. +// building up the combined module. /// \param symbRenameListener a listener that gets called everytime a symbol in /// one of the input modules is renamed. The arguments /// passed to the listener are: the input diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -12,10 +12,12 @@ #include "mlir/Dialect/SPIRV/ModuleCombiner.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/SymbolTable.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; @@ -59,6 +61,56 @@ return success(); } +template +static SymbolOpTy +emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp, + DenseMap &deduplicationMap) { + auto result = deduplicationMap.try_emplace(key, symbolOp); + + if (result.second) + return SymbolOpTy(); + + return result.first->second; +} + +/// Computes a hash code to represent the argument SymbolOpInterface based on +/// all the Op's attributes except for the symbol name. +/// +/// \return the hash code computed from the Op's attributes as described above. +/// +/// Note: We use the operation's name (not the symbol name) as part of the hash +/// computation. This prevents, for example, mistakenly considering a global +/// variable and a spec constant as duplicates because their descriptor set + +/// binding and spec_id, repectively, happen to hash to the same value. +static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { + llvm::hash_code hashCode(0); + hashCode = llvm::hash_combine(symbolOp.getOperation()->getName()); + + for (auto attr : symbolOp.getOperation()->getAttrs()) { + if (attr.first == SymbolTable::getSymbolAttrName()) + continue; + hashCode = llvm::hash_combine(hashCode, attr); + } + + return hashCode; +} + +struct BlockEquivalenceData { + BlockEquivalenceData(Block *block); + + llvm::hash_code hash; +}; + +BlockEquivalenceData::BlockEquivalenceData(Block *block) : hash(0) { + // TODO: Properly handle operations with regions. + for (Operation &op : *block) + hash = llvm::hash_combine( + hash, (op.getNumRegions() == 0) + ? OperationEquivalence::computeHash( + &op, OperationEquivalence::Flags::IgnoreOperands) + : llvm::hash_code(0)); +} + namespace mlir { namespace spirv { @@ -174,6 +226,43 @@ combinedModuleBuilder.insert(op.clone()); } + // Deduplicate identical global variables, spec constants, and functions. + DenseMap hashToSymbolOp; + SmallVector eraseList; + + for (auto &op : combinedModule.getBlock().without_terminator()) { + llvm::hash_code hashCode(0); + SymbolOpInterface symbolOp = dyn_cast(op); + + if (!symbolOp) + continue; + + hashCode = computeHash(symbolOp); + + if (auto funcOp = dyn_cast(op)) + for (auto &blk : funcOp) + hashCode = + llvm::hash_combine(hashCode, BlockEquivalenceData(&blk).hash); + + SymbolOpInterface replacementSymOp = + emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp); + + if (!replacementSymOp) + continue; + + if (failed(SymbolTable::replaceAllSymbolUses( + symbolOp, replacementSymOp.getName(), combinedModule))) { + symbolOp.emitError("unable to update all symbol uses for ") + << symbolOp.getName() << " to " << replacementSymOp.getName(); + return nullptr; + } + + eraseList.push_back(symbolOp); + } + + for (auto symbolOp : eraseList) + symbolOp.erase(); + return combinedModule; } diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir --- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir @@ -39,10 +39,12 @@ // CHECK-NEXT: } // CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.FAdd // CHECK-NEXT: spv.ReturnValue // CHECK-NEXT: } // CHECK-NEXT: spv.func @foo_2 +// CHECK-NEXT: spv.ISub // CHECK-NEXT: spv.ReturnValue // CHECK-NEXT: } // CHECK-NEXT: } @@ -57,13 +59,15 @@ spv.module Logical GLSL450 { spv.func @foo(%arg0 : f32) -> f32 "None" { - spv.ReturnValue %arg0 : f32 + %0 = spv.FAdd %arg0, %arg0 : f32 + spv.ReturnValue %0 : f32 } } spv.module Logical GLSL450 { spv.func @foo(%arg0 : i32) -> i32 "None" { - spv.ReturnValue %arg0 : i32 + %0 = spv.ISub %arg0, %arg0 : i32 + spv.ReturnValue %0 : i32 } } } @@ -578,9 +582,9 @@ // CHECK: module { // CHECK-NEXT: spv.module Logical GLSL450 { -// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: spv.globalVariable @foo_1 bind(1, 0) -// CHECK-NEXT: spv.globalVariable @foo +// CHECK-NEXT: spv.globalVariable @foo bind(2, 0) // CHECK-NEXT: } module { @@ -589,7 +593,26 @@ } spv.module Logical GLSL450 { - spv.globalVariable @foo bind(1, 0) : !spv.ptr + spv.globalVariable @foo bind(2, 0) : !spv.ptr +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 built_in("GlobalInvocationId") + +// CHECK-NEXT: spv.globalVariable @foo built_in("LocalInvocationId") +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr, Input> +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo built_in("LocalInvocationId") : !spv.ptr, Input> } } diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir @@ -0,0 +1,244 @@ +// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s + +// Deduplicate 2 global variables with the same descriptor set and binding. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo + +// CHECK-NEXT: spv.func @use_foo +// CHECK-NEXT: spv._address_of @foo +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @use_bar +// CHECK-NEXT: spv._address_of @foo +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.FAdd +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr + + spv.func @use_foo() -> f32 "None" { + %0 = spv._address_of @foo : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} + +spv.module Logical GLSL450 { + spv.globalVariable @bar bind(1, 0) : !spv.ptr + + spv.func @use_bar() -> f32 "None" { + %0 = spv._address_of @bar : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + %2 = spv.FAdd %1, %1 : f32 + spv.ReturnValue %2 : f32 + } +} +} + +// ----- + +// Deduplicate 2 global variables with the same descriptor set and binding but different types. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo bind(1, 0) + +// CHECK-NEXT: spv.globalVariable @bar bind(1, 0) + +// CHECK-NEXT: spv.func @use_bar +// CHECK-NEXT: spv._address_of @bar +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.globalVariable @bar bind(1, 0) : !spv.ptr + + spv.func @use_bar() -> f32 "None" { + %0 = spv._address_of @bar : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} +} + +// ----- + +// Deduplicate 2 global variables with the same built-in attribute. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo built_in("GlobalInvocationId") +// CHECK-NEXT: spv.func @use_bar +// CHECK-NEXT: spv._address_of @foo +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr, Input> +} + +spv.module Logical GLSL450 { + spv.globalVariable @bar built_in("GlobalInvocationId") : !spv.ptr, Input> + + spv.func @use_bar() -> vector<3xi32> "None" { + %0 = spv._address_of @bar : !spv.ptr, Input> + %1 = spv.Load "Input" %0 : vector<3xi32> + spv.ReturnValue %1 : vector<3xi32> + } +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @foo spec_id(5) + +// CHECK-NEXT: spv.func @use_foo() +// CHECK-NEXT: %0 = spv._reference_of @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @use_bar() +// CHECK-NEXT: %0 = spv._reference_of @foo +// CHECK-NEXT: spv.FAdd +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @foo spec_id(5) = 1. : f32 + + spv.func @use_foo() -> (f32) "None" { + %0 = spv._reference_of @foo : f32 + spv.ReturnValue %0 : f32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @bar spec_id(5) = 1. : f32 + + spv.func @use_bar() -> (f32) "None" { + %0 = spv._reference_of @bar : f32 + %1 = spv.FAdd %0, %0 : f32 + spv.ReturnValue %1 : f32 + } +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar spec_id(5) + +// CHECK-NEXT: spv.func @foo(%arg0: f32) +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_different_body(%arg0: f32) +// CHECK-NEXT: spv._reference_of +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz(%arg0: i32) +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32) +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz_no_return_different_control +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @baz_no_return_another_control +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @kernel +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @kernel_different_attr +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar spec_id(5) = 1. : f32 + + spv.func @foo(%arg0: f32) -> (f32) "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.func @foo_duplicate(%arg0: f32) -> (f32) "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.func @foo_different_body(%arg0: f32) -> (f32) "None" { + %0 = spv._reference_of @bar : f32 + spv.ReturnValue %arg0 : f32 + } + + spv.func @baz(%arg0: i32) -> (i32) "None" { + spv.ReturnValue %arg0 : i32 + } + + spv.func @baz_no_return(%arg0: i32) "None" { + spv.Return + } + + spv.func @baz_no_return_duplicate(%arg0: i32) -> () "None" { + spv.Return + } + + spv.func @baz_no_return_different_control(%arg0: i32) -> () "Inline" { + spv.Return + } + + spv.func @baz_no_return_another_control(%arg0: i32) -> () "Inline|Pure" { + spv.Return + } + + spv.func @kernel( + %arg0: f32, + %arg1: !spv.ptr)>, CrossWorkgroup>) "None" + attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { + spv.Return + } + + spv.func @kernel_different_attr( + %arg0: f32, + %arg1: !spv.ptr)>, CrossWorkgroup>) "None" + attributes {spv.entry_point_abi = {local_size = dense<[64, 1, 1]> : vector<3xi32>}} { + spv.Return + } +} +}