diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ #define MLIR_DIALECT_SPIRV_IR_SPIRVOPS_H_ +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td @@ -467,8 +467,9 @@ let builders = [ OpBuilder<(ins CArg<"Optional", "llvm::None">:$name)>, OpBuilder<(ins "spirv::AddressingModel":$addressing_model, - "spirv::MemoryModel":$memory_model, - CArg<"Optional", "llvm::None">:$name)> + "spirv::MemoryModel":$memory_model, + CArg<"Optional", "llvm::None">:$vce_triple, + CArg<"Optional", "llvm::None">:$name)> ]; // We need to ensure the block inside the region is properly terminated; diff --git a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h --- a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h +++ b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h @@ -22,53 +22,54 @@ namespace spirv { class ModuleOp; -/// To combine a number of MLIR SPIR-V modules, we move all the module-level ops +/// The listener function to receive symbol renaming events. +/// +/// `originalModule` is the input spirv::ModuleOp that contains the renamed +/// symbol. `oldSymbol` and `newSymbol` are the original and renamed symbol. +/// Note that it's the responsibility of the caller to properly retain the +/// storage underlying the passed StringRefs if the listener callback outlives +/// this function call. +using SymbolRenameListener = function_ref; + +/// Combines a list of SPIR-V `inputModules` into one. Returns the combined +/// module on success; returns a null module otherwise. +// +/// \param inputModules the list of modules to combine. They won't be modified. +/// \param combinedMdouleBuilder an OpBuilder for building the combined module. +/// \param symbRenameListener a listener that gets called everytime a symbol in +/// one of the input modules is renamed. +/// +/// To combine multiple SPIR-V modules, we move all the module-level ops /// from all the input modules into one big combined module. To that end, the /// combination process proceeds in 2 phases: /// -/// (1) resolve conflicts between pairs of ops from different modules -/// (2) deduplicate equivalent ops/sub-ops in the merged module. +/// 1. resolve conflicts between pairs of ops from different modules, +/// 2. deduplicate equivalent ops/sub-ops in the merged module. /// /// For the conflict resolution phase, the following rules are employed to /// resolve such conflicts: /// -/// - If 2 spv.func's have the same symbol name, then rename one of the +/// - If 2 spv.func's have the same symbol name, then rename one of the /// functions. -/// - If an spv.func and another op have the same symbol name, then rename the +/// - If an spv.func and another op have the same symbol name, then rename the /// other symbol. -/// - If none of the 2 conflicting ops are spv.func, then rename either. +/// - If none of the 2 conflicting ops are spv.func, then rename either. /// /// For deduplication, the following 3 cases are taken into consideration: /// -/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding +/// - If 2 spv.GlobalVariable's have either the same descriptor set + binding /// or the same build_in attribute value, then replace one of them using the /// other. -/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then +/// - If 2 spv.SpecConstant's have the same spec_id attribute value, then /// replace one of them using the other. -/// - If 2 spv.func's are identical replace one of them using the other. +/// - Deduplicating functions are not supported right now. /// /// In all cases, the references to the updated symbol (whether renamed or /// deduplicated) are also updated to reflect the change. -/// -/// \param modules the list of modules to combine. Input modules are not -/// modified. -/// \param combinedMdouleBuilder an OpBuilder to be used for -// building up the combined module. -/// \param symbRenameListener a listener that gets called everytime a symbol in -/// one of the input modules is renamed. The arguments -/// passed to the listener are: the input -/// spirv::ModuleOp that contains the renamed symbol, -/// a StringRef to the old symbol name, and a -/// StringRef to the new symbol name. Note that it is -/// the responsibility of the caller to properly -/// retain the storage underlying the passed -/// StringRefs if the listener callback outlives this -/// function call. -/// -/// \return the combined module. -OwningOpRef -combine(MutableArrayRef modules, OpBuilder &combinedModuleBuilder, - function_ref symbRenameListener); +OwningOpRef combine(ArrayRef inputModules, + OpBuilder &combinedModuleBuilder, + SymbolRenameListener symRenameListener); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -310,7 +310,7 @@ // Add a keyword to the module name to avoid symbolic conflict. std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str(); auto spvModule = rewriter.create( - moduleOp.getLoc(), addressingModel, memoryModel.getValue(), + moduleOp.getLoc(), addressingModel, memoryModel.getValue(), llvm::None, StringRef(spvModuleName)); // Move the region from the module op into the SPIR-V module. diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2540,6 +2540,7 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, spirv::AddressingModel addressingModel, spirv::MemoryModel memoryModel, + Optional vceTriple, Optional name) { state.addAttribute( "addressing_model", @@ -2548,10 +2549,11 @@ static_cast(memoryModel))); OpBuilder::InsertionGuard guard(builder); builder.createBlock(state.addRegion()); - if (name) { - state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), - builder.getStringAttr(*name)); - } + if (vceTriple) + state.addAttribute(getVCETripleAttrName(), *vceTriple); + if (name) + state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), + builder.getStringAttr(*name)); } static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) { diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -12,27 +12,33 @@ #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/SymbolTable.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" using namespace mlir; static constexpr unsigned maxFreeID = 1 << 20; +/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric +/// suffix in `lastUsedID`. static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, - spirv::ModuleOp combinedModule) { + spirv::ModuleOp module) { SmallString<64> newSymName(oldSymName); newSymName.push_back('_'); while (lastUsedID < maxFreeID) { std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); - if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) { + if (!SymbolTable::lookupSymbolIn(module, possible)) { newSymName += llvm::utostr(lastUsedID); break; } @@ -41,8 +47,8 @@ return newSymName; } -/// Check if a symbol with the same name as op already exists in source. If so, -/// rename op and update all its references in target. +/// Checks if a symbol with the same name as `op` already exists in `source`. +/// If so, renames `op` and updates all its references in `target`. static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, spirv::ModuleOp target, spirv::ModuleOp source, @@ -61,99 +67,67 @@ return success(); } -template -static SymbolOpTy -emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp, - DenseMap &deduplicationMap) { - auto result = deduplicationMap.try_emplace(key, symbolOp); - - if (result.second) - return SymbolOpTy(); - - return result.first->second; -} - -/// Computes a hash code to represent the argument SymbolOpInterface based on -/// all the Op's attributes except for the symbol name. -/// -/// \return the hash code computed from the Op's attributes as described above. +/// Computes a hash code to represent `symbolOp` based on all its attributes +/// except for the symbol name. /// /// Note: We use the operation's name (not the symbol name) as part of the hash /// computation. This prevents, for example, mistakenly considering a global /// variable and a spec constant as duplicates because their descriptor set + /// binding and spec_id, respectively, happen to hash to the same value. static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { - llvm::hash_code hashCode(0); - hashCode = llvm::hash_combine(symbolOp->getName()); - - for (auto attr : symbolOp->getAttrs()) { - if (attr.first == SymbolTable::getSymbolAttrName()) - continue; - hashCode = llvm::hash_combine(hashCode, attr); - } - - return hashCode; -} - -/// Computes a hash code from the argument Block. -llvm::hash_code computeHash(Block *block) { - // TODO: Consider extracting BlockEquivalenceData into a common header and - // re-using it here. - llvm::hash_code hash(0); - - for (Operation &op : *block) { - // TODO: Properly handle operations with regions. - if (op.getNumRegions() > 0) - return 0; - - hash = llvm::hash_combine( - hash, OperationEquivalence::computeHash( - &op, OperationEquivalence::Flags::IgnoreOperands)); - } - - return hash; + auto range = + llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) { + return attr.first != SymbolTable::getSymbolAttrName(); + }); + + return llvm::hash_combine( + symbolOp->getName(), + llvm::hash_combine_range(range.begin(), range.end())); } namespace mlir { namespace spirv { -// TODO Properly test symbol rename listener mechanism. - -OwningOpRef -combine(llvm::MutableArrayRef modules, - OpBuilder &combinedModuleBuilder, - llvm::function_ref - symRenameListener) { - unsigned lastUsedID = 0; - - if (modules.empty()) +OwningOpRef combine(ArrayRef inputModules, + OpBuilder &combinedModuleBuilder, + SymbolRenameListener symRenameListener) { + if (inputModules.empty()) return nullptr; - auto addressingModel = modules[0].addressing_model(); - auto memoryModel = modules[0].memory_model(); + spirv::ModuleOp firstModule = inputModules.front(); + auto addressingModel = firstModule.addressing_model(); + auto memoryModel = firstModule.memory_model(); + auto vceTriple = firstModule.vce_triple(); + + // First check whether there are conflicts between addressing/memory model. + // Return early if so. + for (auto module : inputModules) { + if (module.addressing_model() != addressingModel || + module.memory_model() != memoryModel || + module.vce_triple() != vceTriple) { + module.emitError("input modules differ in addressing model, memory " + "model, and/or VCE triple"); + return nullptr; + } + } auto combinedModule = combinedModuleBuilder.create( - modules[0].getLoc(), addressingModel, memoryModel); + firstModule.getLoc(), addressingModel, memoryModel, vceTriple); combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody()); // In some cases, a symbol in the (current state of the) combined module is - // renamed in order to maintain the conflicting symbol in the input module + // renamed in order to enable the conflicting symbol in the input module // being merged. For example, if the conflict is between a global variable in // the current combined module and a function in the input module, the global // variable is renamed. In order to notify listeners of the symbol updates in // such cases, we need to keep track of the module from which the renamed // symbol in the combined module originated. This map keeps such information. - DenseMap symNameToModuleMap; + llvm::StringMap symNameToModuleMap; - for (auto module : modules) { - if (module.addressing_model() != addressingModel || - module.memory_model() != memoryModel) { - module.emitError( - "input modules differ in addressing model and/or memory model"); - return nullptr; - } + unsigned lastUsedID = 0; - spirv::ModuleOp moduleClone = module.clone(); + for (auto inputModule : inputModules) { + spirv::ModuleOp moduleClone = inputModule.clone(); // In the combined module, rename all symbols that conflict with symbols // from the current input module. This renaming applies to all ops except @@ -161,65 +135,70 @@ // non-spv.func, we rename that symbol instead and maintain the spv.func in // the combined module name as it is. for (auto &op : *combinedModule.getBody()) { - if (auto symbolOp = dyn_cast(op)) { - StringRef oldSymName = symbolOp.getName(); + auto symbolOp = dyn_cast(op); + if (!symbolOp) + continue; - if (!isa(op) && - failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, - lastUsedID))) - return nullptr; + StringRef oldSymName = symbolOp.getName(); - StringRef newSymName = symbolOp.getName(); + if (!isa(op) && + failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, + lastUsedID))) + return nullptr; - if (symRenameListener && oldSymName != newSymName) { - spirv::ModuleOp originalModule = - symNameToModuleMap.lookup(oldSymName); + StringRef newSymName = symbolOp.getName(); - if (!originalModule) { - module.emitError("unable to find original ModuleOp for symbol ") - << oldSymName; - return nullptr; - } + if (symRenameListener && oldSymName != newSymName) { + spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName); - symRenameListener(originalModule, oldSymName, newSymName); - - // Since the symbol name is updated, there is no need to maintain the - // entry that associates the old symbol name with the original module. - symNameToModuleMap.erase(oldSymName); - // Instead, add a new entry to map the new symbol name to the original - // module in case it gets renamed again later. - symNameToModuleMap[newSymName] = originalModule; + if (!originalModule) { + inputModule.emitError( + "unable to find original spirv::ModuleOp for symbol ") + << oldSymName; + return nullptr; } + + symRenameListener(originalModule, oldSymName, newSymName); + + // Since the symbol name is updated, there is no need to maintain the + // entry that associates the old symbol name with the original module. + symNameToModuleMap.erase(oldSymName); + // Instead, add a new entry to map the new symbol name to the original + // module in case it gets renamed again later. + symNameToModuleMap[newSymName] = originalModule; } } // In the current input module, rename all symbols that conflict with // symbols from the combined module. This includes renaming spv.funcs. for (auto &op : *moduleClone.getBody()) { - if (auto symbolOp = dyn_cast(op)) { - StringRef oldSymName = symbolOp.getName(); + auto symbolOp = dyn_cast(op); + if (!symbolOp) + continue; - if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, - lastUsedID))) - return nullptr; + StringRef oldSymName = symbolOp.getName(); - StringRef newSymName = symbolOp.getName(); + if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, + lastUsedID))) + return nullptr; - if (symRenameListener && oldSymName != newSymName) { - symRenameListener(module, oldSymName, newSymName); + StringRef newSymName = symbolOp.getName(); - // Insert the module associated with the symbol name. - auto emplaceResult = - symNameToModuleMap.try_emplace(symbolOp.getName(), module); + if (symRenameListener) { + if (oldSymName != newSymName) + symRenameListener(inputModule, oldSymName, newSymName); - // If an entry with the same symbol name is already present, this must - // be a problem with the implementation, specially clean-up of the map - // while iterating over the combined module above. - if (!emplaceResult.second) { - module.emitError("did not expect to find an entry for symbol ") - << symbolOp.getName(); - return nullptr; - } + // Insert the module associated with the symbol name. + auto emplaceResult = + symNameToModuleMap.try_emplace(newSymName, inputModule); + + // If an entry with the same symbol name is already present, this must + // be a problem with the implementation, specially clean-up of the map + // while iterating over the combined module above. + if (!emplaceResult.second) { + inputModule.emitError("did not expect to find an entry for symbol ") + << symbolOp.getName(); + return nullptr; } } } @@ -234,30 +213,26 @@ SmallVector eraseList; for (auto &op : *combinedModule.getBody()) { - llvm::hash_code hashCode(0); SymbolOpInterface symbolOp = dyn_cast(op); - if (!symbolOp) continue; - hashCode = computeHash(symbolOp); - - // A 0 hash code means the op is not suitable for deduplication and should - // be skipped. An example of this is when a function has ops with regions - // which are not properly supported yet. - if (!hashCode) + // Do not support ops with operands or results. + // Global variables, spec constants, and functions won't have + // operands/results, but just for safety here. + if (op.getNumOperands() != 0 || op.getNumResults() != 0) continue; - if (auto funcOp = dyn_cast(op)) - for (auto &blk : funcOp) - hashCode = llvm::hash_combine(hashCode, computeHash(&blk)); - - SymbolOpInterface replacementSymOp = - emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp); + // Deduplicating functions are not supported yet. + if (isa(op)) + continue; - if (!replacementSymOp) + auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp); + if (result.second) continue; + SymbolOpInterface replacementSymOp = result.first->second; + if (failed(SymbolTable::replaceAllSymbolUses( symbolOp, replacementSymOp.getName(), combinedModule))) { symbolOp.emitError("unable to update all symbol uses for ") diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir --- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir @@ -1,9 +1,19 @@ // RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s +// Combine modules without the same symbols + // CHECK: module { // CHECK-NEXT: spv.module Logical GLSL450 { // CHECK-NEXT: spv.SpecConstant @m1_sc +// CHECK-NEXT: spv.GlobalVariable @m1_gv bind(1, 0) +// CHECK-NEXT: spv.func @no_op +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: spv.EntryPoint "GLCompute" @no_op +// CHECK-NEXT: spv.ExecutionMode @no_op "LocalSize", 32, 1, 1 + // CHECK-NEXT: spv.SpecConstant @m2_sc +// CHECK-NEXT: spv.GlobalVariable @m2_gv bind(0, 1) // CHECK-NEXT: spv.func @variable_init_spec_constant // CHECK-NEXT: spv.mlir.referenceof @m2_sc // CHECK-NEXT: spv.Variable init @@ -15,10 +25,17 @@ module { spv.module Logical GLSL450 { spv.SpecConstant @m1_sc = 42.42 : f32 + spv.GlobalVariable @m1_gv bind(1, 0): !spv.ptr + spv.func @no_op() -> () "None" { + spv.Return + } + spv.EntryPoint "GLCompute" @no_op + spv.ExecutionMode @no_op "LocalSize", 32, 1, 1 } spv.module Logical GLSL450 { spv.SpecConstant @m2_sc = 42 : i32 + spv.GlobalVariable @m2_gv bind(0, 1): !spv.ptr spv.func @variable_init_spec_constant() -> () "None" { %0 = spv.mlir.referenceof @m2_sc : i32 %1 = spv.Variable init(%0) : !spv.ptr @@ -33,7 +50,7 @@ spv.module Physical64 GLSL450 { } -// expected-error @+1 {{input modules differ in addressing model and/or memory model}} +// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}} spv.module Logical GLSL450 { } } @@ -44,7 +61,19 @@ spv.module Logical Simple { } -// expected-error @+1 {{input modules differ in addressing model and/or memory model}} +// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}} +spv.module Logical GLSL450 { +} +} + +// ----- + +module { spv.module Logical GLSL450 { } + +// expected-error @+1 {{input modules differ in addressing model, memory model, and/or VCE triple}} +spv.module Logical GLSL450 requires #spv.vce { +} } + diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir --- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict-resolution.mlir @@ -215,7 +215,7 @@ spv.func @foo(%arg0 : i32) -> i32 "None" { spv.ReturnValue %arg0 : i32 } - + spv.EntryPoint "GLCompute" @foo spv.ExecutionMode @foo "ContractionOff" } @@ -383,7 +383,7 @@ spv.SpecConstant @foo = -5 : i32 spv.func @bar() -> i32 "None" { - %0 = spv.mlir.referenceof @foo : i32 + %0 = spv.mlir.referenceof @foo : i32 spv.ReturnValue %0 : i32 } } diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir --- a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/deduplication.mlir @@ -21,7 +21,6 @@ // CHECK-NEXT: } // CHECK-NEXT: } -module { spv.module Logical GLSL450 { spv.GlobalVariable @foo bind(1, 0) : !spv.ptr @@ -42,7 +41,6 @@ spv.ReturnValue %2 : f32 } } -} // ----- @@ -62,7 +60,6 @@ // CHECK-NEXT: } // CHECK-NEXT: } -module { spv.module Logical GLSL450 { spv.GlobalVariable @foo bind(1, 0) : !spv.ptr } @@ -76,7 +73,6 @@ spv.ReturnValue %1 : f32 } } -} // ----- @@ -93,7 +89,6 @@ // CHECK-NEXT: } // CHECK-NEXT: } -module { spv.module Logical GLSL450 { spv.GlobalVariable @foo built_in("GlobalInvocationId") : !spv.ptr, Input> } @@ -107,10 +102,11 @@ spv.ReturnValue %1 : vector<3xi32> } } -} // ----- +// Deduplicate 2 spec constants with the same spec ID. + // CHECK: module { // CHECK-NEXT: spv.module Logical GLSL450 { // CHECK-NEXT: spv.SpecConstant @foo spec_id(5) @@ -128,7 +124,6 @@ // CHECK-NEXT: } // CHECK-NEXT: } -module { spv.module Logical GLSL450 { spv.SpecConstant @foo spec_id(5) = 1. : f32 @@ -147,48 +142,82 @@ spv.ReturnValue %1 : f32 } } + +// ----- + +// Don't deduplicate functions with similar ops but different operands. + +// CHECK: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) +// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG1]] : f32 +// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG2]] : f32 +// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32 +// CHECK-NEXT: } +// CHECK-NEXT: spv.func @foo_1(%[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32) +// CHECK-NEXT: %[[ADD:.+]] = spv.FAdd %[[ARG0]], %[[ARG2]] : f32 +// CHECK-NEXT: %[[MUL:.+]] = spv.FMul %[[ADD]], %[[ARG1]] : f32 +// CHECK-NEXT: spv.ReturnValue %[[MUL]] : f32 +// CHECK-NEXT: } +// CHECK-NEXT: } + +spv.module Logical GLSL450 { + spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" { + %add = spv.FAdd %a, %b: f32 + %mul = spv.FMul %add, %c: f32 + spv.ReturnValue %mul: f32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%a: f32, %b: f32, %c: f32) -> f32 "None" { + %add = spv.FAdd %a, %c: f32 + %mul = spv.FMul %add, %b: f32 + spv.ReturnValue %mul: f32 + } } // ----- -// CHECK: module { -// CHECK-NEXT: spv.module Logical GLSL450 { -// CHECK-NEXT: spv.SpecConstant @bar spec_id(5) +// TODO: re-enable this test once we have better function deduplication. -// CHECK-NEXT: spv.func @foo(%arg0: f32) -// CHECK-NEXT: spv.ReturnValue -// CHECK-NEXT: } +// XXXXX: module { +// XXXXX-NEXT: spv.module Logical GLSL450 { +// XXXXX-NEXT: spv.SpecConstant @bar spec_id(5) -// CHECK-NEXT: spv.func @foo_different_body(%arg0: f32) -// CHECK-NEXT: spv.mlir.referenceof -// CHECK-NEXT: spv.ReturnValue -// CHECK-NEXT: } +// XXXXX-NEXT: spv.func @foo(%arg0: f32) +// XXXXX-NEXT: spv.ReturnValue +// XXXXX-NEXT: } -// CHECK-NEXT: spv.func @baz(%arg0: i32) -// CHECK-NEXT: spv.ReturnValue -// CHECK-NEXT: } +// XXXXX-NEXT: spv.func @foo_different_body(%arg0: f32) +// XXXXX-NEXT: spv.mlir.referenceof +// XXXXX-NEXT: spv.ReturnValue +// XXXXX-NEXT: } -// CHECK-NEXT: spv.func @baz_no_return(%arg0: i32) -// CHECK-NEXT: spv.Return -// CHECK-NEXT: } +// XXXXX-NEXT: spv.func @baz(%arg0: i32) +// XXXXX-NEXT: spv.ReturnValue +// XXXXX-NEXT: } -// CHECK-NEXT: spv.func @baz_no_return_different_control -// CHECK-NEXT: spv.Return -// CHECK-NEXT: } +// XXXXX-NEXT: spv.func @baz_no_return(%arg0: i32) +// XXXXX-NEXT: spv.Return +// XXXXX-NEXT: } -// CHECK-NEXT: spv.func @baz_no_return_another_control -// CHECK-NEXT: spv.Return -// CHECK-NEXT: } +// XXXXX-NEXT: spv.func @baz_no_return_different_control +// XXXXX-NEXT: spv.Return +// XXXXX-NEXT: } -// CHECK-NEXT: spv.func @kernel -// CHECK-NEXT: spv.Return -// CHECK-NEXT: } +// XXXXX-NEXT: spv.func @baz_no_return_another_control +// XXXXX-NEXT: spv.Return +// XXXXX-NEXT: } -// CHECK-NEXT: spv.func @kernel_different_attr -// CHECK-NEXT: spv.Return -// CHECK-NEXT: } -// CHECK-NEXT: } -// CHECK-NEXT: } +// XXXXX-NEXT: spv.func @kernel +// XXXXX-NEXT: spv.Return +// XXXXX-NEXT: } + +// XXXXX-NEXT: spv.func @kernel_different_attr +// XXXXX-NEXT: spv.Return +// XXXXX-NEXT: } +// XXXXX-NEXT: } +// XXXXX-NEXT: } module { spv.module Logical GLSL450 { diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/symbol-rename-listener.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s + +module { +spv.module @Module1 Logical GLSL450 { + spv.GlobalVariable @foo bind(1, 0) : !spv.ptr + spv.func @bar() -> () "None" { + spv.Return + } + spv.func @baz() -> () "None" { + spv.Return + } + + spv.SpecConstant @sc = -5 : i32 +} + +spv.module @Module2 Logical GLSL450 { + spv.func @foo() -> () "None" { + spv.Return + } + + spv.GlobalVariable @bar bind(1, 0) : !spv.ptr + + spv.func @baz() -> () "None" { + spv.Return + } + + spv.SpecConstant @sc = -5 : i32 +} + +spv.module @Module3 Logical GLSL450 { + spv.func @foo() -> () "None" { + spv.Return + } + + spv.GlobalVariable @bar bind(1, 0) : !spv.ptr + + spv.func @baz() -> () "None" { + spv.Return + } + + spv.SpecConstant @sc = -5 : i32 +} +} + +// CHECK: [Module1] foo -> foo_1 +// CHECK: [Module1] sc -> sc_2 + +// CHECK: [Module2] bar -> bar_3 +// CHECK: [Module2] baz -> baz_4 +// CHECK: [Module2] sc -> sc_5 + +// CHECK: [Module3] foo -> foo_6 +// CHECK: [Module3] bar -> bar_7 +// CHECK: [Module3] baz -> baz_8 diff --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp @@ -37,7 +37,14 @@ auto modules = llvm::to_vector<4>(getOperation().getOps()); OpBuilder combinedModuleBuilder(modules[0]); - combinedModule = spirv::combine(modules, combinedModuleBuilder, nullptr); + + auto listener = [](spirv::ModuleOp originalModule, StringRef oldSymbol, + StringRef newSymbol) { + llvm::outs() << "[" << originalModule.getName() << "] " << oldSymbol + << " -> " << newSymbol << "\n"; + }; + + combinedModule = spirv::combine(modules, combinedModuleBuilder, listener); for (spirv::ModuleOp module : modules) module.erase();