diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.h @@ -32,6 +32,10 @@ LLVMTypeConverter &typeConverter; }; +/// Encodes global variable's descriptor set and binding into its name if they +/// both exist. +void encodeBindAttribute(ModuleOp module); + /// Populates type conversions with additional SPIR-V types. void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -23,6 +23,7 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" #define DEBUG_TYPE "spirv-to-llvm-pattern" @@ -1332,8 +1333,6 @@ // TODO: Support EntryPoint/ExecutionMode properly. ErasePattern, ErasePattern, - // Function Call op - // GLSL extended instruction set ops DirectConversionPattern, DirectConversionPattern, @@ -1386,3 +1385,42 @@ patterns.insert( context, typeConverter); } + +//===----------------------------------------------------------------------===// +// Pre-conversion hooks +//===----------------------------------------------------------------------===// + +/// Hook for descriptor set and binding number encoding. +static constexpr StringRef kBinding = "binding"; +static constexpr StringRef kDescriptorSet = "descriptor_set"; +void mlir::encodeBindAttribute(ModuleOp module) { + auto spvModules = module.getOps(); + for (auto spvModule : spvModules) { + spvModule.walk([&](spirv::GlobalVariableOp op) { + IntegerAttr descriptorSet = op.getAttrOfType(kDescriptorSet); + IntegerAttr binding = op.getAttrOfType(kBinding); + // For every global variable in the module, get the ones with descriptor + // set and binding numbers. + if (descriptorSet && binding) { + // Encode these numbers into the variable's symbolic name. If the + // SPIR-V module has a name, add it at the beginning. + auto moduleAndName = spvModule.getName().hasValue() + ? spvModule.getName().getValue().str() + "_" + + op.sym_name().str() + : op.sym_name().str(); + std::string name = + llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, + std::to_string(descriptorSet.getInt()), + std::to_string(binding.getInt())); + + // Replace all symbol uses and set the new symbol name. Finally, remove + // descriptor set and binding attributes. + if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule))) + op.emitError("unable to replace all symbol uses for ") << name; + SymbolTable::setSymbolName(op, name); + op.removeAttr(kDescriptorSet); + op.removeAttr(kBinding); + } + }); + } +} diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.cpp @@ -33,6 +33,9 @@ ModuleOp module = getOperation(); LLVMTypeConverter converter(&getContext()); + // Encode global variable's descriptor set and binding if they exist. + encodeBindAttribute(module); + OwningRewritePatternList patterns; populateSPIRVToLLVMTypeConversion(converter); @@ -45,7 +48,7 @@ target.addIllegalDialect(); target.addLegalDialect(); - // set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module` + // Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module` // conversion. target.addLegalOp(); target.addLegalOp(); diff --git a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir @@ -37,7 +37,7 @@ spv.module Logical GLSL450 { // CHECK: llvm.mlir.global private @struct() : !llvm.struct)> // CHECK-LABEL: @func - // CHECK: llvm.mlir.addressof @struct : !llvm.ptr)>> + // CHECK: llvm.mlir.addressof @struct : !llvm.ptr)>> spv.globalVariable @struct : !spv.ptr>, Private> spv.func @func() "None" { %0 = spv._address_of @struct : !spv.ptr>, Private> @@ -45,6 +45,28 @@ } } +spv.module Logical GLSL450 { + // CHECK: llvm.mlir.global external @bar_descriptor_set0_binding0() : !llvm.i32 + // CHECK-LABEL: @foo + // CHECK: llvm.mlir.addressof @bar_descriptor_set0_binding0 : !llvm.ptr + spv.globalVariable @bar bind(0, 0) : !spv.ptr + spv.func @foo() "None" { + %0 = spv._address_of @bar : !spv.ptr + spv.Return + } +} + +spv.module @name Logical GLSL450 { + // CHECK: llvm.mlir.global external @name_bar_descriptor_set0_binding0() : !llvm.i32 + // CHECK-LABEL: @foo + // CHECK: llvm.mlir.addressof @name_bar_descriptor_set0_binding0 : !llvm.ptr + spv.globalVariable @bar bind(0, 0) : !spv.ptr + spv.func @foo() "None" { + %0 = spv._address_of @bar : !spv.ptr + spv.Return + } +} + //===----------------------------------------------------------------------===// // spv.Load //===----------------------------------------------------------------------===//