diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -343,6 +343,26 @@ return newFuncOp; } +/// Updates `argABI` with spv.interface_var_abi attributes for lowering gpu.func +/// to spv.func if no arguments have the attributes set already. Returns failure +/// if any argument has the ABI attribute set already. +static LogicalResult +getDefaultABIAttrs(MLIRContext *context, gpu::GPUFuncOp funcOp, + SmallVectorImpl &argABI) { + for (auto argIndex : llvm::seq(0, funcOp.getNumArguments())) { + if (funcOp.getArgAttrOfType( + argIndex, spirv::getInterfaceVarABIAttrName())) + return failure(); + // Vulkan's interface variable requirements needs scalars to be wrapped in a + // struct. The struct held in storage buffer. + Optional sc; + if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) + sc = spirv::StorageClass::StorageBuffer; + argABI.push_back(spirv::getInterfaceVarABIAttr(0, argIndex, sc, context)); + } + return success(); +} + LogicalResult GPUFuncOpConversion::matchAndRewrite( gpu::GPUFuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -350,22 +370,21 @@ return failure(); SmallVector argABI; - for (auto argIndex : llvm::seq(0, funcOp.getNumArguments())) { - // If the ABI is already specified, use it. - auto abiAttr = funcOp.getArgAttrOfType( - argIndex, spirv::getInterfaceVarABIAttrName()); - if (abiAttr) { + if (failed(getDefaultABIAttrs(rewriter.getContext(), funcOp, argABI))) { + argABI.clear(); + for (auto argIndex : llvm::seq(0, funcOp.getNumArguments())) { + // If the ABI is already specified, use it. + auto abiAttr = funcOp.getArgAttrOfType( + argIndex, spirv::getInterfaceVarABIAttrName()); + if (!abiAttr) { + funcOp.emitRemark( + "match failure: missing 'spv.interface_var_abi' attribute at " + "argument ") + << argIndex; + return failure(); + } argABI.push_back(abiAttr); - continue; } - // todo(ravishankarm): Use the "default ABI". Remove this in a follow up - // CL. Staging this to make this easy to revert in case of breakages out of - // tree. - Optional sc; - if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat()) - sc = spirv::StorageClass::StorageBuffer; - argABI.push_back( - spirv::getInterfaceVarABIAttr(0, argIndex, sc, rewriter.getContext())); } auto entryPointAttr = spirv::lookupEntryPointABI(funcOp); diff --git a/mlir/test/Conversion/GPUToSPIRV/simple.mlir b/mlir/test/Conversion/GPUToSPIRV/simple.mlir --- a/mlir/test/Conversion/GPUToSPIRV/simple.mlir +++ b/mlir/test/Conversion/GPUToSPIRV/simple.mlir @@ -26,6 +26,39 @@ // ----- +module attributes {gpu.container_module} { + gpu.module @kernels { + // CHECK: spv.module Logical GLSL450 { + // CHECK-LABEL: spv.func @basic_module_structure_preset_ABI + // CHECK-SAME: {{%[a-zA-Z0-9_]*}}: f32 + // CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer> + // CHECK-SAME: !spv.ptr [0]>, StorageBuffer> + // CHECK-SAME: spv.interface_var_abi = #spv.interface_var_abi<(3, 0)> + // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} + gpu.func @basic_module_structure_preset_ABI( + %arg0 : f32 + {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>}, + %arg1 : memref<12xf32> + {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel + attributes + {spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { + // CHECK: spv.Return + gpu.return + } + } + + func @main() { + %0 = "op"() : () -> (f32) + %1 = "op"() : () -> (memref<12xf32>) + %cst = constant 1 : index + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) { kernel = @kernels::@basic_module_structure_preset_ABI } + : (index, index, index, index, index, index, f32, memref<12xf32>) -> () + return + } +} + +// ----- + module attributes {gpu.container_module} { gpu.module @kernels { // expected-error @below {{failed to legalize operation 'gpu.func'}} @@ -44,3 +77,57 @@ return } } + +// ----- + +module attributes {gpu.container_module} { + gpu.module @kernels { + // expected-error @below {{failed to legalize operation 'gpu.func'}} + // expected-remark @below {{match failure: missing 'spv.interface_var_abi' attribute at argument 1}} + gpu.func @missing_entry_point_abi( + %arg0 : f32 + {spv.interface_var_abi = #spv.interface_var_abi<(1, 2), StorageBuffer>}, + %arg1 : memref<12xf32>) kernel + attributes + {spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { + gpu.return + } + } + + func @main() { + %0 = "op"() : () -> (f32) + %1 = "op"() : () -> (memref<12xf32>) + %cst = constant 1 : index + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) + { kernel = @kernels::@missing_entry_point_abi } + : (index, index, index, index, index, index, f32, memref<12xf32>) -> () + return + } +} + +// ----- + +module attributes {gpu.container_module} { + gpu.module @kernels { + // expected-error @below {{failed to legalize operation 'gpu.func'}} + // expected-remark @below {{match failure: missing 'spv.interface_var_abi' attribute at argument 0}} + gpu.func @missing_entry_point_abi( + %arg0 : f32, + %arg1 : memref<12xf32> + {spv.interface_var_abi = #spv.interface_var_abi<(3, 0)>}) kernel + attributes + {spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { + gpu.return + } + } + + func @main() { + %0 = "op"() : () -> (f32) + %1 = "op"() : () -> (memref<12xf32>) + %cst = constant 1 : index + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0, %1) + { kernel = @kernels::@missing_entry_point_abi} + : (index, index, index, index, index, index, f32, memref<12xf32>) -> () + return + } +}