diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -33,10 +33,10 @@ namespace { -// A pass to convert gpu launch op to vulkan launch call op, by creating a -// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize` -// function and attaching binary data and entry point name as an attributes to -// created vulkan launch call op. +/// A pass to convert gpu launch op to vulkan launch call op, by creating a +/// SPIR-V binary shader from `spirv::ModuleOp` using `spirv::serialize` +/// function and attaching binary data and entry point name as an attributes to +/// created vulkan launch call op. class ConvertGpuLaunchFuncToVulkanLaunchFunc : public ModulePass { public: diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -26,7 +26,8 @@ using namespace mlir; -static constexpr const char *kBindResource = "bindResource"; +static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; +static constexpr const char *kCiFaceVulkanLaunch = "_mlir_ciface_vulkanLaunch"; static constexpr const char *kDeinitVulkan = "deinitVulkan"; static constexpr const char *kRunOnVulkan = "runOnVulkan"; static constexpr const char *kInitVulkan = "initVulkan"; @@ -40,11 +41,11 @@ namespace { -/// A pass to convert vulkan launch func into a sequence of Vulkan +/// A pass to convert vulkan launch call op into a sequence of Vulkan /// runtime calls in the following order: /// /// * initVulkan -- initializes vulkan runtime -/// * bindResource -- binds resource +/// * bindMemRef -- binds memref /// * setBinaryShader -- sets the binary shader data /// * setEntryPoint -- sets the entry point name /// * setNumWorkGroups -- sets the number of a local workgroups @@ -67,6 +68,29 @@ llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); + initializeMemRefTypes(); + } + + void initializeMemRefTypes() { + // According to the MLIR doc memref argument is converted into a + // pointer-to-struct argument of type: + // template + // struct { + // Elem *allocated; + // Elem *aligned; + // int64_t offset; + // int64_t sizes[Rank]; // omitted when rank == 0 + // int64_t strides[Rank]; // omitted when rank == 0 + // }; + auto llvmPtrToFloatType = getFloatType().getPointerTo(); + auto llvmArrayOneElementSizeType = + LLVM::LLVMType::getArrayTy(getInt64Type(), 1); + + // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`. + llvmMemRef1DFloat = LLVM::LLVMType::getStructTy( + llvmDialect, + {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), + llvmArrayOneElementSizeType, llvmArrayOneElementSizeType}); } LLVM::LLVMType getFloatType() { return llvmFloatType; } @@ -74,6 +98,7 @@ LLVM::LLVMType getPointerType() { return llvmPointerType; } LLVM::LLVMType getInt32Type() { return llvmInt32Type; } LLVM::LLVMType getInt64Type() { return llvmInt64Type; } + LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } /// Creates a LLVM global for the given `name`. Value createEntryPointNameConstant(StringRef name, Location loc, @@ -85,16 +110,27 @@ /// Checks whether the given LLVM::CallOp is a vulkan launch call op. bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && - callOp.getNumOperands() >= 6); + callOp.getNumOperands() >= kNumConfigOps); + } + + /// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call + /// op. + bool isCiFaceVulkanLaunchCallOp(LLVM::CallOp callOp) { + return (callOp.callee() && + callOp.callee().getValue() == kCiFaceVulkanLaunch && + callOp.getNumOperands() >= kNumConfigOps); } /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan /// runtime calls. void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); - /// Creates call to `bindResource` for each resource operand. - void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp, - Value vulkanRuntiem); + /// Creates call to `bindMemRef` for each memref operand. + void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, + Value vulkanRuntime); + + /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. + void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); public: void runOnModule() override; @@ -106,89 +142,83 @@ LLVM::LLVMType llvmPointerType; LLVM::LLVMType llvmInt32Type; LLVM::LLVMType llvmInt64Type; -}; - -/// Represents operand adaptor for vulkan launch call operation, to simplify an -/// access to the lowered memref. -// TODO: We should use 'emit-c-wrappers' option to lower memref type: -// https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission. -struct VulkanLaunchOpOperandAdaptor { - VulkanLaunchOpOperandAdaptor(ArrayRef values) { operands = values; } - VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete; - VulkanLaunchOpOperandAdaptor - operator=(const VulkanLaunchOpOperandAdaptor &) = delete; - - /// Returns a tuple with a pointer to the memory and the size for the index-th - /// resource. - std::tuple getResourceDescriptor1D(uint32_t index) { - assert(index < getResourceCount1D()); - // 1D memref calling convention according to "ConversionToLLVMDialect.md": - // 0. Allocated pointer. - // 1. Aligned pointer. - // 2. Offset. - // 3. Size in dim 0. - // 4. Stride in dim 0. - auto offset = numConfigOps + index * loweredMemRefNumOps1D; - return std::make_tuple(operands[offset], operands[offset + 3]); - } + LLVM::LLVMType llvmMemRef1DFloat; - /// Returns the number of resources assuming all operands lowered from - /// 1D memref. - uint32_t getResourceCount1D() { - return (operands.size() - numConfigOps) / loweredMemRefNumOps1D; - } - -private: - /// The number of operands of lowered 1D memref. - static constexpr const uint32_t loweredMemRefNumOps1D = 5; - /// The number of the first config operands. - static constexpr const uint32_t numConfigOps = 6; - ArrayRef operands; + // TODO: Use an associative array to support multiple vulkan launch calls. + std::pair spirvAttributes; + static constexpr const uint32_t kNumConfigOps = 6; }; } // anonymous namespace void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { initializeCachedTypes(); + + // Collect SPIRV attributes such as `spirv_blob` and `spirv_entry_point_name`. getModule().walk([this](LLVM::CallOp op) { if (isVulkanLaunchCallOp(op)) + collectSPIRVAttributes(op); + }); + + // Convert vulkan launch call op into a sequence of Vulkan runtime calls. + getModule().walk([this](LLVM::CallOp op) { + if (isCiFaceVulkanLaunchCallOp(op)) translateVulkanLaunchCall(op); }); } -void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls( - LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) { - if (vulkanLaunchCallOp.getNumOperands() == 6) +void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( + LLVM::CallOp vulkanLaunchCallOp) { + // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes + // for the given vulkan launch call. + auto spirvBlobAttr = + vulkanLaunchCallOp.getAttrOfType(kSPIRVBlobAttrName); + if (!spirvBlobAttr) { + vulkanLaunchCallOp.emitError() + << "missing " << kSPIRVBlobAttrName << " attribute"; + return signalPassFailure(); + } + + auto spirvEntryPointNameAttr = + vulkanLaunchCallOp.getAttrOfType(kSPIRVEntryPointAttrName); + if (!spirvEntryPointNameAttr) { + vulkanLaunchCallOp.emitError() + << "missing " << kSPIRVEntryPointAttrName << " attribute"; + return signalPassFailure(); + } + + spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); +} + +void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( + LLVM::CallOp ciFaceVulkanLaunchCallOp, Value vulkanRuntime) { + if (ciFaceVulkanLaunchCallOp.getNumOperands() == kNumConfigOps) return; - OpBuilder builder(vulkanLaunchCallOp); - Location loc = vulkanLaunchCallOp.getLoc(); + OpBuilder builder(ciFaceVulkanLaunchCallOp); + Location loc = ciFaceVulkanLaunchCallOp.getLoc(); // Create LLVM constant for the descriptor set index. - // Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV` + // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` // pass does. Value descriptorSet = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(0)); + auto operands = + SmallVector{ciFaceVulkanLaunchCallOp.getOperands()}; - auto operands = SmallVector{vulkanLaunchCallOp.getOperands()}; - VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands); - - for (auto resourceIdx : - llvm::seq(0, vkLaunchOperandAdaptor.getResourceCount1D())) { + uint32_t operandIdx = 0; + for (const auto ptrToMemRefDescriptor : + llvm::drop_begin(operands, kNumConfigOps)) { // Create LLVM constant for the descriptor binding index. Value descriptorBinding = builder.create( - loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx)); - // Get a pointer to the memory and size of that memory. - auto resourceDescriptor = - vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx); - // Create call to `bindResource`. + loc, getInt32Type(), builder.getI32IntegerAttr(operandIdx)); + // Create call to `bindMemRef`. builder.create( loc, ArrayRef{getVoidType()}, - builder.getSymbolRefAttr(kBindResource), + // TODO: Add support for memref with other ranks. + builder.getSymbolRefAttr(kBindMemRef1DFloat), ArrayRef{vulkanRuntime, descriptorSet, descriptorBinding, - // Pointer to the memory. - std::get<0>(resourceDescriptor), - // Size of the memory. - std::get<1>(resourceDescriptor)}); + ptrToMemRefDescriptor}); + ++operandIdx; } } @@ -228,14 +258,14 @@ /*isVarArg=*/false)); } - if (!module.lookupSymbol(kBindResource)) { + if (!module.lookupSymbol(kBindMemRef1DFloat)) { builder.create( - loc, kBindResource, - LLVM::LLVMType::getFunctionTy( - getVoidType(), - {getPointerType(), getInt32Type(), getInt32Type(), - getFloatType().getPointerTo(), getInt64Type()}, - /*isVarArg=*/false)); + loc, kBindMemRef1DFloat, + LLVM::LLVMType::getFunctionTy(getVoidType(), + {getPointerType(), getInt32Type(), + getInt32Type(), + getMemRef1DFloat().getPointerTo()}, + /*isVarArg=*/false)); } if (!module.lookupSymbol(kInitVulkan)) { @@ -267,28 +297,9 @@ } void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( - LLVM::CallOp vulkanLaunchCallOp) { - OpBuilder builder(vulkanLaunchCallOp); - Location loc = vulkanLaunchCallOp.getLoc(); - - // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes - // for the given vulkan launch call. - auto spirvBlobAttr = - vulkanLaunchCallOp.getAttrOfType(kSPIRVBlobAttrName); - if (!spirvBlobAttr) { - vulkanLaunchCallOp.emitError() - << "missing " << kSPIRVBlobAttrName << " attribute"; - return signalPassFailure(); - } - - auto entryPointNameAttr = - vulkanLaunchCallOp.getAttrOfType(kSPIRVEntryPointAttrName); - if (!entryPointNameAttr) { - vulkanLaunchCallOp.emitError() - << "missing " << kSPIRVEntryPointAttrName << " attribute"; - return signalPassFailure(); - } - + LLVM::CallOp ciFaceVulkanLaunchCallOp) { + OpBuilder builder(ciFaceVulkanLaunchCallOp); + Location loc = ciFaceVulkanLaunchCallOp.getLoc(); // Create call to `initVulkan`. auto initVulkanCall = builder.create( loc, ArrayRef{getPointerType()}, @@ -300,16 +311,16 @@ // Create LLVM global with SPIR-V binary data, so we can pass a pointer with // that data to runtime call. Value ptrToSPIRVBinary = LLVM::createGlobalString( - loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(), + loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), LLVM::Linkage::Internal, getLLVMDialect()); // Create LLVM constant for the size of SPIR-V binary shader. Value binarySize = builder.create( loc, getInt32Type(), - builder.getI32IntegerAttr(spirvBlobAttr.getValue().size())); + builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); - // Create call to `bindResource` for each resource operand. - createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime); + // Create call to `bindMemRef` for each memref operand. + createBindMemRefCalls(ciFaceVulkanLaunchCallOp, vulkanRuntime); // Create call to `setBinaryShader` runtime function with the given pointer to // SPIR-V binary and binary size. @@ -318,8 +329,8 @@ builder.getSymbolRefAttr(kSetBinaryShader), ArrayRef{vulkanRuntime, ptrToSPIRVBinary, binarySize}); // Create LLVM global with entry point name. - Value entryPointName = - createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder); + Value entryPointName = createEntryPointNameConstant( + spirvAttributes.second.getValue(), loc, builder); // Create call to `setEntryPoint` runtime function with the given pointer to // entry point name. builder.create(loc, ArrayRef{getVoidType()}, @@ -330,9 +341,9 @@ builder.create( loc, ArrayRef{getVoidType()}, builder.getSymbolRefAttr(kSetNumWorkGroups), - ArrayRef{vulkanRuntime, vulkanLaunchCallOp.getOperand(0), - vulkanLaunchCallOp.getOperand(1), - vulkanLaunchCallOp.getOperand(2)}); + ArrayRef{vulkanRuntime, ciFaceVulkanLaunchCallOp.getOperand(0), + ciFaceVulkanLaunchCallOp.getOperand(1), + ciFaceVulkanLaunchCallOp.getOperand(2)}); // Create call to `runOnVulkan` runtime function. builder.create(loc, ArrayRef{getVoidType()}, @@ -347,7 +358,7 @@ // Declare runtime functions. declareVulkanFunctions(loc); - vulkanLaunchCallOp.erase(); + ciFaceVulkanLaunchCallOp.erase(); } std::unique_ptr> diff --git a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir --- a/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir +++ b/mlir/test/Conversion/GPUToVulkan/invoke-vulkan.mlir @@ -6,7 +6,7 @@ // CHECK: %[[addressof_SPIRV_BIN:.*]] = llvm.mlir.addressof @SPIRV_BIN // CHECK: %[[SPIRV_BIN_ptr:.*]] = llvm.getelementptr %[[addressof_SPIRV_BIN]] // CHECK: %[[SPIRV_BIN_size:.*]] = llvm.mlir.constant -// CHECK: llvm.call @bindResource(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"float*">, !llvm.i64) -> !llvm.void +// CHECK: llvm.call @bindMemRef1DFloat(%[[Vulkan_Runtime_ptr]], %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"i8*">, !llvm.i32, !llvm.i32, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> !llvm.void // CHECK: llvm.call @setBinaryShader(%[[Vulkan_Runtime_ptr]], %[[SPIRV_BIN_ptr]], %[[SPIRV_BIN_size]]) : (!llvm<"i8*">, !llvm<"i8*">, !llvm.i32) -> !llvm.void // CHECK: %[[addressof_entry_point:.*]] = llvm.mlir.addressof @kernel_spv_entry_point_name // CHECK: %[[entry_point_ptr:.*]] = llvm.getelementptr %[[addressof_entry_point]] @@ -44,5 +44,18 @@ : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) -> () llvm.return } - llvm.func @vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"float*">, !llvm<"float*">, !llvm.i64, !llvm.i64, !llvm.i64) + llvm.func @vulkanLaunch(%arg0: !llvm.i64, %arg1: !llvm.i64, %arg2: !llvm.i64, %arg3: !llvm.i64, %arg4: !llvm.i64, %arg5: !llvm.i64, %arg6: !llvm<"float*">, %arg7: !llvm<"float*">, %arg8: !llvm.i64, %arg9: !llvm.i64, %arg10: !llvm.i64) { + %0 = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %1 = llvm.insertvalue %arg6, %0[0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %2 = llvm.insertvalue %arg7, %1[1] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %3 = llvm.insertvalue %arg8, %2[2] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %4 = llvm.insertvalue %arg9, %3[3, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %5 = llvm.insertvalue %arg10, %4[4, 0] : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> + %6 = llvm.mlir.constant(1 : index) : !llvm.i64 + %7 = llvm.alloca %6 x !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }"> : (!llvm.i64) -> !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> + llvm.store %5, %7 : !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*"> + llvm.call @_mlir_ciface_vulkanLaunch(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %7) : (!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) -> () + llvm.return + } + llvm.func @_mlir_ciface_vulkanLaunch(!llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm.i64, !llvm<"{ float*, float*, i64, [1 x i64], [1 x i64] }*">) } diff --git a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h --- a/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h +++ b/mlir/tools/mlir-vulkan-runner/VulkanRuntime.h @@ -22,7 +22,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Support/ToolOutputFile.h" -#include // NOLINT +#include using namespace mlir; diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -39,7 +39,9 @@ OpPassManager &modulePM = passManager.nest(); modulePM.addPass(spirv::createLowerABIAttributesPass()); passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass()); - passManager.addPass(createLowerToLLVMPass()); + passManager.addPass(createLowerToLLVMPass(/*useAlloca=*/false, + /*useBarePtrCallConv=*/false, + /*emitCWrappers=*/true)); passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass()); return passManager.run(module); } diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp --- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp +++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp @@ -62,34 +62,28 @@ } // namespace +template +struct MemRefDescriptor { + T *allocated; + T *aligned; + int64_t offset; + int64_t sizes[N]; + int64_t strides[N]; +}; + extern "C" { -// Initializes `VulkanRuntimeManager` and returns a pointer to it. +/// Initializes `VulkanRuntimeManager` and returns a pointer to it. void *initVulkan() { return new VulkanRuntimeManager(); } -// Deinitializes `VulkanRuntimeManager` by the given pointer. +/// Deinitializes `VulkanRuntimeManager` by the given pointer. void deinitVulkan(void *vkRuntimeManager) { delete reinterpret_cast(vkRuntimeManager); } -/// Binds the given memref to the given descriptor set and descriptor index. -void bindResource(void *vkRuntimeManager, DescriptorSetIndex setIndex, - BindingIndex bindIndex, float *ptr, int64_t size) { - VulkanHostMemoryBuffer memBuffer{ptr, - static_cast(size * sizeof(float))}; - reinterpret_cast(vkRuntimeManager) - ->setResourceData(setIndex, bindIndex, memBuffer); -} - void runOnVulkan(void *vkRuntimeManager) { reinterpret_cast(vkRuntimeManager)->runOnVulkan(); } -/// Fills the given 1D float memref with the given float value. -void fillResource1DFloat(float *allocated, float *aligned, int64_t offset, - int64_t size, int64_t stride, float value) { - std::fill_n(allocated, size, value); -} - void setEntryPoint(void *vkRuntimeManager, const char *entryPoint) { reinterpret_cast(vkRuntimeManager) ->setEntryPoint(entryPoint); @@ -105,4 +99,21 @@ reinterpret_cast(vkRuntimeManager) ->setShaderModule(shader, size); } + +/// Binds the given 1D float memref to the given descriptor set and descriptor +/// index. +void bindMemRef1DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex, + BindingIndex bindIndex, + MemRefDescriptor *ptr) { + VulkanHostMemoryBuffer memBuffer{ + ptr->allocated, static_cast(ptr->sizes[0] * sizeof(float))}; + reinterpret_cast(vkRuntimeManager) + ->setResourceData(setIndex, bindIndex, memBuffer); +} + +/// Fills the given 1D float memref with the given float value. +void _mlir_ciface_fillResource1DFloat(MemRefDescriptor *ptr, // NOLINT + float value) { + std::fill_n(ptr->allocated, ptr->sizes[0], value); +} }