diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -68,6 +68,19 @@ ConversionPatternRewriter &rewriter) const override; }; +/// This is separate because in Vulkan workgroup size is exposed to shaders via +/// a constant with WorkgroupSize decoration. So here we cannot generate a +/// builtin variable; instead the infromation in the `spv.entry_point_abi` +/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp. +class WorkGroupSizeConversion : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + PatternMatchResult + matchAndRewrite(gpu::BlockDimOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Pattern to convert a kernel function in GPU dialect within a spv.module. class KernelFnConversion final : public SPIRVOpLowering { public: @@ -240,34 +253,54 @@ // Builtins. //===----------------------------------------------------------------------===// -template -PatternMatchResult LaunchConfigConversion::matchAndRewrite( - SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - auto dimAttr = - op.getOperation()->template getAttrOfType("dimension"); +static Optional getLaunchConfigIndex(Operation *op) { + auto dimAttr = op->getAttrOfType("dimension"); if (!dimAttr) { - return this->matchFailure(); + return {}; } - int32_t index = 0; if (dimAttr.getValue() == "x") { - index = 0; + return 0; } else if (dimAttr.getValue() == "y") { - index = 1; + return 1; } else if (dimAttr.getValue() == "z") { - index = 2; - } else { - return this->matchFailure(); + return 2; } + return {}; +} + +template +PatternMatchResult LaunchConfigConversion::matchAndRewrite( + SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto index = getLaunchConfigIndex(op); + if (!index) + return this->matchFailure(); // SPIR-V invocation builtin variables are a vector of type <3xi32> auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter); rewriter.replaceOpWithNewOp( op, rewriter.getIntegerType(32), spirvBuiltin, - rewriter.getI32ArrayAttr({index})); + rewriter.getI32ArrayAttr({index.getValue()})); return this->matchSuccess(); } +PatternMatchResult WorkGroupSizeConversion::matchAndRewrite( + gpu::BlockDimOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + auto index = getLaunchConfigIndex(op); + if (!index) + return matchFailure(); + + auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); + auto val = workGroupSizeAttr.getValue(index.getValue()); + auto convertedType = typeConverter.convertType(op.getResult().getType()); + if (!convertedType) + return matchFailure(); + rewriter.replaceOpWithNewOp( + op, convertedType, IntegerAttr::get(convertedType, val)); + return matchSuccess(); +} + //===----------------------------------------------------------------------===// // GPUFuncOp //===----------------------------------------------------------------------===// @@ -401,13 +434,11 @@ populateWithGenerated(context, &patterns); patterns.insert(context, typeConverter, workGroupSize); patterns.insert< - ForOpConversion, GPUReturnOpConversion, IfOpConversion, - GPUModuleConversion, - GPUReturnOpConversion, ForOpConversion, GPUModuleConversion, - LaunchConfigConversion, + ForOpConversion, GPUModuleConversion, GPUReturnOpConversion, + IfOpConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, - TerminatorOpConversion>(context, typeConverter); + TerminatorOpConversion, WorkGroupSizeConversion>(context, typeConverter); } diff --git a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir --- a/mlir/test/Conversion/GPUToSPIRV/builtins.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/builtins.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -pass-pipeline='convert-gpu-to-spirv{workgroup-size=32,4}' %s -o - | FileCheck %s module attributes {gpu.container_module} { func @builtin() { @@ -77,13 +77,11 @@ } // CHECK-LABEL: spv.module "Logical" "GLSL450" - // CHECK: spv.globalVariable [[WORKGROUPSIZE:@.*]] built_in("WorkgroupSize") gpu.module @kernels { gpu.func @builtin_workgroup_size_x() attributes {gpu.kernel} { - // CHECK: [[ADDRESS:%.*]] = spv._address_of [[WORKGROUPSIZE]] - // CHECK-NEXT: [[VEC:%.*]] = spv.Load "Input" [[ADDRESS]] - // CHECK-NEXT: {{%.*}} = spv.CompositeExtract [[VEC]]{{\[}}0 : i32{{\]}} + // The constant value is obtained fomr the command line option above. + // CHECK: spv.constant 32 : i32 %0 = "gpu.block_dim"() {dimension = "x"} : () -> index gpu.return } @@ -92,6 +90,48 @@ // ----- +module attributes {gpu.container_module} { + func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = "builtin_workgroup_size_y", kernel_module = @kernels} : (index, index, index, index, index, index) -> () + return + } + + // CHECK-LABEL: spv.module "Logical" "GLSL450" + gpu.module @kernels { + gpu.func @builtin_workgroup_size_y() + attributes {gpu.kernel} { + // The constant value is obtained fomr the command line option above. + // CHECK: spv.constant 4 : i32 + %0 = "gpu.block_dim"() {dimension = "y"} : () -> index + gpu.return + } + } +} + +// ----- + +module attributes {gpu.container_module} { + func @builtin() { + %c0 = constant 1 : index + "gpu.launch_func"(%c0, %c0, %c0, %c0, %c0, %c0) {kernel = "builtin_workgroup_size_z", kernel_module = @kernels} : (index, index, index, index, index, index) -> () + return + } + + // CHECK-LABEL: spv.module "Logical" "GLSL450" + gpu.module @kernels { + gpu.func @builtin_workgroup_size_z() + attributes {gpu.kernel} { + // The constant value is obtained fomr the command line option above (1 is default). + // CHECK: spv.constant 1 : i32 + %0 = "gpu.block_dim"() {dimension = "z"} : () -> index + gpu.return + } + } +} + +// ----- + module attributes {gpu.container_module} { func @builtin() { %c0 = constant 1 : index diff --git a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir --- a/mlir/test/Conversion/GPUToSPIRV/load-store.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/load-store.mlir @@ -17,7 +17,6 @@ // CHECK-LABEL: spv.module "Logical" "GLSL450" gpu.module @kernels { - // CHECK-DAG: spv.globalVariable [[WORKGROUPSIZEVAR:@.*]] built_in("WorkgroupSize") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[NUMWORKGROUPSVAR:@.*]] built_in("NumWorkgroups") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[LOCALINVOCATIONIDVAR:@.*]] built_in("LocalInvocationId") : !spv.ptr, Input> // CHECK-DAG: spv.globalVariable [[WORKGROUPIDVAR:@.*]] built_in("WorkgroupId") : !spv.ptr, Input>