diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -48,6 +48,7 @@ add_definitions(-DMLIR_CUDA_CONVERSIONS_ENABLED=${MLIR_CUDA_CONVERSIONS_ENABLED}) set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner") +set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner") include_directories( "include") include_directories( ${MLIR_INCLUDE_DIR}) diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -27,7 +27,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" -#include "llvm/Support/FormatVariadic.h" +#include "llvm/ADT/SmallString.h" using namespace mlir; @@ -80,7 +80,7 @@ /// populates the given `numWorkGroups`. LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp, - SmallVector &numWorkGroups); + SmallVectorImpl &numWorkGroups); /// Declares all needed runtime functions. void declareVulkanFunctions(Location loc); @@ -153,17 +153,15 @@ Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant( StringRef name, Location loc, OpBuilder &builder) { - std::vector shaderName(name.begin(), name.end()); + SmallString<16> shaderName(name.begin(), name.end()); // Append `\0` to follow C style string given that LLVM::createGlobalString() // won't handle this directly for us. shaderName.push_back('\0'); - std::string entryPointGlobalName = - std::string(llvm::formatv("{0}_spv_entry_point_name", name)); - return LLVM::createGlobalString( - loc, builder, entryPointGlobalName, - StringRef(shaderName.data(), shaderName.size()), LLVM::Linkage::Internal, - getLLVMDialect()); + std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); + return LLVM::createGlobalString(loc, builder, entryPointGlobalName, + shaderName, LLVM::Linkage::Internal, + getLLVMDialect()); } LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader( @@ -171,14 +169,12 @@ bool done = false; SmallVector binary; for (auto spirvModule : module.getOps()) { - if (done) { - spirvModule.emitError("should only contain one 'spv.module' op"); - return failure(); - } + if (done) + return spirvModule.emitError("should only contain one 'spv.module' op"); done = true; - if (failed(spirv::serialize(spirvModule, binary))) { + + if (failed(spirv::serialize(spirvModule, binary))) return failure(); - } } binaryShader.resize(binary.size() * sizeof(uint32_t)); @@ -189,14 +185,13 @@ LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups( Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp, - SmallVector &numWorkGroups) { + SmallVectorImpl &numWorkGroups) { for (auto index : llvm::seq(0, 3)) { auto numWorkGroupDimConstant = dyn_cast_or_null( launchOp.getOperand(index).getDefiningOp()); - if (!numWorkGroupDimConstant) { + if (!numWorkGroupDimConstant) return failure(); - } auto numWorkGroupDimValue = numWorkGroupDimConstant.getValue().cast().getInt(); @@ -207,7 +202,6 @@ return success(); } -// Translates gpu launch op to the sequence of Vulkan runtime calls. void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls( mlir::gpu::LaunchFuncOp launchOp) { ModuleOp module = getModule(); @@ -217,9 +211,8 @@ // Serialize `spirv::Module` into binary form. std::vector binary; if (failed( - GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary))) { + GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary))) return signalPassFailure(); - } // Create LLVM global with SPIR-V binary data, so we can pass a pointer with // that data to runtime call. @@ -246,9 +239,8 @@ // Create number of local workgroup for each dimension. SmallVector numWorkGroups; - if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups))) { + if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups))) return signalPassFailure(); - } // Create call `setNumWorkGroups` runtime function with the given numbers of // local workgroup. diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -15,6 +15,7 @@ # Passed to lit.site.cfg.py.in to set up the path where to find the libraries # for the mlir cuda runner tests. set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) configure_lit_site_cfg( ${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in @@ -61,6 +62,12 @@ ) endif() +if(MLIR_VULKAN_RUNNER_ENABLED) + list(APPEND MLIR_TEST_DEPENDS + mlir-vulkan-runner + ) +endif() + add_lit_testsuite(check-mlir "Running the MLIR regression tests" ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${MLIR_TEST_DEPENDS} diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -67,7 +67,8 @@ ToolSubst('toy-ch4', unresolved='ignore'), ToolSubst('toy-ch5', unresolved='ignore'), ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'), - ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore') + ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore'), + ToolSubst('%vulkan_wrapper_library_dir', config.vulkan_wrapper_library_dir, unresolved='ignore') ]) llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -36,6 +36,8 @@ config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@ config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@" config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@ +config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@" +config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@ # Support substitution of the tools_dir with user parameters. This is # used when we can't determine the tool dir at configuration time. diff --git a/mlir/test/mlir-vulkan-runner/addf.mlir b/mlir/test/mlir-vulkan-runner/addf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-vulkan-runner/addf.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +// CHECK: [3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3] +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @kernel_add(%arg0 : memref<8xf32>, %arg1 : memref<8xf32>, %arg2 : memref<8xf32>) + attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} { + %0 = "gpu.block_id"() {dimension = "x"} : () -> index + %1 = load %arg0[%0] : memref<8xf32> + %2 = load %arg1[%0] : memref<8xf32> + %3 = addf %1, %2 : f32 + store %3, %arg2[%0] : memref<8xf32> + gpu.return + } + } + + func @main() { + %arg0 = alloc() : memref<8xf32> + %arg1 = alloc() : memref<8xf32> + %arg2 = alloc() : memref<8xf32> + %0 = constant 0 : i32 + %1 = constant 1 : i32 + %2 = constant 2 : i32 + %value0 = constant 0.0 : f32 + %value1 = constant 1.1 : f32 + %value2 = constant 2.2 : f32 + %arg3 = memref_cast %arg0 : memref<8xf32> to memref + %arg4 = memref_cast %arg1 : memref<8xf32> to memref + %arg5 = memref_cast %arg2 : memref<8xf32> to memref + call @setResourceData(%0, %0, %arg3, %value1) : (i32, i32, memref, f32) -> () + call @setResourceData(%0, %1, %arg4, %value2) : (i32, i32, memref, f32) -> () + call @setResourceData(%0, %2, %arg5, %value0) : (i32, i32, memref, f32) -> () + + %cst1 = constant 1 : index + %cst8 = constant 8 : index + "gpu.launch_func"(%cst8, %cst1, %cst1, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = "kernel_add", kernel_module = @kernels } + : (index, index, index, index, index, index, memref<8xf32>, memref<8xf32>, memref<8xf32>) -> () + %arg6 = memref_cast %arg5 : memref to memref<*xf32> + call @print_memref_f32(%arg6) : (memref<*xf32>) -> () + return + } + func @setResourceData(%0 : i32, %1 : i32, %2 : memref, %4 : f32) + func @print_memref_f32(%ptr : memref<*xf32>) +} + diff --git a/mlir/test/mlir-vulkan-runner/lit.local.cfg b/mlir/test/mlir-vulkan-runner/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-vulkan-runner/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_vulkan_runner: + config.unsupported = True diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(mlir-opt) add_subdirectory(mlir-tblgen) add_subdirectory(mlir-translate) +add_subdirectory(mlir-vulkan-runner) diff --git a/mlir/tools/mlir-vulkan-runner/CMakeLists.txt b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt @@ -0,0 +1,105 @@ +set(LLVM_OPTIONAL_SOURCES + mlir-vulkan-runner.cpp + vulkan-runtime-wrappers.cpp + VulkanRuntime.cpp + VulkanRuntime.h + ) + +if (MLIR_VULKAN_RUNNER_ENABLED) + message(STATUS "Building the Vulkan runner") + + # At first try "FindVulkan" from: + # https://cmake.org/cmake/help/v3.7/module/FindVulkan.html + if (NOT CMAKE_VERSION VERSION_LESS 3.7.0) + find_package(Vulkan) + endif() + + # If Vulkan is not found try a path specified by VULKAN_SDK. + if (NOT Vulkan_FOUND) + if ("$ENV{VULKAN_SDK}" STREQUAL "") + message(FATAL_ERROR "Please use at least CMAKE 3.7.0 or provide VULKAN_SDK + path as an environment variable") + endif() + + find_library(Vulkan_LIBRARY vulkan HINTS "$ENV{VULKAN_SDK}/lib" REQUIRED) + if (Vulkan_LIBRARY) + set(Vulkan_FOUND ON) + set(Vulkan_INCLUDE_DIR "$ENV{VULKAN_SDK}/include") + message(STATUS "Found Vulkan: " ${Vulkan_LIBRARY}) + endif() + endif() + + if (NOT Vulkan_FOUND) + message(FATAL_ERROR "Cannot find Vulkan library") + endif() + + add_llvm_library(vulkan-runtime-wrappers SHARED + vulkan-runtime-wrappers.cpp + VulkanRuntime.cpp + ) + + target_include_directories(vulkan-runtime-wrappers + PRIVATE ${Vulkan_INCLUDE_DIR} + LLVMSupport + ) + + target_link_libraries(vulkan-runtime-wrappers + LLVMSupport + MLIRSPIRVSerialization + LLVMCore + LLVMSupport + ${Vulkan_LIBRARY} + ) + + set(LIBS + LLVMCore + LLVMSupport + MLIRJitRunner + MLIRAffineOps + MLIRAnalysis + MLIREDSC + MLIRExecutionEngine + MLIRFxpMathOps + MLIRGPU + MLIRGPUtoCUDATransforms + MLIRGPUtoNVVMTransforms + MLIRGPUtoSPIRVTransforms + MLIRGPUtoVulkanTransforms + MLIRIR + MLIRLLVMIR + MLIRLinalgOps + MLIRLoopToStandard + MLIROpenMP + MLIRParser + MLIRQuantOps + MLIRROCDLIR + MLIRSPIRV + MLIRSPIRVTransforms + MLIRStandardOps + MLIRStandardToLLVM + MLIRSupport + MLIRTargetLLVMIR + MLIRTransforms + MLIRTranslation + ${Vulkan_LIBRARY} + ) + + # Manually expand the target library, since our MLIR libraries + # aren't plugged into the LLVM dependency tracking. If we don't + # do this then we can't insert the CodeGen library after ourselves + llvm_expand_pseudo_components(TARGET_LIBS AllTargetsCodeGens) + # Prepend LLVM in front of every target, this is how the library + # are named with CMake + SET(targets_to_link) + FOREACH(t ${TARGET_LIBS}) + LIST(APPEND targets_to_link "LLVM${t}") + ENDFOREACH(t) + + add_llvm_tool(mlir-vulkan-runner + mlir-vulkan-runner.cpp + ) + add_dependencies(mlir-vulkan-runner vulkan-runtime-wrappers) + llvm_update_compile_flags(mlir-vulkan-runner) + target_link_libraries(mlir-vulkan-runner PRIVATE ${FULL_LINK_LIBS} ${LIBS}) + +endif() diff --git a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h @@ -0,0 +1,225 @@ +//===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares Vulkan runtime API. +// +//===----------------------------------------------------------------------===// + +#ifndef VULKAN_RUNTIME_H +#define VULKAN_RUNTIME_H + +#include "mlir/Analysis/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Dialect/SPIRV/Serialization.h" +#include "mlir/IR/Module.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/StringExtras.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ToolOutputFile.h" + +#include + +using namespace mlir; + +using DescriptorSetIndex = uint32_t; +using BindingIndex = uint32_t; + +/// Struct containing information regarding to a device memory buffer. +struct VulkanDeviceMemoryBuffer { + BindingIndex bindingIndex{0}; + VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; + VkDescriptorBufferInfo bufferInfo{}; + VkBuffer buffer{VK_NULL_HANDLE}; + VkDeviceMemory deviceMemory{VK_NULL_HANDLE}; +}; + +/// Struct containing information regarding to a host memory buffer. +struct VulkanHostMemoryBuffer { + /// Pointer to a host memory. + void *ptr{nullptr}; + /// Size of a host memory in bytes. + uint32_t size{0}; +}; + +/// Struct containing the number of local workgroups to dispatch for each +/// dimension. +struct NumWorkGroups { + uint32_t x{1}; + uint32_t y{1}; + uint32_t z{1}; +}; + +/// Struct containing information regarding a descriptor set. +struct DescriptorSetInfo { + /// Index of a descriptor set in descriptor sets. + DescriptorSetIndex descriptorSet{0}; + /// Number of desriptors in a set. + uint32_t descriptorSize{0}; + /// Type of a descriptor set. + VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM}; +}; + +/// VulkanHostMemoryBuffer mapped into a descriptor set and a binding. +using ResourceData = + llvm::DenseMap>; + +/// StorageClass mapped into a descriptor set and a binding. +using ResourceStorageClassBindingMap = + llvm::DenseMap>; + +inline void emitVulkanError(const llvm::Twine &message, VkResult error) { + llvm::errs() + << message.concat(" failed with error code ").concat(llvm::Twine{error}); +} + +#define RETURN_ON_VULKAN_ERROR(result, msg) \ + if ((result) != VK_SUCCESS) { \ + emitVulkanError(msg, (result)); \ + return failure(); \ + } + +/// Vulkan runtime. +/// The purpose of this class is to run SPIR-V compute shader on Vulkan +/// device. +/// Before the run, user must provide and set resource data with descriptors, +/// SPIR-V shader, number of work groups and entry point. After the creation of +/// VulkanRuntime, special methods must be called in the following +/// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy(); +/// each method in the sequence returns succes or failure depends on the Vulkan +/// result code. +class VulkanRuntime { +public: + explicit VulkanRuntime() = default; + VulkanRuntime(const VulkanRuntime &) = delete; + VulkanRuntime &operator=(const VulkanRuntime &) = delete; + + /// Sets needed data for Vulkan runtime. + void setResourceData(const ResourceData &resData); + void setResourceData(const DescriptorSetIndex desIndex, + const BindingIndex bindIndex, + const VulkanHostMemoryBuffer &hostMemBuffer); + void setShaderModule(uint8_t *shader, uint32_t size); + void setNumWorkGroups(const NumWorkGroups &numberWorkGroups); + void setResourceStorageClassBindingMap( + const ResourceStorageClassBindingMap &stClassData); + void setEntryPoint(const char *entryPointName); + + /// Runtime initialization. + LogicalResult initRuntime(); + + /// Runs runtime. + LogicalResult run(); + + /// Updates host memory buffers. + LogicalResult updateHostMemoryBuffers(); + + /// Destroys all created vulkan objects and resources. + LogicalResult destroy(); + +private: + //===--------------------------------------------------------------------===// + // Pipeline creation methods. + //===--------------------------------------------------------------------===// + + LogicalResult createInstance(); + LogicalResult createDevice(); + LogicalResult getBestComputeQueue(const VkPhysicalDevice &physicalDevice); + LogicalResult createMemoryBuffers(); + LogicalResult createShaderModule(); + void initDescriptorSetLayoutBindingMap(); + LogicalResult createDescriptorSetLayout(); + LogicalResult createPipelineLayout(); + LogicalResult createComputePipeline(); + LogicalResult createDescriptorPool(); + LogicalResult allocateDescriptorSets(); + LogicalResult setWriteDescriptors(); + LogicalResult createCommandPool(); + LogicalResult createComputeCommandBuffer(); + LogicalResult submitCommandBuffersToQueue(); + + //===--------------------------------------------------------------------===// + // Helper methods. + //===--------------------------------------------------------------------===// + + /// Maps storage class to a descriptor type. + LogicalResult + mapStorageClassToDescriptorType(spirv::StorageClass storageClass, + VkDescriptorType &descriptorType); + + /// Maps storage class to buffer usage flags. + LogicalResult + mapStorageClassToBufferUsageFlag(spirv::StorageClass storageClass, + VkBufferUsageFlagBits &bufferUsage); + + LogicalResult countDeviceMemorySize(); + + //===--------------------------------------------------------------------===// + // Vulkan objects. + //===--------------------------------------------------------------------===// + + VkInstance instance; + VkDevice device; + VkQueue queue; + + /// Specifies VulkanDeviceMemoryBuffers divided into sets. + llvm::DenseMap> + deviceMemoryBufferMap; + + /// Specifies shader module. + VkShaderModule shaderModule; + + /// Specifies layout bindings. + llvm::DenseMap> + descriptorSetLayoutBindingMap; + + /// Specifies layouts of descriptor sets. + llvm::SmallVector descriptorSetLayouts; + VkPipelineLayout pipelineLayout; + + /// Specifies descriptor sets. + llvm::SmallVector descriptorSets; + + /// Specifies a pool of descriptor set info, each descriptor set must have + /// information such as type, index and amount of bindings. + llvm::SmallVector descriptorSetInfoPool; + VkDescriptorPool descriptorPool; + + /// Computation pipeline. + VkPipeline pipeline; + VkCommandPool commandPool; + llvm::SmallVector commandBuffers; + + //===--------------------------------------------------------------------===// + // Vulkan memory context. + //===--------------------------------------------------------------------===// + + uint32_t queueFamilyIndex{0}; + uint32_t memoryTypeIndex{VK_MAX_MEMORY_TYPES}; + VkDeviceSize memorySize{0}; + + //===--------------------------------------------------------------------===// + // Vulkan execution context. + //===--------------------------------------------------------------------===// + + NumWorkGroups numWorkGroups; + const char *entryPoint{nullptr}; + uint8_t *binary{nullptr}; + uint32_t binarySize{0}; + + //===--------------------------------------------------------------------===// + // Vulkan resource data and storage classes. + //===--------------------------------------------------------------------===// + + ResourceData resourceData; + ResourceStorageClassBindingMap resourceStorageClassData; +}; +#endif diff --git a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.cpp b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.cpp @@ -0,0 +1,717 @@ +//===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a library for running a module on a Vulkan device. +// Implements a Vulkan runtime. +// +//===----------------------------------------------------------------------===// + +#include "VulkanRuntime.h" + +using namespace mlir; + +void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) { + numWorkGroups = numberWorkGroups; +} + +void VulkanRuntime::setResourceStorageClassBindingMap( + const ResourceStorageClassBindingMap &stClassData) { + resourceStorageClassData = stClassData; +} + +void VulkanRuntime::setResourceData( + const DescriptorSetIndex desIndex, const BindingIndex bindIndex, + const VulkanHostMemoryBuffer &hostMemBuffer) { + resourceData[desIndex][bindIndex] = hostMemBuffer; + resourceStorageClassData[desIndex][bindIndex] = + spirv::StorageClass::StorageBuffer; +} + +void VulkanRuntime::setEntryPoint(const char *entryPointName) { + entryPoint = entryPointName; +} + +void VulkanRuntime::setResourceData(const ResourceData &resData) { + resourceData = resData; +} + +void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) { + binary = shader; + binarySize = size; +} + +LogicalResult VulkanRuntime::mapStorageClassToDescriptorType( + spirv::StorageClass storageClass, VkDescriptorType &descriptorType) { + switch (storageClass) { + case spirv::StorageClass::StorageBuffer: + descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + break; + case spirv::StorageClass::Uniform: + descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; + break; + default: + llvm::errs() << "unsupported storage class"; + return failure(); + } + return success(); +} + +LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag( + spirv::StorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) { + switch (storageClass) { + case spirv::StorageClass::StorageBuffer: + bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + break; + case spirv::StorageClass::Uniform: + bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + break; + default: + llvm::errs() << "unsupported storage class"; + return failure(); + } + return success(); +} + +LogicalResult VulkanRuntime::countDeviceMemorySize() { + for (const auto &resourceDataMapPair : resourceData) { + const auto &resourceDataMap = resourceDataMapPair.second; + for (const auto &resourceDataBindingPair : resourceDataMap) { + if (resourceDataBindingPair.second.size) { + memorySize += resourceDataBindingPair.second.size; + } else { + llvm::errs() + << "expected buffer size greater than zero for resource data"; + return failure(); + } + } + } + return success(); +} + +LogicalResult VulkanRuntime::initRuntime() { + if (!resourceData.size()) { + llvm::errs() << "Vulkan runtime needs at least one resource"; + return failure(); + } + if (!binarySize || !binary) { + llvm::errs() << "binary shader size must be greater than zero"; + return failure(); + } + if (failed(countDeviceMemorySize())) { + return failure(); + } + return success(); +} + +LogicalResult VulkanRuntime::destroy() { + // According to Vulkan spec: + // "To ensure that no work is active on the device, vkDeviceWaitIdle can be + // used to gate the destruction of the device. Prior to destroying a device, + // an application is responsible for destroying/freeing any Vulkan objects + // that were created using that device as the first parameter of the + // corresponding vkCreate* or vkAllocate* command." + RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle"); + + // Free and destroy. + vkFreeCommandBuffers(device, commandPool, commandBuffers.size(), + commandBuffers.data()); + vkDestroyCommandPool(device, commandPool, nullptr); + vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(), + descriptorSets.data()); + vkDestroyDescriptorPool(device, descriptorPool, nullptr); + vkDestroyPipeline(device, pipeline, nullptr); + vkDestroyPipelineLayout(device, pipelineLayout, nullptr); + for (auto &descriptorSetLayout: descriptorSetLayouts) { + vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr); + } + vkDestroyShaderModule(device, shaderModule, nullptr); + + // For each descriptor set. + for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { + auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; + // For each descirptor binding. + for (auto &memoryBuffer : deviceMemoryBuffers) { + vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr); + vkDestroyBuffer(device, memoryBuffer.buffer, nullptr); + } + } + + vkDestroyDevice(device, nullptr); + vkDestroyInstance(instance, nullptr); + return success(); +} + +LogicalResult VulkanRuntime::run() { + // Create logical device, shader module and memory buffers. + if (failed(createInstance()) || failed(createDevice()) || + failed(createMemoryBuffers()) || failed(createShaderModule())) { + return failure(); + } + + // Descriptor bindings divided into sets. Each descriptor binding + // must have a layout binding attached into a descriptor set layout. + // Each layout set must be binded into a pipeline layout. + initDescriptorSetLayoutBindingMap(); + if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) || + // Each descriptor set must be allocated from a descriptor pool. + failed(createComputePipeline()) || failed(createDescriptorPool()) || + failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) || + // Create command buffer. + failed(createCommandPool()) || failed(createComputeCommandBuffer())) { + return failure(); + } + + // Get working queue. + vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue); + + // Submit command buffer into the queue. + if (failed(submitCommandBuffersToQueue())) { + return failure(); + } + + RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle"); + return success(); +} + +LogicalResult VulkanRuntime::createInstance() { + VkApplicationInfo applicationInfo = {}; + applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + applicationInfo.pNext = nullptr; + applicationInfo.pApplicationName = "MLIR Vulkan runtime"; + applicationInfo.applicationVersion = 0; + applicationInfo.pEngineName = "mlir"; + applicationInfo.engineVersion = 0; + applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0); + + VkInstanceCreateInfo instanceCreateInfo = {}; + instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + instanceCreateInfo.pNext = nullptr; + instanceCreateInfo.flags = 0; + instanceCreateInfo.pApplicationInfo = &applicationInfo; + instanceCreateInfo.enabledLayerCount = 0; + instanceCreateInfo.ppEnabledLayerNames = 0; + instanceCreateInfo.enabledExtensionCount = 0; + instanceCreateInfo.ppEnabledExtensionNames = 0; + + RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance), + "vkCreateInstance"); + return success(); +} + +LogicalResult VulkanRuntime::createDevice() { + uint32_t physicalDeviceCount = 0; + RETURN_ON_VULKAN_ERROR( + vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0), + "vkEnumeratePhysicalDevices"); + + llvm::SmallVector physicalDevices(physicalDeviceCount); + RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance, + &physicalDeviceCount, + physicalDevices.data()), + "vkEnumeratePhysicalDevices"); + + RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE, + "physicalDeviceCount"); + + // TODO(denis0x0D): find the best device. + const auto &physicalDevice = physicalDevices.front(); + getBestComputeQueue(physicalDevice); + + const float queuePrioritory = 1.0f; + VkDeviceQueueCreateInfo deviceQueueCreateInfo = {}; + deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + deviceQueueCreateInfo.pNext = nullptr; + deviceQueueCreateInfo.flags = 0; + deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex; + deviceQueueCreateInfo.queueCount = 1; + deviceQueueCreateInfo.pQueuePriorities = &queuePrioritory; + + // Structure specifying parameters of a newly created device. + VkDeviceCreateInfo deviceCreateInfo = {}; + deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + deviceCreateInfo.pNext = nullptr; + deviceCreateInfo.flags = 0; + deviceCreateInfo.queueCreateInfoCount = 1; + deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo; + deviceCreateInfo.enabledLayerCount = 0; + deviceCreateInfo.ppEnabledLayerNames = nullptr; + deviceCreateInfo.enabledExtensionCount = 0; + deviceCreateInfo.ppEnabledExtensionNames = nullptr; + deviceCreateInfo.pEnabledFeatures = nullptr; + + RETURN_ON_VULKAN_ERROR( + vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device), + "vkCreateDevice"); + + VkPhysicalDeviceMemoryProperties properties = {}; + vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties); + + // Try to find memory type with following properties: + // VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated + // with this type can be mapped for host access using vkMapMemory; + // VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache + // management commands vkFlushMappedMemoryRanges and + // vkInvalidateMappedMemoryRanges are not needed to flush host writes to the + // device or make device writes visible to the host, respectively. + for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) { + if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT & + properties.memoryTypes[i].propertyFlags) && + (VK_MEMORY_PROPERTY_HOST_COHERENT_BIT & + properties.memoryTypes[i].propertyFlags) && + (memorySize <= + properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) { + memoryTypeIndex = i; + break; + } + } + + RETURN_ON_VULKAN_ERROR(memoryTypeIndex == VK_MAX_MEMORY_TYPES ? VK_INCOMPLETE + : VK_SUCCESS, + "invalid memoryTypeIndex"); + return success(); +} + +LogicalResult +VulkanRuntime::getBestComputeQueue(const VkPhysicalDevice &physicalDevice) { + uint32_t queueFamilyPropertiesCount = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, + &queueFamilyPropertiesCount, 0); + SmallVector queueFamilyProperties( + queueFamilyPropertiesCount); + + vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, + &queueFamilyPropertiesCount, + queueFamilyProperties.data()); + + // VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support + // compute operations. + for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { + const VkQueueFlags maskedFlags = + (~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & + queueFamilyProperties[i].queueFlags); + + if (!(VK_QUEUE_GRAPHICS_BIT & maskedFlags) && + (VK_QUEUE_COMPUTE_BIT & maskedFlags)) { + queueFamilyIndex = i; + return success(); + } + } + + for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) { + const VkQueueFlags maskedFlags = + (~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) & + queueFamilyProperties[i].queueFlags); + + if (VK_QUEUE_COMPUTE_BIT & maskedFlags) { + queueFamilyIndex = i; + return success(); + } + } + + llvm::errs() << "cannot find valid queue"; + return failure(); +} + +LogicalResult VulkanRuntime::createMemoryBuffers() { + // For each descriptor set. + for (const auto &resourceDataMapPair : resourceData) { + llvm::SmallVector deviceMemoryBuffers; + const auto descriptorSetIndex = resourceDataMapPair.first; + const auto &resourceDataMap = resourceDataMapPair.second; + + // For each descriptor binding. + for (const auto &resourceDataBindingPair : resourceDataMap) { + // Create device memory buffer. + VulkanDeviceMemoryBuffer memoryBuffer; + memoryBuffer.bindingIndex = resourceDataBindingPair.first; + VkDescriptorType descriptorType = {}; + VkBufferUsageFlagBits bufferUsage = {}; + + // Check that descriptor set has storage class map. + const auto resourceStorageClassMapIt = + resourceStorageClassData.find(descriptorSetIndex); + if (resourceStorageClassMapIt == resourceStorageClassData.end()) { + llvm::errs() + << "cannot find storge class for resource in descriptor set: " + << descriptorSetIndex; + return failure(); + } + + // Check that specific descriptor binding has storage class. + const auto &resourceStorageClassMap = resourceStorageClassMapIt->second; + const auto resourceStorageClassIt = + resourceStorageClassMap.find(resourceDataBindingPair.first); + if (resourceStorageClassIt == resourceStorageClassMap.end()) { + llvm::errs() + << "cannot find storage class for resource with descriptor index: " + << resourceDataBindingPair.first; + return failure(); + } + + const auto resourceStorageClassBinding = resourceStorageClassIt->second; + if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding, + descriptorType)) || + failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding, + bufferUsage))) { + llvm::errs() << "storage class for resource with descriptor binding: " + << resourceDataBindingPair.first + << " in the descriptor set: " << descriptorSetIndex + << " is not supported "; + return failure(); + } + + // Set descriptor type for the specific device memory buffer. + memoryBuffer.descriptorType = descriptorType; + const auto bufferSize = resourceDataBindingPair.second.size; + + // Specify memory allocation info. + VkMemoryAllocateInfo memoryAllocateInfo = {}; + memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + memoryAllocateInfo.pNext = nullptr; + memoryAllocateInfo.allocationSize = bufferSize; + memoryAllocateInfo.memoryTypeIndex = memoryTypeIndex; + + // Allocate device memory. + RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0, + &memoryBuffer.deviceMemory), + "vkAllocateMemory"); + void *payload; + RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.deviceMemory, 0, + bufferSize, 0, + reinterpret_cast(&payload)), + "vkMapMemory"); + + // Copy host memory into the mapped area. + std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize); + vkUnmapMemory(device, memoryBuffer.deviceMemory); + + VkBufferCreateInfo bufferCreateInfo = {}; + bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + bufferCreateInfo.pNext = nullptr; + bufferCreateInfo.flags = 0; + bufferCreateInfo.size = bufferSize; + bufferCreateInfo.usage = bufferUsage; + bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + bufferCreateInfo.queueFamilyIndexCount = 1; + bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex; + RETURN_ON_VULKAN_ERROR( + vkCreateBuffer(device, &bufferCreateInfo, 0, &memoryBuffer.buffer), + "vkCreateBuffer"); + + // Bind buffer and device memory. + RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.buffer, + memoryBuffer.deviceMemory, 0), + "vkBindBufferMemory"); + + // Update buffer info. + memoryBuffer.bufferInfo.buffer = memoryBuffer.buffer; + memoryBuffer.bufferInfo.offset = 0; + memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE; + deviceMemoryBuffers.push_back(memoryBuffer); + } + + // Associate device memory buffers with a descriptor set. + deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers; + } + return success(); +} + +LogicalResult VulkanRuntime::createShaderModule() { + VkShaderModuleCreateInfo shaderModuleCreateInfo = {}; + shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + shaderModuleCreateInfo.pNext = nullptr; + shaderModuleCreateInfo.flags = 0; + // Set size in bytes. + shaderModuleCreateInfo.codeSize = binarySize; + // Set pointer to the binary shader. + shaderModuleCreateInfo.pCode = reinterpret_cast(binary); + RETURN_ON_VULKAN_ERROR( + vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule), + "vkCreateShaderModule"); + return success(); +} + +void VulkanRuntime::initDescriptorSetLayoutBindingMap() { + for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { + SmallVector descriptorSetLayoutBindings; + const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; + const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; + + // Create a layout binding for each descriptor. + for (const auto &memBuffer : deviceMemoryBuffers) { + VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {}; + descriptorSetLayoutBinding.binding = memBuffer.bindingIndex; + descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType; + descriptorSetLayoutBinding.descriptorCount = 1; + descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + descriptorSetLayoutBinding.pImmutableSamplers = 0; + descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding); + } + descriptorSetLayoutBindingMap[descriptorSetIndex] = + descriptorSetLayoutBindings; + } +} + +LogicalResult VulkanRuntime::createDescriptorSetLayout() { + for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) { + const auto descriptorSetIndex = deviceMemoryBufferMapPair.first; + const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second; + // Each descriptor in a descriptor set must be the same type. + VkDescriptorType descriptorType = + deviceMemoryBuffers.front().descriptorType; + const uint32_t descriptorSize = deviceMemoryBuffers.size(); + const auto descriptorSetLayoutBindingIt = + descriptorSetLayoutBindingMap.find(descriptorSetIndex); + + if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) { + llvm::errs() << "cannot find layout bindings for the set with number: " + << descriptorSetIndex; + return failure(); + } + + const auto &descriptorSetLayoutBindings = + descriptorSetLayoutBindingIt->second; + // Create descriptor set layout. + VkDescriptorSetLayout descriptorSetLayout = {}; + VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {}; + + descriptorSetLayoutCreateInfo.sType = + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; + descriptorSetLayoutCreateInfo.pNext = nullptr; + descriptorSetLayoutCreateInfo.flags = 0; + // Amount of descriptor bindings in a layout set. + descriptorSetLayoutCreateInfo.bindingCount = + descriptorSetLayoutBindings.size(); + descriptorSetLayoutCreateInfo.pBindings = + descriptorSetLayoutBindings.data(); + RETURN_ON_VULKAN_ERROR( + vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0, + &descriptorSetLayout), + "vkCreateDescriptorSetLayout"); + + descriptorSetLayouts.push_back(descriptorSetLayout); + descriptorSetInfoPool.push_back( + {descriptorSetIndex, descriptorSize, descriptorType}); + } + return success(); +} + +LogicalResult VulkanRuntime::createPipelineLayout() { + // Associate descriptor sets with a pipeline layout. + VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {}; + pipelineLayoutCreateInfo.sType = + VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; + pipelineLayoutCreateInfo.pNext = nullptr; + pipelineLayoutCreateInfo.flags = 0; + pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size(); + pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data(); + pipelineLayoutCreateInfo.pushConstantRangeCount = 0; + pipelineLayoutCreateInfo.pPushConstantRanges = 0; + RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device, + &pipelineLayoutCreateInfo, 0, + &pipelineLayout), + "vkCreatePipelineLayout"); + return success(); +} + +LogicalResult VulkanRuntime::createComputePipeline() { + VkPipelineShaderStageCreateInfo stageInfo = {}; + stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + stageInfo.pNext = nullptr; + stageInfo.flags = 0; + stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT; + stageInfo.module = shaderModule; + // Set entry point. + stageInfo.pName = entryPoint; + stageInfo.pSpecializationInfo = 0; + + VkComputePipelineCreateInfo computePipelineCreateInfo = {}; + computePipelineCreateInfo.sType = + VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + computePipelineCreateInfo.pNext = nullptr; + computePipelineCreateInfo.flags = 0; + computePipelineCreateInfo.stage = stageInfo; + computePipelineCreateInfo.layout = pipelineLayout; + computePipelineCreateInfo.basePipelineHandle = 0; + computePipelineCreateInfo.basePipelineIndex = 0; + RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1, + &computePipelineCreateInfo, 0, + &pipeline), + "vkCreateComputePipelines"); + return success(); +} + +LogicalResult VulkanRuntime::createDescriptorPool() { + llvm::SmallVector descriptorPoolSizes; + for (const auto &descriptorSetInfo : descriptorSetInfoPool) { + // For each descriptor set populate descriptor pool size. + VkDescriptorPoolSize descriptorPoolSize = {}; + descriptorPoolSize.type = descriptorSetInfo.descriptorType; + descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize; + descriptorPoolSizes.push_back(descriptorPoolSize); + } + + VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {}; + descriptorPoolCreateInfo.sType = + VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; + descriptorPoolCreateInfo.pNext = nullptr; + descriptorPoolCreateInfo.flags = 0; + descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size(); + descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size(); + descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data(); + RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device, + &descriptorPoolCreateInfo, 0, + &descriptorPool), + "vkCreateDescriptorPool"); + return success(); +} + +LogicalResult VulkanRuntime::allocateDescriptorSets() { + VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {}; + // Size of desciptor sets and descriptor layout sets is the same. + descriptorSets.resize(descriptorSetLayouts.size()); + descriptorSetAllocateInfo.sType = + VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; + descriptorSetAllocateInfo.pNext = nullptr; + descriptorSetAllocateInfo.descriptorPool = descriptorPool; + descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size(); + descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data(); + RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device, + &descriptorSetAllocateInfo, + descriptorSets.data()), + "vkAllocateDescriptorSets"); + return success(); +} + +LogicalResult VulkanRuntime::setWriteDescriptors() { + if (descriptorSets.size() != descriptorSetInfoPool.size()) { + llvm::errs() << "Each descriptor set must have descriptor set information"; + return failure(); + } + // For each descriptor set. + auto descriptorSetIt = descriptorSets.begin(); + // Each descriptor set is associated with descriptor set info. + for (const auto &descriptorSetInfo : descriptorSetInfoPool) { + // For each device memory buffer in the descriptor set. + const auto &deviceMemoryBuffers = + deviceMemoryBufferMap[descriptorSetInfo.descriptorSet]; + for (const auto &memoryBuffer : deviceMemoryBuffers) { + // Structure describing descriptor sets to write to. + VkWriteDescriptorSet wSet = {}; + wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + wSet.pNext = nullptr; + // Descirptor set. + wSet.dstSet = *descriptorSetIt; + wSet.dstBinding = memoryBuffer.bindingIndex; + wSet.dstArrayElement = 0; + wSet.descriptorCount = 1; + wSet.descriptorType = memoryBuffer.descriptorType; + wSet.pImageInfo = nullptr; + wSet.pBufferInfo = &memoryBuffer.bufferInfo; + wSet.pTexelBufferView = nullptr; + vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr); + } + // Increment descriptor set iterator. + ++descriptorSetIt; + } + return success(); +} + +LogicalResult VulkanRuntime::createCommandPool() { + VkCommandPoolCreateInfo commandPoolCreateInfo = {}; + commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; + commandPoolCreateInfo.pNext = nullptr; + commandPoolCreateInfo.flags = 0; + commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex; + RETURN_ON_VULKAN_ERROR( + vkCreateCommandPool(device, &commandPoolCreateInfo, 0, &commandPool), + "vkCreateCommandPool"); + return success(); +} + +LogicalResult VulkanRuntime::createComputeCommandBuffer() { + VkCommandBufferAllocateInfo commandBufferAllocateInfo = {}; + VkCommandBuffer commandBuffer; + commandBufferAllocateInfo.sType = + VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; + commandBufferAllocateInfo.pNext = nullptr; + commandBufferAllocateInfo.commandPool = commandPool; + commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; + commandBufferAllocateInfo.commandBufferCount = 1; + RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device, + &commandBufferAllocateInfo, + &commandBuffer), + "vkAllocateCommandBuffers"); + + VkCommandBufferBeginInfo commandBufferBeginInfo = {}; + commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; + commandBufferBeginInfo.pNext = nullptr; + commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT; + commandBufferBeginInfo.pInheritanceInfo = nullptr; + + // Commands begin. + RETURN_ON_VULKAN_ERROR( + vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo), + "vkBeginCommandBuffer"); + + // Commands. + vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); + vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, + pipelineLayout, 0, descriptorSets.size(), + descriptorSets.data(), 0, 0); + vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y, + numWorkGroups.z); + + // Commands end. + RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer), + "vkEndCommandBuffer"); + + commandBuffers.push_back(commandBuffer); + return success(); +} + +LogicalResult VulkanRuntime::submitCommandBuffersToQueue() { + VkSubmitInfo submitInfo = {}; + submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO; + submitInfo.pNext = nullptr; + submitInfo.waitSemaphoreCount = 0; + submitInfo.pWaitSemaphores = 0; + submitInfo.pWaitDstStageMask = 0; + submitInfo.commandBufferCount = commandBuffers.size(); + submitInfo.pCommandBuffers = commandBuffers.data(); + submitInfo.signalSemaphoreCount = 0; + submitInfo.pSignalSemaphores = nullptr; + RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0), + "vkQueueSubmit"); + return success(); +} + +LogicalResult VulkanRuntime::updateHostMemoryBuffers() { + // For each descriptor set. + for (auto &resourceDataMapPair : resourceData) { + auto &resourceDataMap = resourceDataMapPair.second; + auto &deviceMemoryBuffers = + deviceMemoryBufferMap[resourceDataMapPair.first]; + // For each device memory buffer in the set. + for (auto &deviceMemoryBuffer : deviceMemoryBuffers) { + if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) { + void *payload; + auto &hostMemoryBuffer = + resourceDataMap[deviceMemoryBuffer.bindingIndex]; + RETURN_ON_VULKAN_ERROR(vkMapMemory(device, + deviceMemoryBuffer.deviceMemory, 0, + hostMemoryBuffer.size, 0, + reinterpret_cast(&payload)), + "vkMapMemory"); + std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size); + vkUnmapMemory(device, deviceMemoryBuffer.deviceMemory); + } + } + } + return success(); +} diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -0,0 +1,45 @@ +//===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a command line utility that executes an MLIR file on the Vulkan by +// translating MLIR GPU module to SPIR-V and host part to LLVM IR before +// JIT-compiling and executing the latter. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" +#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/JitRunner.h" + +using namespace mlir; + +static LogicalResult runMLIRPasses(ModuleOp module) { + PassManager passManager(module.getContext()); + applyPassManagerCLOptions(passManager); + + passManager.addPass(createGpuKernelOutliningPass()); + passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass()); + passManager.addPass(createConvertGPUToSPIRVPass()); + OpPassManager &modulePM = passManager.nest(); + modulePM.addPass(spirv::createLowerABIAttributesPass()); + passManager.addPass(createConvertGpuLaunchFuncToVulkanCallsPass()); + passManager.addPass(createLowerToLLVMPass()); + return passManager.run(module); +} + +int main(int argc, char **argv) { + registerPassManagerCLOptions(); + return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); +} diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp @@ -0,0 +1,97 @@ +//===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements C runtime wrappers around the VulkanRuntime. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "VulkanRuntime.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/raw_ostream.h" + +namespace { + +// TODO(denis0x0D): This static machinery should be replaced by `initVulkan` and +// `deinitVulkan` to be more explicit and to avoid static initialization and +// destruction. +class VulkanRuntimeManager; +static llvm::ManagedStatic vkRuntimeManager; + +class VulkanRuntimeManager { + public: + VulkanRuntimeManager() = default; + VulkanRuntimeManager(const VulkanRuntimeManager &) = delete; + VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete; + ~VulkanRuntimeManager() = default; + + void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex, + const VulkanHostMemoryBuffer &memBuffer) { + std::lock_guard lock(mutex); + vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer); + } + + void setEntryPoint(const char *entryPoint) { + std::lock_guard lock(mutex); + vulkanRuntime.setEntryPoint(entryPoint); + } + + void setNumWorkGroups(NumWorkGroups numWorkGroups) { + std::lock_guard lock(mutex); + vulkanRuntime.setNumWorkGroups(numWorkGroups); + } + + void setShaderModule(uint8_t *shader, uint32_t size) { + std::lock_guard lock(mutex); + vulkanRuntime.setShaderModule(shader, size); + } + + void runOnVulkan() { + std::lock_guard lock(mutex); + if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) || + failed(vulkanRuntime.updateHostMemoryBuffers()) || + failed(vulkanRuntime.destroy())) { + llvm::errs() << "runOnVulkan failed"; + } + } + + private: + VulkanRuntime vulkanRuntime; + std::mutex mutex; +}; + +} // namespace + +extern "C" { +/// Fills the given memref with the given value. +/// Binds the given memref to the given descriptor set and descriptor index. +void setResourceData(const DescriptorSetIndex setIndex, BindingIndex bindIndex, + float *allocated, float *aligned, int64_t offset, + int64_t size, int64_t stride, float value) { + std::fill_n(allocated, size, value); + VulkanHostMemoryBuffer memBuffer{allocated, + static_cast(size * sizeof(float))}; + vkRuntimeManager->setResourceData(setIndex, bindIndex, memBuffer); +} + +void setEntryPoint(const char *entryPoint) { + vkRuntimeManager->setEntryPoint(entryPoint); +} + +void setNumWorkGroups(uint32_t x, uint32_t y, uint32_t z) { + vkRuntimeManager->setNumWorkGroups({x, y, z}); +} + +void setBinaryShader(uint8_t *shader, uint32_t size) { + vkRuntimeManager->setShaderModule(shader, size); +} + +void runOnVulkan() { vkRuntimeManager->runOnVulkan(); } +}