diff --git a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h --- a/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h +++ b/mlir/include/mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h @@ -23,18 +23,18 @@ namespace spirv { /// Mapping from numeric MemRef memory spaces into SPIR-V symbolic ones. using MemorySpaceToStorageClassMap = - std::function(unsigned)>; + std::function(Attribute)>; /// Maps MemRef memory spaces to storage classes for Vulkan-flavored SPIR-V /// using the default rule. Returns None if the memory space is unknown. -Optional mapMemorySpaceToVulkanStorageClass(unsigned); +Optional mapMemorySpaceToVulkanStorageClass(Attribute); /// Maps storage classes for Vulkan-flavored SPIR-V to MemRef memory spaces /// using the default rule. Returns None if the storage class is unsupported. Optional mapVulkanStorageClassToMemorySpace(spirv::StorageClass); /// Maps MemRef memory spaces to storage classes for OpenCL-flavored SPIR-V /// using the default rule. Returns None if the memory space is unknown. -Optional mapMemorySpaceToOpenCLStorageClass(unsigned); +Optional mapMemorySpaceToOpenCLStorageClass(Attribute); /// Maps storage classes for OpenCL-flavored SPIR-V to MemRef memory spaces /// using the default rule. Returns None if the storage class is unsupported. Optional mapOpenCLStorageClassToMemorySpace(spirv::StorageClass); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -56,7 +57,18 @@ MAP_FN(spirv::StorageClass::Output, 10) Optional -spirv::mapMemorySpaceToVulkanStorageClass(unsigned memorySpace) { +spirv::mapMemorySpaceToVulkanStorageClass(Attribute memorySpaceAttr) { + // Handle null memory space attribute specially. + if (!memorySpaceAttr) + return spirv::StorageClass::StorageBuffer; + + // Unknown dialect custom attributes are not supported by default. + // Downstream callers should plug in more specialized ones. + auto intAttr = memorySpaceAttr.dyn_cast(); + if (!intAttr) + return llvm::None; + unsigned memorySpace = intAttr.getInt(); + #define STORAGE_SPACE_MAP_FN(storage, space) \ case space: \ return storage; @@ -99,7 +111,18 @@ MAP_FN(spirv::StorageClass::Image, 7) Optional -spirv::mapMemorySpaceToOpenCLStorageClass(unsigned memorySpace) { +spirv::mapMemorySpaceToOpenCLStorageClass(Attribute memorySpaceAttr) { + // Handle null memory space attribute specially. + if (!memorySpaceAttr) + return spirv::StorageClass::CrossWorkgroup; + + // Unknown dialect custom attributes are not supported by default. + // Downstream callers should plug in more specialized ones. + auto intAttr = memorySpaceAttr.dyn_cast(); + if (!intAttr) + return llvm::None; + unsigned memorySpace = intAttr.getInt(); + #define STORAGE_SPACE_MAP_FN(storage, space) \ case space: \ return storage; @@ -143,17 +166,8 @@ addConversion([](Type type) { return type; }); addConversion([this](BaseMemRefType memRefType) -> Optional { - // Expect IntegerAttr memory spaces. The attribute can be missing for the - // case of memory space == 0. - Attribute spaceAttr = memRefType.getMemorySpace(); - if (spaceAttr && !spaceAttr.isa()) { - LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType - << " due to non-IntegerAttr memory space\n"); - return llvm::None; - } - - unsigned space = memRefType.getMemorySpaceAsInt(); - auto storage = this->memorySpaceMap(space); + Optional storage = + this->memorySpaceMap(memRefType.getMemorySpace()); if (!storage) { LLVM_DEBUG(llvm::dbgs() << "cannot convert " << memRefType