diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MapMemRefStorageClassPass.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionInterfaces.h" #include "mlir/Transforms/DialectConversion.h" @@ -325,6 +326,15 @@ MLIRContext *context = &getContext(); Operation *op = getOperation(); + if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) { + spirv::TargetEnv targetEnv(attr); + if (targetEnv.allows(spirv::Capability::Kernel)) { + memorySpaceMap = spirv::mapMemorySpaceToOpenCLStorageClass; + } else if (targetEnv.allows(spirv::Capability::Shader)) { + memorySpaceMap = spirv::mapMemorySpaceToVulkanStorageClass; + } + } + auto target = spirv::getMemorySpaceToStorageClassTarget(*context); spirv::MemorySpaceToStorageClassConverter converter(memorySpaceMap); diff --git a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/map-storage-class.mlir @@ -114,3 +114,56 @@ %0 = "dialect.memref_producer"() : () -> (memref) return } + +// ----- + +/// Checks memory maps to OpenCL mapping if Kernel capability is enabled. +module attributes { spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { +func.func @operand_result() { + // CHECK: memref> + %0 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<4xi32, #spv.storage_class> + %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>) + // CHECK: memref> + %2 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<*xf16, #spv.storage_class> + %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>) + + + "dialect.memref_consumer"(%0) : (memref) -> () + // CHECK: memref<4xi32, #spv.storage_class> + "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> () + // CHECK: memref> + "dialect.memref_consumer"(%2) : (memref) -> () + // CHECK: memref<*xf16, #spv.storage_class> + "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () + + return +} +} + +// ----- + +/// Checks memory maps to Vulkan mapping if Shader capability is enabled. +module attributes { spv.target_env = #spv.target_env<#spv.vce, #spv.resource_limits<>> } { +func.func @operand_result() { + // CHECK: memref> + %0 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<4xi32, #spv.storage_class> + %1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>) + // CHECK: memref> + %2 = "dialect.memref_producer"() : () -> (memref) + // CHECK: memref<*xf16, #spv.storage_class> + %3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>) + + + "dialect.memref_consumer"(%0) : (memref) -> () + // CHECK: memref<4xi32, #spv.storage_class> + "dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> () + // CHECK: memref> + "dialect.memref_consumer"(%2) : (memref) -> () + // CHECK: memref<*xf16, #spv.storage_class> + "dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> () + return +} +} \ No newline at end of file