diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h --- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h @@ -22,9 +22,15 @@ namespace spirv { /// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones. -using MemorySpaceToStorageClassMap = DenseMap; -/// Returns the default map for targeting Vulkan-flavored SPIR-V. -MemorySpaceToStorageClassMap getDefaultVulkanStorageClassMap(); +using MemorySpaceToStorageClassMap = + std::function(unsigned)>; + +/// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V +/// using the default rule. Returns None if the memory space is unknown. +Optional mapMemorySpaceToVulkanStorageClass(unsigned); +/// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces +/// using the default rule. Returns None if the storage class is unsupported. +Optional mapVulkanStorageClassToMemorySpace(spirv::StorageClass); /// Type converter for converting numeric MemRef memory spaces into SPIR-V /// symbolic ones. @@ -34,7 +40,7 @@ const MemorySpaceToStorageClassMap &memorySpaceMap); private: - const MemorySpaceToStorageClassMap &memorySpaceMap; + MemorySpaceToStorageClassMap memorySpaceMap; }; /// Creates the target that populates legality of ops with MemRef types. diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp @@ -64,7 +64,7 @@ std::unique_ptr target = spirv::getMemorySpaceToStorageClassTarget(*context); spirv::MemorySpaceToStorageClassMap memorySpaceMap = - spirv::getDefaultVulkanStorageClassMap(); + spirv::mapMemorySpaceToVulkanStorageClass; spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -30,7 +31,6 @@ // Mappings //===----------------------------------------------------------------------===// -spirv::MemorySpaceToStorageClassMap spirv::getDefaultVulkanStorageClassMap() { /// Mapping between SPIR-V storage classes to memref memory spaces. /// /// Note: memref does not have a defined semantics for each memory space; it @@ -47,29 +47,42 @@ MAP_FN(spirv::StorageClass::PushConstant, 7) \ MAP_FN(spirv::StorageClass::UniformConstant, 8) \ MAP_FN(spirv::StorageClass::Input, 9) \ - MAP_FN(spirv::StorageClass::Output, 10) \ - MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \ - MAP_FN(spirv::StorageClass::AtomicCounter, 12) \ - MAP_FN(spirv::StorageClass::Image, 13) \ - MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \ - MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \ - MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \ - MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \ - MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \ - MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \ - MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \ - MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \ - MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \ - MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23) - -#define STORAGE_SPACE_MAP_FN(storage, space) {space, storage}, - - return {STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)}; + MAP_FN(spirv::StorageClass::Output, 10) + +Optional +spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case space: \ + return storage; + + switch (memorySpace) { + STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + default: + break; + } + return llvm::None; + +#undef STORAGE_SPACE_MAP_FN +} + +Optional +spirv::mapVulkanStorageClassToMemorySpace(spirv::StorageClass storageClass) { +#define STORAGE_SPACE_MAP_FN(storage, space) \ + case storage: \ + return space; + + switch (storageClass) { + STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) + default: + break; + } + return llvm::None; #undef STORAGE_SPACE_MAP_FN -#undef STORAGE_SPACE_MAP_LIST } +#undef STORAGE_SPACE_MAP_LIST + //===----------------------------------------------------------------------===// // Type Converter //===----------------------------------------------------------------------===// @@ -91,8 +104,8 @@ } unsigned space = memRefType.getMemorySpaceAsInt(); - auto it = this->memorySpaceMap.find(space); - if (it == this->memorySpaceMap.end()) { + auto storage = this->memorySpaceMap(space); + if (!storage) { LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType << " due to being unable to find memory space in map\n"); @@ -100,7 +113,7 @@ } auto storageAttr = - spirv::StorageClassAttr::get(memRefType.getContext(), it->second); + spirv::StorageClassAttr::get(memRefType.getContext(), *storage); if (auto rankedType = memRefType.dyn_cast()) { return MemRefType::get(memRefType.getShape(), memRefType.getElementType(), rankedType.getLayout(), storageAttr); @@ -231,16 +244,7 @@ : public MapMemRefStorageClassBase { public: explicit MapMemRefStorageClassPass() { - memorySpaceMap = spirv::getDefaultVulkanStorageClassMap(); - - LLVM_DEBUG({ - llvm::dbgs() << "memory space to storage class mapping:\n"; - if (memorySpaceMap.empty()) - llvm::dbgs() << " [empty]\n"; - for (auto kv : memorySpaceMap) - llvm::dbgs() << " " << kv.first << " -> " - << spirv::stringifyStorageClass(kv.second) << "\n"; - }); + memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass; } explicit MapMemRefStorageClassPass( const spirv::MemorySpaceToStorageClassMap &memorySpaceMap)