diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -338,6 +338,15 @@ spvModuleRegion.begin()); // The spirv.module build method adds a block. Remove that. rewriter.eraseBlock(&spvModuleRegion.back()); + + // Some of the patterns call `lookupTargetEnv` during conversion and they + // will fail if called after GPUModuleConversion and we don't preserve + // `TargetEnv` attribute. + // Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp. + if (auto attr = moduleOp->getAttrOfType( + spirv::getTargetEnvAttrName())) + spvModule->setAttr(spirv::getTargetEnvAttrName(), attr); + rewriter.eraseOp(moduleOp); return success(); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -63,37 +63,41 @@ gpuModules.push_back(builder.clone(*moduleOp.getOperation())); }); - // Map MemRef memory space to SPIR-V storage class first if requested. - if (mapMemorySpace) { + // Run conversion for each module independently as they can have different + // TargetEnv attributes. + for (Operation *gpuModule : gpuModules) { + // Map MemRef memory space to SPIR-V storage class first if requested. + if (mapMemorySpace) { + std::unique_ptr target = + spirv::getMemorySpaceToStorageClassTarget(*context); + spirv::MemorySpaceToStorageClassMap memorySpaceMap = + spirv::mapMemorySpaceToVulkanStorageClass; + spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); + + RewritePatternSet patterns(context); + spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns); + + if (failed(applyFullConversion(gpuModule, *target, std::move(patterns)))) + return signalPassFailure(); + } + + auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule); std::unique_ptr target = - spirv::getMemorySpaceToStorageClassTarget(*context); - spirv::MemorySpaceToStorageClassMap memorySpaceMap = - spirv::mapMemorySpaceToVulkanStorageClass; - spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); + SPIRVConversionTarget::get(targetAttr); + SPIRVTypeConverter typeConverter(targetAttr); RewritePatternSet patterns(context); - spirv::populateMemorySpaceToStorageClassPatterns(converter, patterns); + populateGPUToSPIRVPatterns(typeConverter, patterns); - if (failed(applyFullConversion(gpuModules, *target, std::move(patterns)))) + // TODO: Change SPIR-V conversion to be progressive and remove the following + // patterns. + mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns); + populateMemRefToSPIRVPatterns(typeConverter, patterns); + populateFuncToSPIRVPatterns(typeConverter, patterns); + + if (failed(applyFullConversion(gpuModule, *target, std::move(patterns)))) return signalPassFailure(); } - - auto targetAttr = spirv::lookupTargetEnvOrDefault(module); - std::unique_ptr target = - SPIRVConversionTarget::get(targetAttr); - - SPIRVTypeConverter typeConverter(targetAttr); - RewritePatternSet patterns(context); - populateGPUToSPIRVPatterns(typeConverter, patterns); - - // TODO: Change SPIR-V conversion to be progressive and remove the following - // patterns. - mlir::arith::populateArithToSPIRVPatterns(typeConverter, patterns); - populateMemRefToSPIRVPatterns(typeConverter, patterns); - populateFuncToSPIRVPatterns(typeConverter, patterns); - - if (failed(applyFullConversion(gpuModules, *target, std::move(patterns)))) - return signalPassFailure(); } std::unique_ptr> diff --git a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir --- a/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/module-opencl.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -verify-diagnostics -split-input-file %s -o - | FileCheck %s module attributes { gpu.container_module, @@ -28,3 +28,36 @@ return } } + +// ----- + +module attributes { + gpu.container_module +} { + gpu.module @kernels attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + } { + // CHECK-LABEL: spirv.module @{{.*}} Physical64 OpenCL + // CHECK-SAME: spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + // CHECK: spirv.func + // CHECK-SAME: {{%.*}}: f32 + // CHECK-NOT: spirv.interface_var_abi + // CHECK-SAME: {{%.*}}: !spirv.ptr, CrossWorkgroup> + // CHECK-NOT: spirv.interface_var_abi + // CHECK-SAME: spirv.entry_point_abi = #spirv.entry_point_abi : vector<3xi32>> + gpu.func @basic_module_structure(%arg0 : f32, %arg1 : memref<12xf32, #spirv.storage_class>) kernel + attributes {spirv.entry_point_abi = #spirv.entry_point_abi: vector<3xi32>>} { + gpu.return + } + } + + func.func @main() { + %0 = "op"() : () -> (f32) + %1 = "op"() : () -> (memref<12xf32, #spirv.storage_class>) + %cst = arith.constant 1 : index + gpu.launch_func @kernels::@basic_module_structure + blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) + args(%0 : f32, %1 : memref<12xf32, #spirv.storage_class>) + return + } +}