diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -144,14 +144,31 @@ SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { auto *typeConverter = this->template getTypeConverter(); - auto indexType = typeConverter->getIndexType(); - - // SPIR-V invocation builtin variables are a vector of type <3xi32> - auto spirvBuiltin = - spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); - rewriter.replaceOpWithNewOp( - op, indexType, spirvBuiltin, + Type indexType = typeConverter->getIndexType(); + + // For Vulkan, these SPIR-V builtin variables are required to be a vector of + // type <3xi32> by the spec: + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html + // + // For OpenCL, it depends on the Physical32/Physical64 addressing model: + // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables + bool forShader = + typeConverter->getTargetEnv().allows(spirv::Capability::Shader); + Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType; + + Value vector = + spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter); + Value dim = rewriter.create( + op.getLoc(), builtinType, vector, rewriter.getI32ArrayAttr({static_cast(op.getDimension())})); + if (forShader && builtinType != indexType) + dim = rewriter.create(op.getLoc(), indexType, dim); + rewriter.replaceOp(op, dim); return success(); } @@ -161,11 +178,23 @@ SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { auto *typeConverter = this->template getTypeConverter(); - auto indexType = typeConverter->getIndexType(); - - auto spirvBuiltin = - spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter); - rewriter.replaceOp(op, spirvBuiltin); + Type indexType = typeConverter->getIndexType(); + Type i32Type = rewriter.getIntegerType(32); + + // For Vulkan, these SPIR-V builtin variables are required to be a vector of + // type i32 by the spec: + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html + // + // For OpenCL, they are also required to be i32: + // https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables + Value builtinValue = + spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter); + if (i32Type != indexType) + builtinValue = rewriter.create(op.getLoc(), indexType, + builtinValue); + rewriter.replaceOp(op, builtinValue); return success(); } diff --git a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir --- a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -split-input-file -convert-gpu-to-spirv="use-64bit-index=false" %s -o - | FileCheck %s --check-prefix=INDEX32 +// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv="use-64bit-index=true" %s -o - | FileCheck %s --check-prefix=INDEX64 module attributes { gpu.container_module, @@ -13,12 +14,15 @@ // INDEX32-LABEL: spirv.module @{{.*}} Logical GLSL450 // INDEX32: spirv.GlobalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") : !spirv.ptr, Input> + // INDEX64-LABEL: spirv.module @{{.*}} Logical GLSL450 + // INDEX64: spirv.GlobalVariable [[WORKGROUPID:@.*]] built_in("WorkgroupId") : !spirv.ptr, Input> gpu.module @kernels { gpu.func @builtin_workgroup_id_x() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // INDEX32: [[ADDRESS:%.*]] = spirv.mlir.addressof [[WORKGROUPID]] // INDEX32-NEXT: [[VEC:%.*]] = spirv.Load "Input" [[ADDRESS]] // INDEX32-NEXT: {{%.*}} = spirv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}} + // INDEX64: spirv.UConvert %{{.+}} : i32 to i64 %0 = gpu.block_id x gpu.return } @@ -422,11 +426,14 @@ } { // INDEX32-LABEL: spirv.module @{{.*}} Logical GLSL450 // INDEX32: spirv.GlobalVariable [[SUBGROUPSIZE:@.*]] built_in("SubgroupSize") : !spirv.ptr + // INDEX64-LABEL: spirv.module @{{.*}} Logical GLSL450 + // INDEX64: spirv.GlobalVariable [[SUBGROUPSIZE:@.*]] built_in("SubgroupSize") : !spirv.ptr gpu.module @kernels { gpu.func @builtin_subgroup_size() kernel attributes {spirv.entry_point_abi = #spirv.entry_point_abi} { // INDEX32: [[ADDRESS:%.*]] = spirv.mlir.addressof [[SUBGROUPSIZE]] // INDEX32-NEXT: {{%.*}} = spirv.Load "Input" [[ADDRESS]] + // INDEX64: spirv.UConvert %{{.+}} : i32 to i64 %0 = gpu.subgroup_size : index gpu.return }