Index: mlir/tools/CMakeLists.txt =================================================================== --- mlir/tools/CMakeLists.txt +++ mlir/tools/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(mlir-opt) add_subdirectory(mlir-tblgen) add_subdirectory(mlir-translate) +add_subdirectory(mlir-vulkan-runner) Index: mlir/tools/mlir-vulkan-runner/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/tools/mlir-vulkan-runner/CMakeLists.txt @@ -0,0 +1,46 @@ +set(LLVM_OPTIONAL_SOURCES + VulkanRuntime.cpp + ) + +if (MLIR_VULKAN_RUNNER_ENABLED) + message(STATUS "Building the Vulkan runner") + + # At first try "FindVulkan" from: + # https://cmake.org/cmake/help/v3.7/module/FindVulkan.html + if (NOT CMAKE_VERSION VERSION_LESS 3.7.0) + find_package(Vulkan) + endif() + + # If Vulkan is not found try a path specified by VULKAN_SDK. + if (NOT Vulkan_FOUND) + if ("$ENV{VULKAN_SDK}" STREQUAL "") + message(FATAL_ERROR "Please use at least CMAKE 3.7.0 or provide VULKAN_SDK + path as an environment variable") + endif() + + find_library(Vulkan_LIBRARY vulkan HINTS "$ENV{VULKAN_SDK}/lib" REQUIRED) + if (Vulkan_LIBRARY) + set(Vulkan_FOUND ON) + set(Vulkan_INCLUDE_DIR "$ENV{VULKAN_SDK}/include") + message(STATUS "Found Vulkan: " ${Vulkan_LIBRARY}) + endif() + endif() + + if (NOT Vulkan_FOUND) + message(FATAL_ERROR "Cannot find Vulkan library") + endif() + + # Create archive for this moment. + add_llvm_library(MLIRVulkanRuntime + VulkanRuntime.cpp + ) + + target_include_directories(MLIRVulkanRuntime PRIVATE ${Vulkan_INCLUDE_DIR}) + + target_link_libraries(MLIRVulkanRuntime + MLIRSPIRVSerialization + LLVMCore + LLVMSupport + ${Vulkan_LIBRARY} + ) +endif() Index: mlir/tools/mlir-vulkan-runner/VulkanRuntime.cpp =================================================================== --- /dev/null +++ mlir/tools/mlir-vulkan-runner/VulkanRuntime.cpp @@ -0,0 +1,1029 @@ +//===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file provides a library for running a module on a Vulkan device. +// Implements a Vulkan runtime to run a spirv::ModuleOp. It also defines a few +// utility functions to extract information from a spirv::ModuleOp. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/IR/Module.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/StringExtras.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ToolOutputFile.h" + +#include + +using namespace mlir; + +using DescriptorSetIndex = uint32_t; +using BindingIndex = uint32_t; + +/// Struct containing information regarding to a device memory buffer. +struct VulkanDeviceMemoryBuffer { + BindingIndex bindingIndex{0}; + VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; + VkDescriptorBufferInfo bufferInfo{VK_NULL_HANDLE}; + VkBuffer buffer{VK_NULL_HANDLE}; + VkDeviceMemory deviceMemory{VK_NULL_HANDLE}; +}; + +/// Struct containing information regarding to a host memory buffer. +struct VulkanHostMemoryBuffer { + /// Pointer to a host memory. + void *ptr{nullptr}; + /// Size of a host memory in bytes. + uint32_t size{0}; +}; + +/// Struct containing the number of local workgroups to dispatch for each +/// dimension. +struct NumWorkGroups { + uint32_t x{1}; + uint32_t y{1}; + uint32_t z{1}; +}; + +/// Struct containing information regarding a descriptor set. +struct DescriptorSetInfo { + /// Index of a descriptor set in descriptor sets. + DescriptorSetIndex descriptorSet{0}; + /// Number of desriptors in a set. + uint32_t descriptorSize{0}; + /// Type of a descriptor set. + VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; +}; + +/// VulkanHostMemoryBuffer mapped into a descriptor set and a binding. +using ResourceData = + llvm::DenseMap>; + +/// StorageClass mapped into a descriptor set and a binding. +using ResourceStorageClassData = + llvm::DenseMap>; + +inline void emitVulkanError(const llvm::Twine &message, VkResult error) { + llvm::errs() + << message.concat(" failed with error code ").concat(llvm::Twine{error}); +} + +#define RETURN_ON_VULKAN_ERROR(result, msg) \ + if ((result) != VK_SUCCESS) { \ + emitVulkanError(msg, (result)); \ + return failure(); \ + } + +namespace { +/// Processes spv.module and collects all needed information for VulkanRuntime. +class SPIRVModuleInfoCollector { +public: + explicit SPIRVModuleInfoCollector() = default; + std::string getEntryPoint() { return entryPoint; } + ResourceStorageClassData &getResourceStorageClassData() { + return resourceStorageClassData; + } + void processModule(spirv::ModuleOp module); + +private: + SPIRVModuleInfoCollector(const SPIRVModuleInfoCollector &) = delete; + SPIRVModuleInfoCollector &operator=(const SPIRVModuleInfoCollector &) = delete; + void processGlobalVariable(spirv::GlobalVariableOp varOp); + void processEntryPoint(spirv::EntryPointOp op); + + std::string entryPoint; + ResourceStorageClassData resourceStorageClassData; +}; + +/// Processes ModuleOp to collect module specific information. +/// Note: ModuleOp must be valid. +void SPIRVModuleInfoCollector::processModule(spirv::ModuleOp module) { + for (auto &op : module.getBlock()) { + if (auto entryPointOp = dyn_cast(op)) { + processEntryPoint(entryPointOp); + } else if (auto varOp = dyn_cast(op)) { + processGlobalVariable(varOp); + } + } +} + +/// Processes EntryPointOp to collect entry point. +void SPIRVModuleInfoCollector::processEntryPoint(spirv::EntryPointOp op) { + entryPoint = op.fn(); +} + +/// Processes GlobalVariableOp to collect storage classes for resource data. +void SPIRVModuleInfoCollector::processGlobalVariable( + spirv::GlobalVariableOp varOp) { + auto descriptorSetName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto descriptorSet = varOp.getAttrOfType(descriptorSetName); + auto binding = varOp.getAttrOfType(bindingName); + + if (descriptorSet && binding) { + if (auto ptrType = varOp.type().dyn_cast()) { + auto descriptorBindingIndex = binding.getInt(); + auto descriptorSetIndex = descriptorSet.getInt(); + resourceStorageClassData[descriptorSetIndex][descriptorBindingIndex] = + ptrType.getStorageClass(); + } + } +} +} // namespace + +namespace { + +/// Vulkan runtime. +/// The purpose of this class is to run SPIR-V computation shader on Vulkan +/// device. +/// Before the run, user must provide and set resource data with descriptors, +/// spir-v shader, number of work groups and entry point. After the creation of +/// VulkanRuntime, special methods must be called in the following +/// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy(); +/// each method in the sequence returns succes or failure depends on the Vulkan +/// result code. +class VulkanRuntime { +public: + explicit VulkanRuntime() = default; + VulkanRuntime(const VulkanRuntime &) = delete; + VulkanRuntime &operator=(const VulkanRuntime &) = delete; + + /// Sets needed data for Vulkan runtime. + void setResourceData(const ResourceData &resData); + void setShaderModule(llvm::ArrayRef binaryRef); + void setNumWorkGroups(const NumWorkGroups &nWorkGroups); + void setResourceStorageClassData(const ResourceStorageClassData &stClassData); + void setEntryPoint(llvm::StringRef entryPointName); + + /// Runtime initialization. + LogicalResult initRuntime(); + + /// Runs runtime. + LogicalResult run(); + + /// Updates host memory buffers. + LogicalResult updateHostMemoryBuffers(); + + /// Destroys all created vulkan objects and resources. + LogicalResult destroy(); + +private: + //===--------------------------------------------------------------------===// + // Pipeline creation methods. + //===--------------------------------------------------------------------===// + + LogicalResult createInstance(); + LogicalResult createDevice(); + LogicalResult getBestComputeQueue(const VkPhysicalDevice &physicalDevice); + LogicalResult createMemoryBuffers(); + LogicalResult createShaderModule(); + void initDescriptorSetLayoutBindingMap(); + LogicalResult createDescriptorSetLayout(); + LogicalResult createPipelineLayout(); + LogicalResult createComputePipeline(); + LogicalResult createDescriptorPool(); + LogicalResult allocateDescriptorSets(); + LogicalResult setWriteDescriptors(); + LogicalResult createCommandPool(); + LogicalResult createComputeCommandBuffer(); + LogicalResult submitCommandBuffersToQueue(); + + //===--------------------------------------------------------------------===// + // Helper methods. + //===--------------------------------------------------------------------===// + + /// Maps storage class to a descriptor type. + LogicalResult + mapStorageClassToDescriptorType(spirv::StorageClass storageClass, + VkDescriptorType &descriptorType); + + /// Maps storage class to buffer usage flags. + LogicalResult + mapStorageClassToBufferUsageFlag(spirv::StorageClass storageClass, + VkBufferUsageFlagBits &bufferUsage); + + LogicalResult countDeviceMemorySize(); + + //===--------------------------------------------------------------------===// + // Vulkan objects. + //===--------------------------------------------------------------------===// + + VkInstance instance; + VkDevice device; + VkQueue queue; + + /// Specifies VulkanDeviceMemoryBuffers divided into sets. + llvm::DenseMap> + deviceMemoryBufferMap; + + /// Specifies shader module. + VkShaderModule shaderModule; + + /// Specifies layout bindings. + llvm::DenseMap> + descriptorSetLayoutBindingMap; + + /// Specifies layouts of descriptor sets. + llvm::SmallVector descriptorSetLayouts; + VkPipelineLayout pipelineLayout; + + /// Specifies descriptor sets. + llvm::SmallVector descriptorSets; + + /// Specifies a pool of descriptor set info, each descriptor set must have + /// information such as type, index and amount of bindings. + llvm::SmallVector descriptorSetInfoPool; + VkDescriptorPool descriptorPool; + + /// Computation pipeline. + VkPipeline pipeline; + VkCommandPool commandPool; + llvm::SmallVector commandBuffers; + + //===--------------------------------------------------------------------===// + // Vulkan memory context. + //===--------------------------------------------------------------------===// + + uint32_t queueFamilyIndex{0}; + uint32_t memoryTypeIndex{VK_MAX_MEMORY_TYPES}; + VkDeviceSize memorySize{0}; + + //===--------------------------------------------------------------------===// + // Vulkan execution context. + //===--------------------------------------------------------------------===// + + NumWorkGroups numWorkGroups; + std::string entryPoint; + llvm::SmallVector binary; + + //===--------------------------------------------------------------------===// + // Vulkan resource data and storage classes. + //===--------------------------------------------------------------------===// + + ResourceData resourceData; + ResourceStorageClassData resourceStorageClassData; +}; + +void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &nWorkGroups) { + numWorkGroups = nWorkGroups; +} + +void VulkanRuntime::setResourceStorageClassData( + const ResourceStorageClassData &stClassData) { + resourceStorageClassData = stClassData; +} + +void VulkanRuntime::setEntryPoint(llvm::StringRef entryPointName) { + entryPoint = entryPointName; +} + +void VulkanRuntime::setResourceData(const ResourceData &resData) { + resourceData = resData; +} + +void VulkanRuntime::setShaderModule(llvm::ArrayRef binaryRef) { + binary = SmallVector(binaryRef.begin(), binaryRef.end()); +} + +LogicalResult VulkanRuntime::mapStorageClassToDescriptorType( + spirv::StorageClass storageClass, VkDescriptorType &descriptorType) { + switch (storageClass) { + case spirv::StorageClass::StorageBuffer: + descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + break; + case spirv::StorageClass::Uniform: + descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; + break; + default: + return failure(); + } + return success(); +} + +LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag( + spirv::StorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) { + switch (storageClass) { + case spirv::StorageClass::StorageBuffer: + bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + break; + case spirv::StorageClass::Uniform: + bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + break; + default: + return failure(); + } + return success(); +} + +LogicalResult VulkanRuntime::countDeviceMemorySize() { + for (const auto &resourceDataMapPair : resourceData) { + const auto &resourceDataMap = resourceDataMapPair.second; + for (const auto &resourceDataBindingPair : resourceDataMap) { + if (resourceDataBindingPair.second.size) { + memorySize += resourceDataBindingPair.second.size; + } else { + return failure(); + } + } + } + return success(); +} + +LogicalResult VulkanRuntime::initRuntime() { + if (!resourceData.size()) { + llvm::errs() << "Vulkan runtime needs at least one resource"; + return failure(); + } + if (!binary.size()) { + llvm::errs() << "binary shader size must be greater than zero"; + return failure(); + } + if (failed(countDeviceMemorySize())) { + return failure(); + } + return success(); +} + +LogicalResult VulkanRuntime::destroy() { + // Free and destroy. + vkFreeCommandBuffers(device, commandPool, commandBuffers.size(), + commandBuffers.data()); + vkDestroyCommandPool(device, commandPool, nullptr); + vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(), + descriptorSets.data()); + vkDestroyDescriptorPool(device, descriptorPool, nullptr); + vkDestroyPipeline(device, pipeline, nullptr); + vkDestroyPipelineLayout(device, pipelineLayout, nullptr); + for (auto &descriptorSetLayout: descriptorSetLayouts) { + vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); + } + vkDestroyShaderModule(device, shaderModule, nullptr); + + // For each descriptor set. + for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { + auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; + // For each descirptor binding. + for (auto &memoryBuffer : deviceMemoryBuffers) { + vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr); + vkDestroyBuffer(device, memoryBuffer.buffer, nullptr); + } + } + + // Wait for device. + RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle"); + vkDestroyDevice(device, nullptr); + vkDestroyInstance(instance, nullptr); + return success(); +} + +LogicalResult VulkanRuntime::run() { + // Create logical device, shader module and memory buffers. + if (failed(createInstance()) || failed(createDevice()) || + failed(createMemoryBuffers()) || failed(createShaderModule())) { + return failure(); + } + + // Descriptor bindings divided into sets. Each descriptor binding + // must have a layout binding attached into a descriptor set layout. + // Each layout set must be binded into a pipeline layout. + initDescriptorSetLayoutBindingMap(); + if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) || + // Each descriptor set must be allocated from a descriptor pool. + failed(createComputePipeline()) || failed(createDescriptorPool()) || + failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) || + // Create command buffer. + failed(createCommandPool()) || failed(createComputeCommandBuffer())) { + return failure(); + } + + // Get working queue. + vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue); + + // Submit command buffer into the queue. + if (failed(submitCommandBuffersToQueue())) { + return failure(); + } + + RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); + return success(); +} + +LogicalResult VulkanRuntime::createInstance() { + VkApplicationInfo applicationInfo = {}; + applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + applicationInfo.pNext = nullptr; + applicationInfo.pApplicationName = "MLIR Vulkan runtime"; + applicationInfo.applicationVersion = 0; + applicationInfo.pEngineName = "mlir"; + applicationInfo.engineVersion = 0; + applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0); + + VkInstanceCreateInfo instanceCreateInfo = {}; + instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + instanceCreateInfo.pNext = nullptr; + instanceCreateInfo.flags = 0; + instanceCreateInfo.pApplicationInfo = &applicationInfo; + instanceCreateInfo.enabledLayerCount = 0; + instanceCreateInfo.ppEnabledLayerNames = 0; + instanceCreateInfo.enabledExtensionCount = 0; + instanceCreateInfo.ppEnabledExtensionNames = 0; + + RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance), + "vkCreateInstance"); + return success(); +} + +LogicalResult VulkanRuntime::createDevice() { + uint32_t physicalDeviceCount = 0; + RETURN_ON_VULKAN_ERROR( + vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0), + "vkEnumeratePhysicalDevices"); + + llvm::SmallVector physicalDevices(physicalDeviceCount); + RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance, + &physicalDeviceCount, + physicalDevices.data()), + "vkEnumeratePhysicalDevices"); + + RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE, + "physicalDeviceCount"); + + // TODO(denis0x0D): find the best device. + const auto &physicalDevice = physicalDevices.front(); + getBestComputeQueue(physicalDevice); + + const float queuePrioritory = 1.0f; + VkDeviceQueueCreateInfo deviceQueueCreateInfo = {}; + deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + deviceQueueCreateInfo.pNext = nullptr; + deviceQueueCreateInfo.flags = 0; + deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex; + deviceQueueCreateInfo.queueCount = 1; + deviceQueueCreateInfo.pQueuePriorities = &queuePrioritory; + + // Structure specifying parameters of a newly created device. + VkDeviceCreateInfo deviceCreateInfo = {}; + deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + deviceCreateInfo.pNext = nullptr; + deviceCreateInfo.flags = 0; + deviceCreateInfo.queueCreateInfoCount = 1; + deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo; + deviceCreateInfo.enabledLayerCount = 0; + deviceCreateInfo.ppEnabledLayerNames = nullptr; + deviceCreateInfo.enabledExtensionCount = 0; + deviceCreateInfo.ppEnabledExtensionNames = nullptr; + deviceCreateInfo.pEnabledFeatures = nullptr; + + RETURN_ON_VULKAN_ERROR( + vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device), + "vkCreateDevice"); + + VkPhysicalDeviceMemoryProperties properties = {}; + vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties); + + // Try to find memory type with following properties: + // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated + // with this type can be mapped for host access using vkMapMemory; + // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache + // management commands vkFlushMappedMemoryRanges and + // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the + // device or make device writes visible to the host, respectively. + for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { + if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT & + properties.memoryTypes[i].propertyFlags) && + (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT & + properties.memoryTypes[i].propertyFlags) && + (memorySize <= + properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { + memoryTypeIndex = i; + break; + } + } + + RETURN_ON_VULKAN_ERROR(memoryTypeIndex == VK_MAX_MEMORY_TYPES ? VK_INCOMPLETE + : VK_SUCCESS, + "invalid memoryTypeIndex"); + return success(); +} + +LogicalResult +VulkanRuntime::getBestComputeQueue(const VkPhysicalDevice &physicalDevice) { + uint32_t queueFamilyPropertiesCount = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, + &queueFamilyPropertiesCount, 0); + SmallVector queueFamilyProperties( + queueFamilyPropertiesCount); + + vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, + &queueFamilyPropertiesCount, + queueFamilyProperties.data()); + + // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support + // compute operations. + for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { + const VkQueueFlags maskedFlags = + (~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & + queueFamilyProperties[i].queueFlags); + + if (!(VK_QUEUE_GRAPHICS_BIT & maskedFlags) && + (VK_QUEUE_COMPUTE_BIT & maskedFlags)) { + queueFamilyIndex = i; + return success(); + } + } + + for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { + const VkQueueFlags maskedFlags = + (~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & + queueFamilyProperties[i].queueFlags); + + if (VK_QUEUE_COMPUTE_BIT & maskedFlags) { + queueFamilyIndex = i; + return success(); + } + } + + llvm::errs() << "cannot find valid queue"; + return failure(); +} + +LogicalResult VulkanRuntime::createMemoryBuffers() { + // For each descriptor set. + for (const auto &resourceDataMapPair : resourceData) { + llvm::SmallVector deviceMemoryBuffers; + const auto descriptorSetIndex = resourceDataMapPair.first; + const auto &resourceDataMap = resourceDataMapPair.second; + + // For each descriptor binding. + for (const auto &resourceDataBindingPair : resourceDataMap) { + // Create device memory buffer. + VulkanDeviceMemoryBuffer memoryBuffer; + memoryBuffer.bindingIndex = resourceDataBindingPair.first; + VkDescriptorType descriptorType = {}; + VkBufferUsageFlagBits bufferUsage = {}; + + // Check that descriptor set has storage class map. + const auto resourceStorageClassMapIt = + resourceStorageClassData.find(descriptorSetIndex); + if (resourceStorageClassMapIt == resourceStorageClassData.end()) { + llvm::errs() + << "cannot find storge class for resource in descriptor set: " + << descriptorSetIndex; + return failure(); + } + + // Check that specific descriptor binding has storage class. + const auto &resourceStorageClassMap = resourceStorageClassMapIt->second; + const auto resourceStorageClassIt = + resourceStorageClassMap.find(resourceDataBindingPair.first); + if (resourceStorageClassIt == resourceStorageClassMap.end()) { + llvm::errs() + << "cannot find storage class for resource with descriptor index: " + << resourceDataBindingPair.first; + return failure(); + } + + const auto resourceStorageClassBinding = resourceStorageClassIt->second; + if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding, + descriptorType)) || + failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding, + bufferUsage))) { + llvm::errs() << "storage class for resource with descriptor binding: " + << resourceDataBindingPair.first + << " in the descriptor set: " << descriptorSetIndex + << " is not supported "; + return failure(); + } + + // Set descriptor type for the specific device memory buffer. + memoryBuffer.descriptorType = descriptorType; + const auto bufferSize = resourceDataBindingPair.second.size; + + // Specify memory allocation info. + VkMemoryAllocateInfo memoryAllocateInfo = {}; + memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + memoryAllocateInfo.pNext = nullptr; + memoryAllocateInfo.allocationSize = bufferSize; + memoryAllocateInfo.memoryTypeIndex = memoryTypeIndex; + + // Allocate device memory. + RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0, + &memoryBuffer.deviceMemory), + "vkAllocateMemory"); + void *payload; + RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.deviceMemory, 0, + bufferSize, 0, + reinterpret_cast(&payload)), + "vkMapMemory"); + + // Copy host memory into the mapped area. + std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize); + vkUnmapMemory(device, memoryBuffer.deviceMemory); + + VkBufferCreateInfo bufferCreateInfo = {}; + bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + bufferCreateInfo.pNext = nullptr; + bufferCreateInfo.flags = 0; + bufferCreateInfo.size = bufferSize; + bufferCreateInfo.usage = bufferUsage; + bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + bufferCreateInfo.queueFamilyIndexCount = 1; + bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex; + RETURN_ON_VULKAN_ERROR( + vkCreateBuffer(device, &bufferCreateInfo, 0, &memoryBuffer.buffer), + "vkCreateBuffer"); + + // Bind buffer and device memory. + RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.buffer, + memoryBuffer.deviceMemory, 0), + "vkBindBufferMemory"); + + // Update buffer info. + memoryBuffer.bufferInfo.buffer = memoryBuffer.buffer; + memoryBuffer.bufferInfo.offset = 0; + memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE; + deviceMemoryBuffers.push_back(memoryBuffer); + } + + // Associate device memory buffers with a descriptor set. + deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers; + } + return success(); +} + +LogicalResult VulkanRuntime::createShaderModule() { + VkShaderModuleCreateInfo shaderModuleCreateInfo = {}; + shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shaderModuleCreateInfo.pNext = nullptr; + shaderModuleCreateInfo.flags = 0; + // Set size in bytes. + shaderModuleCreateInfo.codeSize = binary.size() * sizeof(uint32_t); + // Set pointer to the binary shader. + shaderModuleCreateInfo.pCode = reinterpret_cast(binary.data()); + RETURN_ON_VULKAN_ERROR( + vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule), + "vkCreateShaderModule"); + return success(); +} + +void VulkanRuntime::initDescriptorSetLayoutBindingMap() { + for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { + SmallVector descriptorSetLayoutBindings; + const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; + const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; + + // Create a layout binding for each descriptor. + for (const auto &memBuffer : deviceMemoryBuffers) { + VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {}; + descriptorSetLayoutBinding.binding = memBuffer.bindingIndex; + descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType; + descriptorSetLayoutBinding.descriptorCount = 1; + descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + descriptorSetLayoutBinding.pImmutableSamplers = 0; + descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding); + } + descriptorSetLayoutBindingMap[descriptorSetIndex] = + descriptorSetLayoutBindings; + } +} + +LogicalResult VulkanRuntime::createDescriptorSetLayout() { + for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { + const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; + const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; + // Each descriptor in a descriptor set must be the same type. + VkDescriptorType descriptorType = + deviceMemoryBuffers.front().descriptorType; + const uint32_t descriptorSize = deviceMemoryBuffers.size(); + const auto descriptorSetLayoutBindingIt = + descriptorSetLayoutBindingMap.find(descriptorSetIndex); + + if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) { + llvm::errs() << "cannot find layout bindings for the set with number: " + << descriptorSetIndex; + return failure(); + } + + const auto &descriptorSetLayoutBindings = + descriptorSetLayoutBindingIt->second; + // Create descriptor set layout. + VkDescriptorSetLayout descriptorSetLayout = {}; + VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {}; + + descriptorSetLayoutCreateInfo.sType = + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + descriptorSetLayoutCreateInfo.pNext = nullptr; + descriptorSetLayoutCreateInfo.flags = 0; + // Amount of descriptor bindings in a layout set. + descriptorSetLayoutCreateInfo.bindingCount = + descriptorSetLayoutBindings.size(); + descriptorSetLayoutCreateInfo.pBindings = + descriptorSetLayoutBindings.data(); + RETURN_ON_VULKAN_ERROR( + vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0, + &descriptorSetLayout), + "vkCreateDescriptorSetLayout"); + + descriptorSetLayouts.push_back(descriptorSetLayout); + descriptorSetInfoPool.push_back( + {descriptorSetIndex, descriptorSize, descriptorType}); + } + return success(); +} + +LogicalResult VulkanRuntime::createPipelineLayout() { + // Associate descriptor sets with a pipeline layout. + VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {}; + pipelineLayoutCreateInfo.sType = + VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + pipelineLayoutCreateInfo.pNext = nullptr; + pipelineLayoutCreateInfo.flags = 0; + pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size(); + pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data(); + pipelineLayoutCreateInfo.pushConstantRangeCount = 0; + pipelineLayoutCreateInfo.pPushConstantRanges = 0; + RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device, + &pipelineLayoutCreateInfo, 0, + &pipelineLayout), + "vkCreatePipelineLayout"); + return success(); +} + +LogicalResult VulkanRuntime::createComputePipeline() { + VkPipelineShaderStageCreateInfo stageInfo = {}; + stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + stageInfo.pNext = nullptr; + stageInfo.flags = 0; + stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT; + stageInfo.module = shaderModule; + // Set entry point. + stageInfo.pName = entryPoint.c_str(); + stageInfo.pSpecializationInfo = 0; + + VkComputePipelineCreateInfo computePipelineCreateInfo = {}; + computePipelineCreateInfo.sType = + VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + computePipelineCreateInfo.pNext = nullptr; + computePipelineCreateInfo.flags = 0; + computePipelineCreateInfo.stage = stageInfo; + computePipelineCreateInfo.layout = pipelineLayout; + computePipelineCreateInfo.basePipelineHandle = 0; + computePipelineCreateInfo.basePipelineIndex = 0; + RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1, + &computePipelineCreateInfo, 0, + &pipeline), + "vkCreateComputePipelines"); + return success(); +} + +LogicalResult VulkanRuntime::createDescriptorPool() { + llvm::SmallVector descriptorPoolSizes; + for (const auto &descriptorSetInfo : descriptorSetInfoPool) { + // For each descriptor set populate descriptor pool size. + VkDescriptorPoolSize descriptorPoolSize = {}; + descriptorPoolSize.type = descriptorSetInfo.descriptorType; + descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize; + descriptorPoolSizes.push_back(descriptorPoolSize); + } + + VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {}; + descriptorPoolCreateInfo.sType = + VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + descriptorPoolCreateInfo.pNext = nullptr; + descriptorPoolCreateInfo.flags = 0; + descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size(); + descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size(); + descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data(); + RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device, + &descriptorPoolCreateInfo, 0, + &descriptorPool), + "vkCreateDescriptorPool"); + return success(); +} + +LogicalResult VulkanRuntime::allocateDescriptorSets() { + VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {}; + // Size of desciptor sets and descriptor layout sets is the same. + descriptorSets.resize(descriptorSetLayouts.size()); + descriptorSetAllocateInfo.sType = + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + descriptorSetAllocateInfo.pNext = nullptr; + descriptorSetAllocateInfo.descriptorPool = descriptorPool; + descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size(); + descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data(); + RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device, + &descriptorSetAllocateInfo, + descriptorSets.data()), + "vkAllocateDescriptorSets"); + return success(); +} + +LogicalResult VulkanRuntime::setWriteDescriptors() { + if (descriptorSets.size() != descriptorSetInfoPool.size()) { + llvm::errs() << "Each descriptor set must have descriptor set information"; + return failure(); + } + // For each descriptor set. + auto descriptorSetIt = descriptorSets.begin(); + // Each descriptor set is associated with descriptor set info. + for (const auto &descriptorSetInfo : descriptorSetInfoPool) { + // For each device memory buffer in the descriptor set. + const auto &deviceMemoryBuffers = + deviceMemoryBufferMap[descriptorSetInfo.descriptorSet]; + for (const auto &memoryBuffer : deviceMemoryBuffers) { + // Structure describing descriptor sets to write to. + VkWriteDescriptorSet wSet = {}; + wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + wSet.pNext = nullptr; + // Descirptor set. + wSet.dstSet = *descriptorSetIt; + wSet.dstBinding = memoryBuffer.bindingIndex; + wSet.dstArrayElement = 0; + wSet.descriptorCount = 1; + wSet.descriptorType = memoryBuffer.descriptorType; + wSet.pImageInfo = nullptr; + wSet.pBufferInfo = &memoryBuffer.bufferInfo; + wSet.pTexelBufferView = nullptr; + vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr); + } + // Increment descriptor set iterator. + ++descriptorSetIt; + } + return success(); +} + +LogicalResult VulkanRuntime::createCommandPool() { + VkCommandPoolCreateInfo commandPoolCreateInfo = {}; + commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + commandPoolCreateInfo.pNext = nullptr; + commandPoolCreateInfo.flags = 0; + commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex; + RETURN_ON_VULKAN_ERROR( + vkCreateCommandPool(device, &commandPoolCreateInfo, 0, &commandPool), + "vkCreateCommandPool"); + return success(); +} + +LogicalResult VulkanRuntime::createComputeCommandBuffer() { + VkCommandBufferAllocateInfo commandBufferAllocateInfo = {}; + VkCommandBuffer commandBuffer; + commandBufferAllocateInfo.sType = + VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + commandBufferAllocateInfo.pNext = nullptr; + commandBufferAllocateInfo.commandPool = commandPool; + commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + commandBufferAllocateInfo.commandBufferCount = 1; + RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, + &commandBufferAllocateInfo, + &commandBuffer), + "vkAllocateCommandBuffers"); + + VkCommandBufferBeginInfo commandBufferBeginInfo; + commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + commandBufferBeginInfo.pNext = nullptr; + commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + commandBufferBeginInfo.pInheritanceInfo = 0; + + // Commands begin. + RETURN_ON_VULKAN_ERROR( + vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), + "vkBeginCommandBuffer"); + + // Commands. + vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); + vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, + pipelineLayout, 0, descriptorSets.size(), + descriptorSets.data(), 0, 0); + vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y, + numWorkGroups.z); + + // Commands end. + RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), + "vkEndCommandBuffer"); + + commandBuffers.push_back(commandBuffer); + return success(); +} + +LogicalResult VulkanRuntime::submitCommandBuffersToQueue() { + VkSubmitInfo submitInfo; + submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submitInfo.pNext = nullptr; + submitInfo.waitSemaphoreCount = 0; + submitInfo.pWaitSemaphores = 0; + submitInfo.pWaitDstStageMask = 0; + submitInfo.commandBufferCount = commandBuffers.size(); + submitInfo.pCommandBuffers = commandBuffers.data(); + submitInfo.signalSemaphoreCount = 0; + submitInfo.pSignalSemaphores = nullptr; + RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0), + "vkQueueSubmit"); + return success(); +} + +LogicalResult VulkanRuntime::updateHostMemoryBuffers() { + // For each descriptor set. + for (auto &resourceDataMapPair : resourceData) { + auto &resourceDataMap = resourceDataMapPair.second; + auto &deviceMemoryBuffers = + deviceMemoryBufferMap[resourceDataMapPair.first]; + // For each device memory buffer in the set. + for (auto &deviceMemoryBuffer : deviceMemoryBuffers) { + if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) { + void *payload; + auto &hostMemoryBuffer = + resourceDataMap[deviceMemoryBuffer.bindingIndex]; + RETURN_ON_VULKAN_ERROR(vkMapMemory(device, + deviceMemoryBuffer.deviceMemory, 0, + hostMemoryBuffer.size, 0, + reinterpret_cast(&payload)), + "vkMapMemory"); + std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size); + vkUnmapMemory(device, deviceMemoryBuffer.deviceMemory); + } + } + } + return success(); +} +} // namespace + +static LogicalResult runOnVulkan(spirv::ModuleOp module, + ResourceData &resourceData, + const NumWorkGroups &numWorkGroups) { + SPIRVModuleInfoCollector moduleHandler; + moduleHandler.processModule(module); + + SmallVector binary; + if (failed(spirv::serialize(module, binary))) { + llvm::errs() << "cannot serialize module" << '\n'; + return failure(); + } + + VulkanRuntime runtime; + runtime.setEntryPoint(moduleHandler.getEntryPoint()); + runtime.setNumWorkGroups(numWorkGroups); + runtime.setResourceData(resourceData); + runtime.setShaderModule(binary); + runtime.setResourceStorageClassData( + moduleHandler.getResourceStorageClassData()); + + if (failed(runtime.initRuntime()) || failed(runtime.run()) || + failed(runtime.updateHostMemoryBuffers()) || failed(runtime.destroy())) { + return failure(); + } + return success(); +} + +/// The purpose of the function is to run spirv::ModuleOp on Vulkan device. +/// +/// This function: +/// 1. Consumes mlir::ModuleOp, ResouceData and NumWorkGroups; +/// 2. Verifies mlir::ModuleOp and gets spirv::ModuleOp from mlir::ModuleOp; +/// 3. Collects entry point and storage classes of resource data; +/// 4. Serializes spirv::ModuleOp into the binary form; +/// 5. Creates Vulkan runtime and runs shader on Vulkan device; +/// 6. Updates resource data after the run; +LogicalResult runOnVulkan(mlir::ModuleOp module, ResourceData &resourceData, + const NumWorkGroups &numWorkGroups) { + if (failed(module.verify())) { + return failure(); + } + + auto result = failure(); + bool done = false; + for (auto spirvModule : module.getOps()) { + if (done) { + spirvModule.emitError("found more than one spv.module"); + } + done = true; + result = runOnVulkan(spirvModule, resourceData, numWorkGroups); + } + + if (failed(result)) { + return failure(); + } + return success(); +} Index: mlir/unittests/Dialect/SPIRV/CMakeLists.txt =================================================================== --- mlir/unittests/Dialect/SPIRV/CMakeLists.txt +++ mlir/unittests/Dialect/SPIRV/CMakeLists.txt @@ -9,3 +9,5 @@ whole_archive_link(MLIRSPIRVTests MLIRSPIRV) +add_subdirectory(Runtime) + Index: mlir/unittests/Dialect/SPIRV/Runtime/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/unittests/Dialect/SPIRV/Runtime/CMakeLists.txt @@ -0,0 +1,11 @@ +if (MLIR_VULKAN_RUNNER_ENABLED) +add_mlir_unittest(MLIRSPIRVRuntimeTests + RuntimeTests.cpp +) +target_link_libraries(MLIRSPIRVRuntimeTests + PRIVATE + MLIRSPIRV + MLIRVulkanRuntime + ) +whole_archive_link(MLIRSPIRVRuntimeTests MLIRSPIRV) +endif() Index: mlir/unittests/Dialect/SPIRV/Runtime/RuntimeTests.cpp =================================================================== --- /dev/null +++ mlir/unittests/Dialect/SPIRV/Runtime/RuntimeTests.cpp @@ -0,0 +1,233 @@ +//===- RuntimeTest.cpp - SPIR-V Runtime Tests -----------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file is for testing the Vulkan runtime API which takes in a spirv::ModuleOp, +// a bunch of resourses and number of work groups. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/Parser.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" + +#include "gmock/gmock.h" +#include +#include + +using namespace mlir; +using namespace llvm; + +using DescriptorSetIndex = uint32_t; +using BindingIndex = uint32_t; + +// Struct containing information regarding to a host memory buffer. +struct VulkanHostMemoryBuffer { + void *ptr{nullptr}; + uint64_t size{0}; +}; + +// Struct containing the number of local workgroups to dispatch for each +// dimension. +struct NumWorkGroups { + uint32_t x{1}; + uint32_t y{1}; + uint32_t z{1}; +}; + +// This is a temporary function and will be removed in the future. +// See the full description in tools/mlir-vulkan-runner/VulkanRutime.cpp +extern mlir::LogicalResult runOnVulkan( + mlir::ModuleOp, + llvm::DenseMap> &, + const NumWorkGroups &); + +class RuntimeTest : public ::testing::Test { +protected: + LogicalResult parseAndRunModule(llvm::StringRef sourceFile, + NumWorkGroups numWorkGroups) { + std::string errorMessage; + auto inputFile = llvm::MemoryBuffer::getMemBuffer(sourceFile); + if (!inputFile) { + llvm::errs() << errorMessage << "\n"; + return failure(); + } + + SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer(std::move(inputFile), SMLoc()); + + MLIRContext context; + OwningModuleRef moduleRef(parseSourceFile(sourceMgr, &context)); + if (!moduleRef) { + llvm::errs() << "\ncannot parse the file as a MLIR module" << '\n'; + return failure(); + } + + if (failed(runOnVulkan(moduleRef.get(), vars, numWorkGroups))) { + return failure(); + } + + return success(); + } + + std::unique_ptr createResourceVarFloat(uint32_t descriptorSet, + uint32_t binding, + uint32_t elementCount) { + std::unique_ptr ptr(new float[elementCount]); + std::mt19937 gen(randomDevice()); + std::uniform_real_distribution<> distribution(0.0, 10.0); + for (uint32_t i = 0; i < elementCount; ++i) { + ptr[i] = static_cast(distribution(gen)); + } + VulkanHostMemoryBuffer hostMemoryBuffer; + hostMemoryBuffer.ptr = ptr.get(); + hostMemoryBuffer.size = sizeof(float) * elementCount; + vars[descriptorSet][binding] = hostMemoryBuffer; + return ptr; + } + + void destroyResourceVarFloat(VulkanHostMemoryBuffer &hostMemoryBuffer) { + float *ptr = static_cast(hostMemoryBuffer.ptr); + delete ptr; + } + + VulkanHostMemoryBuffer FMul(VulkanHostMemoryBuffer &var1, + VulkanHostMemoryBuffer &var2) { + VulkanHostMemoryBuffer resultHostMemoryBuffer; + uint32_t size = var1.size / sizeof(float); + float *result = new float[size]; + const float *rhs = reinterpret_cast(var1.ptr); + const float *lhs = reinterpret_cast(var2.ptr); + + for (uint32_t i = 0; i < size; ++i) { + result[i] = lhs[i] * rhs[i]; + } + resultHostMemoryBuffer.ptr = static_cast(result); + resultHostMemoryBuffer.size = size * sizeof(float); + return resultHostMemoryBuffer; + } + + VulkanHostMemoryBuffer FAdd(VulkanHostMemoryBuffer &var1, + VulkanHostMemoryBuffer &var2) { + VulkanHostMemoryBuffer resultHostMemoryBuffer; + uint32_t size = var1.size / sizeof(float); + float *result = new float[size]; + const float *rhs = reinterpret_cast(var1.ptr); + const float *lhs = reinterpret_cast(var2.ptr); + + for (uint32_t i = 0; i < size; ++i) { + result[i] = lhs[i] + rhs[i]; + } + resultHostMemoryBuffer.ptr = static_cast(result); + resultHostMemoryBuffer.size = size * sizeof(float); + return resultHostMemoryBuffer; + } + + bool isEqualFloat(const VulkanHostMemoryBuffer &hostMemoryBuffer1, + const VulkanHostMemoryBuffer &hostMemoryBuffer2) { + if (hostMemoryBuffer1.size != hostMemoryBuffer2.size) + return false; + + uint32_t size = hostMemoryBuffer1.size / sizeof(float); + + const float *lhs = static_cast(hostMemoryBuffer1.ptr); + const float *rhs = static_cast(hostMemoryBuffer2.ptr); + const float epsilon = 0.0001f; + for (uint32_t i = 0; i < size; ++i) { + if (fabs(lhs[i] - rhs[i]) > epsilon) + return false; + } + return true; + } + +protected: + llvm::DenseMap> + vars; + std::random_device randomDevice; +}; + +TEST_F(RuntimeTest, SimpleTest) { + // SPIRV module embedded into the string. + // This module contains 4 resource variables devided into 2 sets. + std::string spirvModuleSource = +"spv.module \"Logical\" \"GLSL450\" {\n" + "spv.globalVariable @var3 bind(1, 1) : !spv.ptr [0]>, StorageBuffer>\n" + "spv.globalVariable @var2 bind(1, 0) : !spv.ptr [0]>, StorageBuffer>\n" + "spv.globalVariable @var1 bind(0, 1) : !spv.ptr [0]>, StorageBuffer>\n" + "spv.globalVariable @var0 bind(0, 0) : !spv.ptr [0]>, StorageBuffer>\n" + "spv.globalVariable @globalInvocationID built_in(\"GlobalInvocationId\"): !spv.ptr, Input>\n" + "func @kernel() -> () {\n" + "%c0 = spv.constant 0 : i32\n" + + "%0 = spv._address_of @var0 : !spv.ptr [0]>, StorageBuffer>\n" + "%1 = spv._address_of @var1 : !spv.ptr [0]>, StorageBuffer>\n" + "%2 = spv._address_of @var2 : !spv.ptr [0]>, StorageBuffer>\n" + "%3 = spv._address_of @var3 : !spv.ptr [0]>, StorageBuffer>\n" + + "%ptr_id = spv._address_of @globalInvocationID: !spv.ptr, Input>\n" + + "%id = spv.AccessChain %ptr_id[%c0] : !spv.ptr, Input>\n" + "%index = spv.Load \"Input\" %id: i32\n" + + "%4 = spv.AccessChain %0[%c0, %index] : !spv.ptr [0]>, StorageBuffer>\n" + "%5 = spv.AccessChain %1[%c0, %index] : !spv.ptr [0]>, StorageBuffer>\n" + "%6 = spv.AccessChain %2[%c0, %index] : !spv.ptr [0]>, StorageBuffer>\n" + "%7 = spv.AccessChain %3[%c0, %index] : !spv.ptr [0]>, StorageBuffer>\n" + + "%8 = spv.Load \"StorageBuffer\" %4 : f32\n" + "%9 = spv.Load \"StorageBuffer\" %5 : f32\n" + "%10 = spv.Load \"StorageBuffer\" %6 : f32\n" + + "%11 = spv.FMul %8, %9 : f32\n" + "%12 = spv.FAdd %11, %10 : f32\n" + + "spv.Store \"StorageBuffer\" %7, %12 : f32\n" + "spv.Return\n" + "}\n" + "spv.EntryPoint \"GLCompute\" @kernel, @globalInvocationID\n" + "spv.ExecutionMode @kernel \"LocalSize\", 1, 1, 1\n" +"} attributes {\n" + "capabilities = [\"Shader\"],\n" + "extensions = [\"SPV_KHR_storage_buffer_storage_class\"]\n" +"}\n"; + + auto resOne = createResourceVarFloat(0, 0, 1024); + auto resTwo = createResourceVarFloat(0, 1, 1024); + auto resThree = createResourceVarFloat(1, 0, 1024); + auto resFour = createResourceVarFloat(1, 1, 1024); + + auto fmulResult = FMul(vars[0][0], vars[0][1]); + auto expected = FAdd(vars[1][0], fmulResult); + + NumWorkGroups numWorkGroups; + numWorkGroups.x = 1024; + ASSERT_TRUE(succeeded(parseAndRunModule(spirvModuleSource, numWorkGroups))); + ASSERT_TRUE(isEqualFloat(expected, vars[1][1])); + + destroyResourceVarFloat(fmulResult); + destroyResourceVarFloat(expected); +}