diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -954,7 +954,17 @@ let options = [ Option<"useOpaquePointers", "use-opaque-pointers", "bool", /*default=*/"true", "Generate LLVM IR using opaque pointers " - "instead of typed pointers"> + "instead of typed pointers">, + Option<"clientAPI", "client-api", "::mlir::spirv::ClientAPI", + /*default=*/"::mlir::spirv::ClientAPI::Unknown", + "Derive StorageClass to address space mapping from the client API", + [{::llvm::cl::values( + clEnumValN(::mlir::spirv::ClientAPI::Unknown, "Unknown", "Unknown (default)"), + clEnumValN(::mlir::spirv::ClientAPI::Metal, "Metal", "Metal"), + clEnumValN(::mlir::spirv::ClientAPI::OpenCL, "OpenCL", "OpenCL"), + clEnumValN(::mlir::spirv::ClientAPI::Vulkan, "Vulkan", "Vulkan"), + clEnumValN(::mlir::spirv::ClientAPI::WebGPU, "WebGPU", "WebGPU") + )}]>, ]; } diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h @@ -15,6 +15,8 @@ #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" + namespace mlir { class LLVMTypeConverter; class MLIRContext; @@ -37,11 +39,16 @@ void encodeBindAttribute(ModuleOp module); /// Populates type conversions with additional SPIR-V types. -void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter); +void populateSPIRVToLLVMTypeConversion( + LLVMTypeConverter &typeConverter, + spirv::ClientAPI clientAPIForAddressSpaceMapping = + spirv::ClientAPI::Unknown); /// Populates the given list with patterns that convert from SPIR-V to LLVM. -void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns); +void populateSPIRVToLLVMConversionPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + spirv::ClientAPI clientAPIForAddressSpaceMapping = + spirv::ClientAPI::Unknown); /// Populates the given list with patterns for function conversion from SPIR-V /// to LLVM. diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h @@ -15,6 +15,8 @@ #include +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" + namespace mlir { class Pass; diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -29,6 +29,12 @@ using namespace mlir; +//===----------------------------------------------------------------------===// +// Constants +//===----------------------------------------------------------------------===// + +static constexpr unsigned defaultAddressSpace = 0; + //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// @@ -271,12 +277,49 @@ return LLVM::LLVMArrayType::get(llvmElementType, numElements); } +static unsigned mapToOpenCLAddressSpace(spirv::StorageClass storageClass) { + // Based on + // https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html#_binary_form + // and clang/lib/Basic/Targets/SPIR.h. + switch (storageClass) { +#define STORAGE_SPACE_MAP(storage, space) \ + case spirv::StorageClass::storage: \ + return space; + STORAGE_SPACE_MAP(Function, 0) + STORAGE_SPACE_MAP(CrossWorkgroup, 1) + STORAGE_SPACE_MAP(Input, 1) + STORAGE_SPACE_MAP(UniformConstant, 2) + STORAGE_SPACE_MAP(Workgroup, 3) + STORAGE_SPACE_MAP(Generic, 4) + STORAGE_SPACE_MAP(DeviceOnlyINTEL, 5) + STORAGE_SPACE_MAP(HostOnlyINTEL, 6) +#undef STORAGE_SPACE_MAP + default: + return defaultAddressSpace; + } +} + +static unsigned mapToAddressSpace(spirv::ClientAPI clientAPI, + spirv::StorageClass storageClass) { + switch (clientAPI) { +#define CLIENT_MAP(client, storage) \ + case spirv::ClientAPI::client: \ + return mapTo##client##AddressSpace(storage); + CLIENT_MAP(OpenCL, storageClass) +#undef CLIENT_MAP + default: + return defaultAddressSpace; + } +} + /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not /// modelled at the moment. static Type convertPointerType(spirv::PointerType type, - LLVMTypeConverter &converter) { + LLVMTypeConverter &converter, + spirv::ClientAPI clientAPI) { auto pointeeType = converter.convertType(type.getPointeeType()); - return converter.getPointerType(pointeeType); + unsigned addressSpace = mapToAddressSpace(clientAPI, type.getStorageClass()); + return converter.getPointerType(pointeeType, addressSpace); } /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over @@ -734,7 +777,11 @@ class GlobalVariablePattern : public SPIRVToLLVMConversion { public: - using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + template + GlobalVariablePattern(spirv::ClientAPI clientAPI, Args &&...args) + : SPIRVToLLVMConversion( + std::forward(args)...), + clientAPI(clientAPI) {} LogicalResult matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, @@ -779,7 +826,7 @@ : LLVM::Linkage::External; auto newGlobalOp = rewriter.replaceOpWithNewOp( op, dstType, isConstant, linkage, op.getSymName(), Attribute(), - /*alignment=*/0); + /*alignment=*/0, mapToAddressSpace(clientAPI, storageClass)); // Attach location attribute if applicable if (op.getLocationAttr()) @@ -787,6 +834,9 @@ return success(); } + +private: + spirv::ClientAPI clientAPI; }; /// Converts SPIR-V cast ops that do not have straightforward LLVM @@ -1494,12 +1544,13 @@ // Pattern population //===----------------------------------------------------------------------===// -void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) { +void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter, + spirv::ClientAPI clientAPI) { typeConverter.addConversion([&](spirv::ArrayType type) { return convertArrayType(type, typeConverter); }); - typeConverter.addConversion([&](spirv::PointerType type) { - return convertPointerType(type, typeConverter); + typeConverter.addConversion([&, clientAPI](spirv::PointerType type) { + return convertPointerType(type, typeConverter, clientAPI); }); typeConverter.addConversion([&](spirv::RuntimeArrayType type) { return convertRuntimeArrayType(type, typeConverter); @@ -1510,7 +1561,8 @@ } void mlir::populateSPIRVToLLVMConversionPatterns( - LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + spirv::ClientAPI clientAPI) { patterns.add< // Arithmetic ops DirectConversionPattern, @@ -1605,9 +1657,8 @@ NotPattern, // Memory ops - AccessChainPattern, AddressOfPattern, GlobalVariablePattern, - LoadStorePattern, LoadStorePattern, - VariablePattern, + AccessChainPattern, AddressOfPattern, LoadStorePattern, + LoadStorePattern, VariablePattern, // Miscellaneous ops CompositeExtractPattern, CompositeInsertPattern, @@ -1622,6 +1673,9 @@ // Return ops ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter); + + patterns.add(clientAPI, patterns.getContext(), + typeConverter); } void mlir::populateSPIRVToLLVMFunctionConversionPatterns( diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -50,16 +51,22 @@ RewritePatternSet patterns(context); - populateSPIRVToLLVMTypeConversion(converter); + populateSPIRVToLLVMTypeConversion(converter, clientAPI); populateSPIRVToLLVMModuleConversionPatterns(converter, patterns); - populateSPIRVToLLVMConversionPatterns(converter, patterns); + populateSPIRVToLLVMConversionPatterns(converter, patterns, clientAPI); populateSPIRVToLLVMFunctionConversionPatterns(converter, patterns); ConversionTarget target(*context); target.addIllegalDialect(); target.addLegalDialect(); + if (clientAPI != spirv::ClientAPI::OpenCL && + clientAPI != spirv::ClientAPI::Unknown) + getOperation()->emitWarning() + << "address space mapping for client '" + << spirv::stringifyClientAPI(clientAPI) << "' not implemented"; + // Set `ModuleOp` as legal for `spirv.module` conversion. target.addLegalOp(); if (failed(applyPartialConversion(module, target, std::move(patterns)))) diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping-unsupported.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping-unsupported.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping-unsupported.mlir @@ -0,0 +1,5 @@ +// RUN: mlir-opt -convert-spirv-to-llvm='client-api=Metal' -verify-diagnostics %s +// RUN: mlir-opt -convert-spirv-to-llvm='client-api=Vulkan' -verify-diagnostics %s +// RUN: mlir-opt -convert-spirv-to-llvm='client-api=WebGPU' -verify-diagnostics %s + +module {} // expected-warning-re {{address space mapping for client '{{.*}}' not implemented}} diff --git a/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/spirv-storage-class-mapping.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt -convert-spirv-to-llvm='use-opaque-pointers=1' -verify-diagnostics %s | FileCheck %s --check-prefixes=CHECK-UNKNOWN,CHECK-ALL +// RUN: mlir-opt -convert-spirv-to-llvm='use-opaque-pointers=1 client-api=OpenCL' -verify-diagnostics %s | FileCheck %s --check-prefixes=CHECK-OPENCL,CHECK-ALL + +// CHECK-OPENCL: llvm.func @pointerUniformConstant(!llvm.ptr<2>) +// CHECK-UNKNOWN: llvm.func @pointerUniformConstant(!llvm.ptr) +spirv.func @pointerUniformConstant(!spirv.ptr) "None" + +// CHECK-OPENCL: llvm.mlir.global external constant @varUniformConstant() {addr_space = 2 : i32} : i1 +// CHECK-UNKNOWN: llvm.mlir.global external constant @varUniformConstant() {addr_space = 0 : i32} : i1 +spirv.GlobalVariable @varUniformConstant : !spirv.ptr + +// CHECK-OPENCL: llvm.func @pointerInput(!llvm.ptr<1>) +// CHECK-UNKNOWN: llvm.func @pointerInput(!llvm.ptr) +spirv.func @pointerInput(!spirv.ptr) "None" + +// CHECK-OPENCL: llvm.mlir.global external constant @varInput() {addr_space = 1 : i32} : i1 +// CHECK-UNKNOWN: llvm.mlir.global external constant @varInput() {addr_space = 0 : i32} : i1 +spirv.GlobalVariable @varInput : !spirv.ptr + +// CHECK-ALL: llvm.func @pointerUniform(!llvm.ptr) +spirv.func @pointerUniform(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerOutput(!llvm.ptr) +spirv.func @pointerOutput(!spirv.ptr) "None" + +// CHECK-ALL: llvm.mlir.global external @varOutput() {addr_space = 0 : i32} : i1 +spirv.GlobalVariable @varOutput : !spirv.ptr + +// CHECK-OPENCL: llvm.func @pointerWorkgroup(!llvm.ptr<3>) +// CHECK-UNKNOWN: llvm.func @pointerWorkgroup(!llvm.ptr) +spirv.func @pointerWorkgroup(!spirv.ptr) "None" + +// CHECK-OPENCL: llvm.func @pointerCrossWorkgroup(!llvm.ptr<1>) +// CHECK-UNKNOWN: llvm.func @pointerCrossWorkgroup(!llvm.ptr) +spirv.func @pointerCrossWorkgroup(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerPrivate(!llvm.ptr) +spirv.func @pointerPrivate(!spirv.ptr) "None" + +// CHECK-ALL: llvm.mlir.global private @varPrivate() {addr_space = 0 : i32} : i1 +spirv.GlobalVariable @varPrivate : !spirv.ptr + +// CHECK-ALL: llvm.func @pointerFunction(!llvm.ptr) +spirv.func @pointerFunction(!spirv.ptr) "None" + +// CHECK-OPENCL: llvm.func @pointerGeneric(!llvm.ptr<4>) +// CHECK-UNKNOWN: llvm.func @pointerGeneric(!llvm.ptr) +spirv.func @pointerGeneric(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerPushConstant(!llvm.ptr) +spirv.func @pointerPushConstant(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerAtomicCounter(!llvm.ptr) +spirv.func @pointerAtomicCounter(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerImage(!llvm.ptr) +spirv.func @pointerImage(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerStorageBuffer(!llvm.ptr) +spirv.func @pointerStorageBuffer(!spirv.ptr) "None" + +// CHECK-ALL: llvm.mlir.global external @varStorageBuffer() {addr_space = 0 : i32} : i1 +spirv.GlobalVariable @varStorageBuffer : !spirv.ptr + +// CHECK-ALL: llvm.func @pointerCallableDataKHR(!llvm.ptr) +spirv.func @pointerCallableDataKHR(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerIncomingCallableDataKHR(!llvm.ptr) +spirv.func @pointerIncomingCallableDataKHR(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerRayPayloadKHR(!llvm.ptr) +spirv.func @pointerRayPayloadKHR(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerHitAttributeKHR(!llvm.ptr) +spirv.func @pointerHitAttributeKHR(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerIncomingRayPayloadKHR(!llvm.ptr) +spirv.func @pointerIncomingRayPayloadKHR(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerShaderRecordBufferKHR(!llvm.ptr) +spirv.func @pointerShaderRecordBufferKHR(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerPhysicalStorageBuffer(!llvm.ptr) +spirv.func @pointerPhysicalStorageBuffer(!spirv.ptr) "None" + +// CHECK-ALL: llvm.func @pointerCodeSectionINTEL(!llvm.ptr) +spirv.func @pointerCodeSectionINTEL(!spirv.ptr) "None" + +// CHECK-OPENCL: llvm.func @pointerDeviceOnlyINTEL(!llvm.ptr<5>) +// CHECK-UNKNOWN: llvm.func @pointerDeviceOnlyINTEL(!llvm.ptr) +spirv.func @pointerDeviceOnlyINTEL(!spirv.ptr) "None" + +// CHECK-OPENCL: llvm.func @pointerHostOnlyINTEL(!llvm.ptr<6>) +// CHECK-UNKOWN: llvm.func @pointerHostOnlyINTEL(!llvm.ptr) +spirv.func @pointerHostOnlyINTEL(!spirv.ptr) "None"