Changeset View
Changeset View
Standalone View
Standalone View
mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp
Show First 20 Lines • Show All 56 Lines • ▼ Show 20 Lines | |||||
private: | private: | ||||
VulkanRuntime vulkanRuntime; | VulkanRuntime vulkanRuntime; | ||||
std::mutex mutex; | std::mutex mutex; | ||||
}; | }; | ||||
} // namespace | } // namespace | ||||
template <typename T, int N> | |||||
struct MemRefDescriptor { | |||||
T *allocated; | |||||
T *aligned; | |||||
int64_t offset; | |||||
int64_t sizes[N]; | |||||
int64_t strides[N]; | |||||
}; | |||||
extern "C" { | extern "C" { | ||||
// Initializes `VulkanRuntimeManager` and returns a pointer to it. | /// Initializes `VulkanRuntimeManager` and returns a pointer to it. | ||||
void *initVulkan() { return new VulkanRuntimeManager(); } | void *initVulkan() { return new VulkanRuntimeManager(); } | ||||
// Deinitializes `VulkanRuntimeManager` by the given pointer. | /// Deinitializes `VulkanRuntimeManager` by the given pointer. | ||||
void deinitVulkan(void *vkRuntimeManager) { | void deinitVulkan(void *vkRuntimeManager) { | ||||
delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager); | delete reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager); | ||||
} | } | ||||
/// Binds the given memref to the given descriptor set and descriptor index. | |||||
void bindResource(void *vkRuntimeManager, DescriptorSetIndex setIndex, | |||||
BindingIndex bindIndex, float *ptr, int64_t size) { | |||||
VulkanHostMemoryBuffer memBuffer{ptr, | |||||
static_cast<uint32_t>(size * sizeof(float))}; | |||||
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | |||||
->setResourceData(setIndex, bindIndex, memBuffer); | |||||
} | |||||
void runOnVulkan(void *vkRuntimeManager) { | void runOnVulkan(void *vkRuntimeManager) { | ||||
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan(); | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager)->runOnVulkan(); | ||||
} | } | ||||
/// Fills the given 1D float memref with the given float value. | |||||
void fillResource1DFloat(float *allocated, float *aligned, int64_t offset, | |||||
int64_t size, int64_t stride, float value) { | |||||
std::fill_n(allocated, size, value); | |||||
} | |||||
void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) { | void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) { | ||||
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | ||||
->setEntryPoint(entryPoint); | ->setEntryPoint(entryPoint); | ||||
} | } | ||||
void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, | void setNumWorkGroups(void *vkRuntimeManager, uint32_t x, uint32_t y, | ||||
uint32_t z) { | uint32_t z) { | ||||
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | ||||
->setNumWorkGroups({x, y, z}); | ->setNumWorkGroups({x, y, z}); | ||||
} | } | ||||
void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) { | void setBinaryShader(void *vkRuntimeManager, uint8_t *shader, uint32_t size) { | ||||
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | ||||
->setShaderModule(shader, size); | ->setShaderModule(shader, size); | ||||
} | } | ||||
/// Binds the given 1D float memref to the given descriptor set and descriptor | |||||
/// index. | |||||
void bindMemRef1DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex, | |||||
BindingIndex bindIndex, | |||||
MemRefDescriptor<float, 1> *ptr) { | |||||
VulkanHostMemoryBuffer memBuffer{ | |||||
ptr->allocated, static_cast<uint32_t>(ptr->sizes[0] * sizeof(float))}; | |||||
reinterpret_cast<VulkanRuntimeManager *>(vkRuntimeManager) | |||||
->setResourceData(setIndex, bindIndex, memBuffer); | |||||
} | |||||
/// Fills the given 1D float memref with the given float value. | |||||
void _mlir_ciface_fillResource1DFloat(MemRefDescriptor<float, 1> *ptr, // NOLINT | |||||
float value) { | |||||
std::fill_n(ptr->allocated, ptr->sizes[0], value); | |||||
} | |||||
} | } |