diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h --- a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h +++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h @@ -28,7 +28,7 @@ /// combination process proceeds in 2 phases: /// /// (1) resolve conflicts between pairs of ops from different modules -/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) +/// (2) deduplicate equivalent ops/sub-ops in the merged module. /// /// For the conflict resolution phase, the following rules are employed to /// resolve such conflicts: @@ -39,13 +39,22 @@ /// other symbol. /// - If none of the 2 conflicting ops are spv.func, then rename either. /// -/// In all cases, the references to the updated symbol are also updated to -/// reflect the change. +/// For deduplication, the following 3 cases are taken into consideration: +/// +/// - If 2 spv.globalVariable's have either the same descriptor set + binding +/// or the same build_in attribute value, then replace one of them using the +/// other. +/// - If 2 spv.specConstant's have the same spec_id attribute value, then +/// replace one of them using the other. +/// - TODO: If 2 spv.func's are identical replace one of them using the other. +/// +/// In all cases, the references to the updated symbol (whether renamed or +/// deduplicated) are also updated to reflect the change. /// /// \param modules the list of modules to combine. Input modules are not /// modified. /// \param combinedMdouleBuilder an OpBuilder to be used for -/// building up the combined module. +// building up the combined module. /// \param symbRenameListener a listener that gets called everytime a symbol in /// one of the input modules is renamed. The arguments /// passed to the listener are: the input diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/SPIRV/ModuleCombiner.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/SymbolTable.h" @@ -59,6 +60,101 @@ return success(); } +template +static SymbolOpTy emplaceOrGetReplacementSymbolName( + KeyTy key, SymbolOpTy symbolOp, + DenseMap &deduplicationMap) { + auto result = deduplicationMap.try_emplace(key, symbolOp); + + if (result.second) + return SymbolOpTy(); + + return result.first->second; +} + +static LogicalResult +deduplicateGlobalVariables(spirv::ModuleOp combinedModule) { + DenseMap, spirv::GlobalVariableOp> + descriptorToGlobalVarOpMap; + DenseMap builtInToGlobalVarOpMap; + + WalkResult result = + combinedModule.walk([&](spirv::GlobalVariableOp globalVarOp) { + spirv::GlobalVariableOp replacementSymOp; + + IntegerAttr descriptorSet = globalVarOp.getAttrOfType( + spirv::SPIRVDialect::getAttributeName( + spirv::Decoration::DescriptorSet)); + IntegerAttr binding = globalVarOp.getAttrOfType( + spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding)); + + if (descriptorSet) { + replacementSymOp = emplaceOrGetReplacementSymbolName( + {descriptorSet.getInt(), binding.getInt()}, globalVarOp, + descriptorToGlobalVarOpMap); + } else { + StringAttr builtIn = globalVarOp.getAttrOfType( + spirv::SPIRVDialect::getAttributeName( + spirv::Decoration::BuiltIn)); + + if (builtIn) { + replacementSymOp = emplaceOrGetReplacementSymbolName( + + builtIn.getValue(), globalVarOp, builtInToGlobalVarOpMap); + } + } + + if (!replacementSymOp) + return WalkResult::advance(); + + // There is already a global variable with either the same descriptor + // set + binding or the same built_in attribute. Deduplicate the current + // global variable. + if (failed(SymbolTable::replaceAllSymbolUses( + globalVarOp, replacementSymOp.getName(), combinedModule))) + return WalkResult( + globalVarOp.emitError("unable to update all symbol uses for ") + << globalVarOp.getName() << " to " << replacementSymOp.getName()); + + globalVarOp.erase(); + return WalkResult::advance(); + }); + + return result == WalkResult::interrupt() ? failure() : success(); +} + +static LogicalResult deduplicateSpecConstants(spirv::ModuleOp combinedModule) { + DenseMap specIdToSpecConstCompositeMap; + + WalkResult result = + combinedModule.walk([&](spirv::SpecConstantOp specConstOp) { + IntegerAttr specId = specConstOp.getAttrOfType( + spirv::SPIRVDialect::getAttributeName(spirv::Decoration::SpecId)); + + if (specId) { + spirv::SpecConstantOp replacementSymOp = + emplaceOrGetReplacementSymbolName(specId.getInt(), specConstOp, + specIdToSpecConstCompositeMap); + + if (!replacementSymOp) + return WalkResult::advance(); + + if (failed(SymbolTable::replaceAllSymbolUses( + specConstOp, replacementSymOp.getName(), combinedModule))) + return WalkResult( + specConstOp.emitError("unable to update all symbol uses for ") + << specConstOp.sym_name() << " to " + << replacementSymOp.getName()); + + specConstOp.erase(); + } + + return WalkResult::advance(); + }); + + return result == WalkResult::interrupt() ? failure() : success(); +} + namespace mlir { namespace spirv { @@ -174,6 +270,10 @@ combinedModuleBuilder.insert(op.clone()); } + if (failed(deduplicateGlobalVariables(combinedModule)) || + failed(deduplicateSpecConstants(combinedModule))) + return nullptr; + return combinedModule; } diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir --- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir @@ -578,9 +578,9 @@ // CHECK: module { // CHECK-NEXT: spv.module Logical GLSL450 { -// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: spv.globalVariable @foo_1 bind(1, 0) -// CHECK-NEXT: spv.globalVariable @foo +// CHECK-NEXT: spv.globalVariable @foo bind(2, 0) // CHECK-NEXT: } module { @@ -589,7 +589,26 @@ } spv.module Logical GLSL450 { - spv.globalVariable @foo bind(1, 0) : !spv.ptr + spv.globalVariable @foo bind(2, 0) : !spv.ptr +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 built_in("GlobalInvocationId") + +// CHECK-NEXT: spv.globalVariable @foo built_in("LocalInvocationId") +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr, Input> +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo built_in("LocalInvocationId") : !spv.ptr, Input> } } diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication_basic.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s + +// Deduplicate 2 global variables with the same descriptor set and binding. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo + +// CHECK-NEXT: spv.func @use_foo +// CHECK-NEXT: spv._address_of @foo +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @use_bar +// CHECK-NEXT: spv._address_of @foo +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr + + spv.func @use_foo() -> f32 "None" { + %0 = spv._address_of @foo : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} + +spv.module Logical GLSL450 { + spv.globalVariable @bar bind(1, 0) : !spv.ptr + + spv.func @use_bar() -> f32 "None" { + %0 = spv._address_of @bar : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} +} + +// ----- + +// Deduplicate 2 global variables with the same descriptor set and binding but different types. + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.globalVariable @bar bind(1, 0) : !spv.ptr + + spv.func @use_bar() -> f32 "None" { + // expected-error @+1 {{result type mismatch with the referenced global variable's type}} + %0 = spv._address_of @bar : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} +} + +// ----- + +// Deduplicate 2 global variables with the same built-in attribute. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo built_in("GlobalInvocationId") +// CHECK-NEXT: spv.func @use_bar +// CHECK-NEXT: spv._address_of @foo +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo built_in("GlobalInvocationId") : !spv.ptr, Input> +} + +spv.module Logical GLSL450 { + spv.globalVariable @bar built_in("GlobalInvocationId") : !spv.ptr, Input> + + spv.func @use_bar() -> vector<3xi32> "None" { + %0 = spv._address_of @bar : !spv.ptr, Input> + %1 = spv.Load "Input" %0 : vector<3xi32> + spv.ReturnValue %1 : vector<3xi32> + } +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @foo spec_id(5) + +// CHECK-NEXT: spv.func @use_foo() +// CHECK-NEXT: %0 = spv._reference_of @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @use_bar() +// CHECK-NEXT: %0 = spv._reference_of @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @foo spec_id(5) = 1. : f32 + + spv.func @use_foo() -> (f32) "None" { + %0 = spv._reference_of @foo : f32 + spv.ReturnValue %0 : f32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @bar spec_id(5) = 1. : f32 + + spv.func @use_bar() -> (f32) "None" { + %0 = spv._reference_of @bar : f32 + spv.ReturnValue %0 : f32 + } +} +}