Index: mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp =================================================================== --- mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/TargetAndABI.h" #include "mlir/IR/Module.h" using namespace mlir; @@ -272,9 +273,18 @@ LogicalResult GPUModuleConversion::matchAndRewrite( gpu::GPUModuleOp moduleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { + auto addressingModel = spirv::AddressingModel::Logical; + auto memoryModel = spirv::MemoryModel::GLSL450; + auto targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp); + for (auto cap : targetEnv.getCapabilities()) { + if (cap == spirv::Capability::Addresses) + addressingModel = spirv::AddressingModel::Physical64; + if (cap == spirv::Capability::Kernel) + memoryModel = spirv::MemoryModel::OpenCL; + } + auto spvModule = rewriter.create( - moduleOp.getLoc(), spirv::AddressingModel::Logical, - spirv::MemoryModel::GLSL450); + moduleOp.getLoc(), addressingModel, memoryModel); // Move the region from the module op into the SPIR-V module. Region &spvModuleRegion = spvModule.body(); Index: mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -119,8 +119,14 @@ if (failed(getInterfaceVariables(funcOp, interfaceVars))) { return failure(); } - builder.create( - funcOp.getLoc(), spirv::ExecutionModel::GLCompute, funcOp, interfaceVars); + + spirv::TargetEnv targetEnv(spirv::lookupTargetEnv(funcOp)); + auto executionModel = spirv::ExecutionModel::GLCompute; + if (targetEnv.allows(spirv::Capability::Kernel)) + executionModel = spirv::ExecutionModel::Kernel; + + builder.create(funcOp.getLoc(), executionModel, funcOp, + interfaceVars); // Specifies the spv.ExecutionModeOp. auto localSizeAttr = entryPointAttr.local_size(); SmallVector localSize(localSizeAttr.getValues()); @@ -146,6 +152,20 @@ ConversionPatternRewriter &rewriter) const override; }; +/// A pattern to convert function signature removing all interface variable ABI +/// attributes. +/// +/// This is used when targeting environments that allow normal arguments +/// for entry points. +class ProcessInterfaceVarABIRemove final + : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + LogicalResult + matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + /// Pass to implement the ABI information specified as attributes. class LowerABIAttributesPass final : public SPIRVLowerABIAttributesBase { @@ -213,6 +233,21 @@ return success(); } +LogicalResult ProcessInterfaceVarABIRemove::matchAndRewrite( + spirv::FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + rewriter.startRootUpdate(funcOp); + auto context = funcOp.getContext(); + auto attrName = spirv::getInterfaceVarABIAttrName(); + auto attrId = Identifier::get(attrName, context); + for (auto argType : llvm::enumerate(funcOp.getType().getInputs())) { + // Remove ABI attributes + funcOp.removeArgAttr(argType.index(), attrId); + } + rewriter.finalizeRootUpdate(funcOp); + return success(); +} + void LowerABIAttributesPass::runOnOperation() { // Uses the signature conversion methodology of the dialect conversion // framework to implement the conversion. @@ -233,7 +268,11 @@ }); OwningRewritePatternList patterns; - patterns.insert(context, typeConverter); + if (targetEnv.allows(spirv::Capability::Kernel)) { + patterns.insert(context, typeConverter); + } else { + patterns.insert(context, typeConverter); + } ConversionTarget target(*context); // "Legal" function ops should have no interface variable ABI attributes. Index: mlir/test/Conversion/GPUToSPIRV/test_opencl_spirv.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/GPUToSPIRV/test_opencl_spirv.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt -allow-unregistered-dialect -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s + +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + gpu.module @kernels { + // CHECK: spv.module Physical64 OpenCL { + // CHECK-LABEL: spv.func @basic_module_structure + // CHECK-SAME: {{%.*}}: !spv.ptr [0]>, CrossWorkgroup> {spv.interface_var_abi = #spv.interface_var_abi<(0, 0)>} + // CHECK-SAME: spv.entry_point_abi = {local_size = dense<[32, 4, 1]> : vector<3xi32>} + gpu.func @basic_module_structure(%arg0 : memref<12xf32, 11>) kernel + attributes {spv.entry_point_abi = {local_size = dense<[32, 4, 1]>: vector<3xi32>}} { + // CHECK: spv.Return + gpu.return + } + } + + func @main() { + %0 = "op"() : () -> (memref<12xf32, 11>) + %cst = constant 1 : index + "gpu.launch_func"(%cst, %cst, %cst, %cst, %cst, %cst, %0) { kernel = @kernels::@basic_module_structure } + : (index, index, index, index, index, index, memref<12xf32, 11>) -> () + return + } +} \ No newline at end of file Index: mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/SPIRV/Transforms/abi-interface-opencl.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -spirv-lower-abi-attrs -verify-diagnostics %s -o - | FileCheck %s + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + +// CHECK: spv.module Physical64 OpenCL +spv.module Physical64 OpenCL { + // CHECK-LABEL: spv.func @kernel + // CHECK-SAME: {{%.*}}: f32 + // CHECK-SAME: {{%.*}}: !spv.ptr>, CrossWorkgroup> + spv.func @kernel( + %arg0: f32 + {spv.interface_var_abi = #spv.interface_var_abi<(0, 0), UniformConstant>}, + %arg1: !spv.ptr>, CrossWorkgroup> + {spv.interface_var_abi = #spv.interface_var_abi<(0, 1)>}) "None" + attributes {spv.entry_point_abi = {local_size = dense<[32, 1, 1]> : vector<3xi32>}} { + // CHECK: spv.Return + spv.Return + } + // CHECK: spv.EntryPoint "Kernel" @kernel + // CHECK: spv.ExecutionMode @kernel "LocalSize", 32, 1, 1 +} // end spv.module + +} // end module