diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h --- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h @@ -32,6 +32,13 @@ /// using the default rule. Returns None if the storage class is unsupported. Optional mapVulkanStorageClassToMemorySpace(spirv::StorageClass); +/// Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V +/// using the default rule. Returns None if the memory space is unknown. +Optional mapMemorySpaceToOpenCLStorageClass(unsigned); +/// Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces +/// using the default rule. Returns None if the storage class is unsupported. +Optional mapOpenCLStorageClassToMemorySpace(spirv::StorageClass); + /// Type converter for converting numeric MemRef memory spaces into SPIR-V /// symbolic ones. class MemorySpaceToStorageClassConverter : public TypeConverter { diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -37,7 +37,7 @@ /// depends on the context where it is used. There are no particular reasons /// behind the number assignments; we try to follow NVVM conventions and largely /// give common storage classes a smaller number. -#define STORAGE_SPACE_MAP_LIST(MAP_FN) \ +#define VULKAN_STORAGE_SPACE_MAP_LIST(MAP_FN) \ MAP_FN(spirv::StorageClass::StorageBuffer, 0) \ MAP_FN(spirv::StorageClass::Generic, 1) \ MAP_FN(spirv::StorageClass::Workgroup, 3) \ @@ -56,7 +56,7 @@ return storage; switch (memorySpace) { - STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) default: break; } @@ -72,7 +72,7 @@ return space; switch (storageClass) { - STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + VULKAN_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) default: break; } @@ -81,7 +81,53 @@ #undef STORAGE_SPACE_MAP_FN } -#undef STORAGE_SPACE_MAP_LIST +#undef VULKAN_STORAGE_SPACE_MAP_LIST + +#define OPENCL_STORAGE_SPACE_MAP_LIST(MAP_FN) \ + MAP_FN(spirv::StorageClass::CrossWorkgroup, 0) \ + MAP_FN(spirv::StorageClass::Generic, 1) \ + MAP_FN(spirv::StorageClass::Workgroup, 3) \ + MAP_FN(spirv::StorageClass::Uniform, 4) \ + MAP_FN(spirv::StorageClass::Private, 5) \ + MAP_FN(spirv::StorageClass::Function, 6) \ + MAP_FN(spirv::StorageClass::Image, 7) \ + MAP_FN(spirv::StorageClass::UniformConstant, 8) \ + MAP_FN(spirv::StorageClass::Input, 9) \ + MAP_FN(spirv::StorageClass::Output, 10) + +Optional +spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case space: \ + return storage; + + switch (memorySpace) { + OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + default: + break; + } + return llvm::None; + +#undef STORAGE_SPACE_MAP_FN +} + +Optional +spirv::mapOpenCLStorageClassToMemorySpace(spirv::StorageClass storageClass) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case storage: \ + return space; + + switch (storageClass) { + OPENCL_STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + default: + break; + } + return llvm::None; + +#undef STORAGE_SPACE_MAP_FN +} + +#undef OPENCL_STORAGE_SPACE_MAP_LIST //===----------------------------------------------------------------------===// // Type Converter @@ -263,7 +309,11 @@ if (failed(Pass::initializeOptions(options))) return failure(); - if (clientAPI != "vulkan") + if (clientAPI == "opencl") { + memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass; + } + + if (clientAPI != "vulkan" && clientAPI != "opencl") return failure(); return success(); diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=opencl' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=OPENCL // Vulkan Mappings: // 0 -> StorageBuffer @@ -7,24 +8,39 @@ // 3 -> Workgroup // 4 -> Uniform +// OpenCL Mappings: +// 0 -> CrossWorkgroup +// 1 -> Generic +// 2 -> [null] +// 3 -> Workgroup +// 4 -> Uniform + // VULKAN-LABEL: func @operand_result +// OPENCL-LABEL: func @operand_result func.func @operand_result() { // VULKAN: memref> + // OPENCL: memref> %0 = "dialect.memref_producer"() : () -> (memref) // VULKAN: memref<4xi32, #spv.storage_class> + // OPENCL: memref<4xi32, #spv.storage_class> %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>) // VULKAN: memref> + // OPENCL: memref> %2 = "dialect.memref_producer"() : () -> (memref) // VULKAN: memref<*xf16, #spv.storage_class> + // OPENCL: memref<*xf16, #spv.storage_class> %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>) "dialect.memref_consumer"(%0) : (memref) -> () // VULKAN: memref<4xi32, #spv.storage_class> + // OPENCL: memref<4xi32, #spv.storage_class> "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> () // VULKAN: memref> + // OPENCL: memref> "dialect.memref_consumer"(%2) : (memref) -> () // VULKAN: memref<*xf16, #spv.storage_class> + // OPENCL: memref<*xf16, #spv.storage_class> "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () return @@ -33,8 +49,10 @@ // ----- // VULKAN-LABEL: func @type_attribute +// OPENCL-LABEL: func @type_attribute func.func @type_attribute() { // VULKAN: attr = memref> + // OPENCL: attr = memref> "dialect.memref_producer"() { attr = memref } : () -> () return } @@ -42,10 +60,13 @@ // ----- // VULKAN-LABEL: func.func @function_io +// OPENCL-LABEL: func.func @function_io func.func @function_io // VULKAN-SAME: (%{{.+}}: memref>, %{{.+}}: memref<4xi32, #spv.storage_class>) + // OPENCL-SAME: (%{{.+}}: memref>, %{{.+}}: memref<4xi32, #spv.storage_class>) (%arg0: memref, %arg1: memref<4xi32, 3>) // VULKAN-SAME: -> (memref>, memref<4xi32, #spv.storage_class>) + // OPENCL-SAME: -> (memref>, memref<4xi32, #spv.storage_class>) -> (memref, memref<4xi32, 3>) { return %arg0, %arg1: memref, memref<4xi32, 3> } @@ -54,17 +75,22 @@ gpu.module @kernel { // VULKAN-LABEL: gpu.func @function_io +// OPENCL-LABEL: gpu.func @function_io // VULKAN-SAME: memref<8xi32, #spv.storage_class> +// OPENCL-SAME: memref<8xi32, #spv.storage_class> gpu.func @function_io(%arg0 : memref<8xi32>) kernel { gpu.return } } // ----- // VULKAN-LABEL: func.func @region +// OPENCL-LABEL: func.func @region func.func @region(%cond: i1, %arg0: memref) { scf.if %cond { // VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref>} + // OPENCL: "dialect.memref_consumer"(%{{.+}}) {attr = memref>} // VULKAN-SAME: (memref>) -> memref> + // OPENCL-SAME: (memref>) -> memref> %0 = "dialect.memref_consumer"(%arg0) { attr = memref } : (memref) -> (memref) } return @@ -73,8 +99,10 @@ // ----- // VULKAN-LABEL: func @non_memref_types +// OPENCL-LABEL: func @non_memref_types func.func @non_memref_types(%arg: f32) -> f32 { // VULKAN: "dialect.op"(%{{.+}}) {attr = 16 : i64} : (f32) -> f32 + // OPENCL: "dialect.op"(%{{.+}}) {attr = 16 : i64} : (f32) -> f32 %0 = "dialect.op"(%arg) { attr = 16 } : (f32) -> (f32) return %0 : f32 }