diff --git a/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/ModuleCombiner.h @@ -0,0 +1,90 @@ +//===- ModuleCombiner.h - MLIR SPIR-V Module Combiner -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the entry point to the SPIR-V module combiner library. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ +#define MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ + +#include "llvm/ADT/SmallVector.h" + +namespace mlir { +class OpBuilder; + +namespace spirv { +class ModuleOp; + +/// To combine a number of MLIR spv modules, we move all the module-level ops +/// from all the input modules into one big combined module. To that end, the +/// combination process can proceed in 2 phases: +/// +/// (1) resolving conflicts between pairs of ops from different modules +/// (2) deduplicate equivalent ops/sub-ops in the merged module. (TODO) +/// +/// For the conflict resolution phase, the following rules are employed to +/// resolve such conflicts: +/// +/// ========================================================================= +/// FuncOp vs. FuncOp +/// Conflict: Same symbol name +/// ------------------------------------------------------------------------- +/// Rename one of the functions and update its refernces. +/// ========================================================================= +/// +/// ========================================================================= +/// FuncOp vs. GlobalVariableOp +/// Conflict: Same symbol name +/// ------------------------------------------------------------------------- +/// Rename the global variable and update references to the renamed symbol. +/// ========================================================================= +/// +/// ========================================================================= +/// FuncOp vs. SpecConstantOp +/// FuncOp vs. SpecConstantCompositeOp +/// Conflict: Same symbol name +/// ------------------------------------------------------------------------- +/// Rename the spec constant and update references to the renamed symbol. +/// ========================================================================= +/// +/// ========================================================================= +/// GlobalVariableOp vs. GlobalVariableOp +/// Conflict: Same symbol name +/// ------------------------------------------------------------------------- +/// Rename either of the global variables and update references to it. +/// ========================================================================= +/// +/// ========================================================================= +/// GlobalVariableOp vs. SpecConstantOp +/// FuncOp vs. SpecConstantCompositeOp +/// Conflict: Same symbol name +/// ------------------------------------------------------------------------- +/// Rename the global variable and update its references. +/// ========================================================================= +/// +/// ========================================================================= +/// 2 spec constants (scalar or composite) +/// Conflict: Same symbol name +/// ------------------------------------------------------------------------- +/// Rename either of the constants and update its references. +/// ========================================================================= +/// +/// ========================================================================= +/// EntryPointOp vs. EntryPointOp +/// Conflict: Same symbol name and execution model +/// ------------------------------------------------------------------------- +/// No need to resolve this explicitly as it will be resolved as part of +/// resolving the conflict between the 2 associated functions. +/// ========================================================================= +void combine(llvm::SmallVector modules, + OpBuilder &combinedModuleBuilder); +} // namespace spirv +} // namespace mlir + +#endif // MLIR_DIALECT_SPIRV_MODULECOMBINER_H_ diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -34,5 +34,6 @@ MLIRTransforms ) +add_subdirectory(Linking) add_subdirectory(Serialization) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Linking/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(ModuleCombiner) diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect_library(MLIRSPIRVModuleCombiner + ModuleCombiner.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SPIRV + ) diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -0,0 +1,130 @@ +//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the the SPIR-V module combiner library. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/ModuleCombiner.h" + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/SymbolTable.h" +#include "llvm/ADT/StringExtras.h" + +using namespace mlir; +using namespace spirv; + +static SmallString<64> renameSymbol(StringRef oldSymName, + unsigned &nextConflictID, + const spirv::ModuleOp combinedModule) { + SmallString<64> newSymName(oldSymName); + newSymName.push_back('_'); + while (true) { + newSymName += llvm::utostr(++nextConflictID); + + if (!SymbolTable::lookupSymbolIn(combinedModule, newSymName)) + break; + } + + return newSymName; +} + +/// Walks the target module for operations of type OpTy. And for each such +/// operation, checks if another operation in the source module has the same +/// symbol. If this is the case, renames the visited/walked operation and +/// updates its references. +template +static void updateSymbolAndAllUses(spirv::ModuleOp target, + spirv::ModuleOp source, + unsigned &nextConflictID) { + target.walk([&](OpTy globalVarOp) { + if (SymbolTable::lookupSymbolIn(source, globalVarOp.sym_name())) { + StringRef oldSymName = globalVarOp.sym_name(); + SmallString<64> newSymName = + renameSymbol(oldSymName, nextConflictID, target); + + if (failed(SymbolTable::replaceAllSymbolUses(globalVarOp, newSymName, + target))) + globalVarOp.emitError("unable to update all symbol uses for ") + << oldSymName << " to " << newSymName; + + SymbolTable::setSymbolName(globalVarOp, newSymName); + } + }); +} + +namespace mlir { +namespace spirv { + +void combine(llvm::SmallVector modules, + OpBuilder &combinedModuleBuilder) { + unsigned nextConflictID = 0; + + if (modules.empty()) + return; + + auto addressingModel = modules[0].addressing_model(); + auto memoryModel = modules[0].memory_model(); + + auto combinedModule = combinedModuleBuilder.create( + modules[0].getLoc(), addressingModel, memoryModel); + combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); + + for (auto module : modules) { + if (module.addressing_model() != addressingModel || + module.memory_model() != memoryModel) { + combinedModule.emitError( + "input modules differ in addressing model and/or memory model"); + return; + } + + spirv::ModuleOp moduleClone = module.clone(); + + // A global variable is renamed if it conflicts with a function or a spec + // constant: + // (1) Rename global variables from the current input module that are + // conflicting with any other module-level op currently in the combined + // module. + updateSymbolAndAllUses(moduleClone, combinedModule, + nextConflictID); + // (2) Rename global variables currently in the combined module that are + // conflicting with any other module-level op in the current input module. + updateSymbolAndAllUses(combinedModule, moduleClone, + nextConflictID); + + // A spec constant is renamed if it conflicts with a function: + // (1) Rename spec constants from the current input module that are + // conflicting with functions currently in the combined module. + updateSymbolAndAllUses(moduleClone, combinedModule, + nextConflictID); + updateSymbolAndAllUses(moduleClone, combinedModule, + nextConflictID); + // (2) Rename spec constants currently in the combined module that are + // conflicting with functions in the current input module. + updateSymbolAndAllUses(combinedModule, moduleClone, + nextConflictID); + updateSymbolAndAllUses(combinedModule, moduleClone, + nextConflictID); + + // Rename function in the current input modules that are conflicting with + // functions already in the combined module. + updateSymbolAndAllUses(moduleClone, combinedModule, nextConflictID); + + // Clone all the module's ops to the combined module. + for (auto &op : moduleClone.getOps()) { + if (dyn_cast(op)) + continue; + + combinedModuleBuilder.insert(op.clone()); + } + } +} + +} // namespace spirv +} // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/basic.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @m1_sc +// CHECK-NEXT: spv.specConstant @m2_sc +// CHECK-NEXT: spv.func @variable_init_spec_constant +// CHECK-NEXT: spv._reference_of @m2_sc +// CHECK-NEXT: spv.Variable init +// CHECK-NEXT: spv.Return +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @m1_sc = 42.42 : f32 +} + +spv.module Logical GLSL450 { + spv.specConstant @m2_sc = 42 : i32 + spv.func @variable_init_spec_constant() -> () "None" { + %0 = spv._reference_of @m2_sc : i32 + %1 = spv.Variable init(%0) : !spv.ptr + spv.Return + } +} +} + +// ----- + +module { +// expected-error @+1 {{input modules differ in addressing model and/or memory model}} +spv.module Physical64 GLSL450 { +} + +spv.module Logical GLSL450 { +} +} + +// ----- + +module { +// expected-error @+1 {{input modules differ in addressing model and/or memory model}} +spv.module Logical Simple { +} + +spv.module Logical GLSL450 { +} +} diff --git a/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Linking/ModuleCombiner/conflict_resolution.mlir @@ -0,0 +1,682 @@ +// RUN: mlir-opt -test-spirv-module-combiner -split-input-file -verify-diagnostics %s | FileCheck %s + +// Test basic renaming of conflicting funcOps. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} +} + +// ----- + +// Test basic renaming of conflicting funcOps across 3 modules. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_2 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Test properly updating references to a renamed funcOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv.FunctionCall @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.func @bar(%arg0 : f32) -> f32 "None" { + %0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32) + spv.ReturnValue %0 : f32 + } +} +} + +// ----- + +// Test properly updating references to a renamed funcOp if the functionCallOp +// preceeds the callee funcOp definition. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv.FunctionCall @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @bar(%arg0 : f32) -> f32 "None" { + %0 = spv.FunctionCall @foo(%arg0) : (f32) -> (f32) + spv.ReturnValue %0 : f32 + } + + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} +} + +// ----- + +// Test properly updating entryPointOp and executionModeOp attached to renamed +// funcOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1 +// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff" +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.EntryPoint "GLCompute" @foo + spv.ExecutionMode @foo "ContractionOff" +} +} + +// ----- + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.EntryPoint "GLCompute" @fo +// CHECK-NEXT: spv.ExecutionMode @foo "ContractionOff" + +// CHECK-NEXT: spv.func @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.EntryPoint "GLCompute" @foo_1 +// CHECK-NEXT: spv.ExecutionMode @foo_1 "ContractionOff" +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } + + spv.EntryPoint "GLCompute" @foo + spv.ExecutionMode @foo "ContractionOff" +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } + + spv.EntryPoint "GLCompute" @foo + spv.ExecutionMode @foo "ContractionOff" +} +} + +// ----- + +// Resolve conflicting funcOp and globalVariableOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} + +// ----- + +// Resolve conflicting funcOp and globalVariableOp and update the global variable's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._address_of @foo_1 +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr + + spv.func @bar() -> f32 "None" { + %0 = spv._address_of @foo : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} +} + +// ----- + +// Resolve conflicting globalVariableOp and funcOp and update the global variable's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._address_of @foo_1 +// CHECK-NEXT: spv.Load +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr + + spv.func @bar() -> f32 "None" { + %0 = spv._address_of @foo : !spv.ptr + %1 = spv.Load "Input" %0 : f32 + spv.ReturnValue %1 : f32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @foo_1 +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantOp and update the spec constant's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 + + spv.func @bar() -> i32 "None" { + %0 = spv._reference_of @foo : i32 + spv.ReturnValue %0 : i32 + } +} +} + +// ----- + +// Resolve conflicting specConstantOp and funcOp and update the spec constant's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @foo_1 +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 + + spv.func @bar() -> i32 "None" { + %0 = spv._reference_of @foo : i32 + spv.ReturnValue %0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantCompositeOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar) +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> +} +} + +// ----- + +// Resolve conflicting funcOp and specConstantCompositeOp and update the spec +// constant's references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar) +// CHECK-NEXT: spv.func @baz +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.CompositeExtract +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} + +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> + + spv.func @baz() -> i32 "None" { + %0 = spv._reference_of @foo : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + spv.ReturnValue %1 : i32 + } +} +} + +// ----- + +// Resolve conflicting specConstantCompositeOp and funcOp and update the spec +// constant's references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo_1 (@bar, @bar) +// CHECK-NEXT: spv.func @baz +// CHECK-NEXT: spv._reference_of @foo_1 +// CHECK-NEXT: spv.CompositeExtract +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> + + spv.func @baz() -> i32 "None" { + %0 = spv._reference_of @foo : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + spv.ReturnValue %1 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } +} +} + +// ----- + +// Resolve conflicting spec constants and funcOps and update the spec constant's +// references. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar_1 +// CHECK-NEXT: spv.specConstantComposite @foo_2 (@bar_1, @bar_1) +// CHECK-NEXT: spv.func @baz +// CHECK-NEXT: spv._reference_of @foo_2 +// CHECK-NEXT: spv.CompositeExtract +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @foo +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } + +// CHECK-NEXT: spv.func @bar +// CHECK-NEXT: spv.ReturnValue +// CHECK-NEXT: } +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> + + spv.func @baz() -> i32 "None" { + %0 = spv._reference_of @foo : !spv.array<2 x i32> + %1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32> + spv.ReturnValue %1 : i32 + } +} + +spv.module Logical GLSL450 { + spv.func @foo(%arg0 : i32) -> i32 "None" { + spv.ReturnValue %arg0 : i32 + } + + spv.func @bar(%arg0 : f32) -> f32 "None" { + spv.ReturnValue %arg0 : f32 + } +} +} + +// ----- + +// Resolve conflicting globalVariableOps. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo + +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} + +// ----- + +// Resolve conflicting globalVariableOp and specConstantOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 + +// CHECK-NEXT: spv.specConstant @foo +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 +} +} + +// ----- + +// Resolve conflicting specConstantOp and globalVariableOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @foo + +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @foo = -5 : i32 +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} + +// ----- + +// Resolve conflicting globalVariableOp and specConstantCompositeOp. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.globalVariable @foo_1 + +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo (@bar, @bar) +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} + +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> +} +} + +// ----- + +// Resolve conflicting globalVariableOp and specConstantComposite. + +// CHECK: module { +// CHECK-NEXT: spv.module Logical GLSL450 { +// CHECK-NEXT: spv.specConstant @bar +// CHECK-NEXT: spv.specConstantComposite @foo (@bar, @bar) + +// CHECK-NEXT: spv.globalVariable @foo_1 +// CHECK-NEXT: } + +module { +spv.module Logical GLSL450 { + spv.specConstant @bar = -5 : i32 + spv.specConstantComposite @foo (@bar, @bar) : !spv.array<2 x i32> +} + +spv.module Logical GLSL450 { + spv.globalVariable @foo bind(1, 0) : !spv.ptr +} +} diff --git a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt --- a/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SPIRV/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRSPIRVTestPasses TestAvailability.cpp TestEntryPointAbi.cpp + TestModuleCombiner.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp @@ -0,0 +1,46 @@ +//===- TestModuleCombiner.cpp - Pass to test SPIR-V module combiner lib ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/ModuleCombiner.h" + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Module.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +class TestModuleCombinerPass + : public PassWrapper> { +public: + TestModuleCombinerPass() = default; + TestModuleCombinerPass(const TestModuleCombinerPass &) {} + void runOnOperation() override; +}; +} // namespace + +void TestModuleCombinerPass::runOnOperation() { + auto modules = llvm::to_vector<4>(getOperation().getOps()); + + OpBuilder combinedModuleBuilder(modules[0]); + spirv::combine(modules, combinedModuleBuilder); + + for (auto module : modules) { + module.erase(); + } +} + +namespace mlir { +void registerTestSpirvModuleCombinerPass() { + PassRegistration registration( + "test-spirv-module-combiner", "Tests SPIR-V module combiner library"); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -77,6 +77,7 @@ void registerTestRecursiveTypesPass(); void registerTestReducer(); void registerTestSpirvEntryPointABIPass(); +void registerTestSpirvModuleCombinerPass(); void registerTestSCFUtilsPass(); void registerTestTraitsPass(); void registerTestVectorConversions(); @@ -136,6 +137,7 @@ registerTestReducer(); registerTestGpuParallelLoopMappingPass(); registerTestSpirvEntryPointABIPass(); + registerTestSpirvModuleCombinerPass(); registerTestSCFUtilsPass(); registerTestTraitsPass(); registerTestVectorConversions();