Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
//===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// | //===- ConvertLaunchFuncToVulkanCalls.cpp - MLIR Vulkan conversion passes -===// | ||||
Lint: Lint: clang-format-diff not found in user's PATH; not linting file. | |||||
// | // | ||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
// See https://llvm.org/LICENSE.txt for license information. | // See https://llvm.org/LICENSE.txt for license information. | ||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// | // | ||||
▲ Show 20 Lines • Show All 43 Lines • ▼ Show 20 Lines | |||||
/// * bindMemRef -- binds memref | /// * bindMemRef -- binds memref | ||||
/// * setBinaryShader -- sets the binary shader data | /// * setBinaryShader -- sets the binary shader data | ||||
/// * setEntryPoint -- sets the entry point name | /// * setEntryPoint -- sets the entry point name | ||||
/// * setNumWorkGroups -- sets the number of a local workgroups | /// * setNumWorkGroups -- sets the number of a local workgroups | ||||
/// * runOnVulkan -- runs vulkan runtime | /// * runOnVulkan -- runs vulkan runtime | ||||
/// * deinitVulkan -- deinitializes vulkan runtime | /// * deinitVulkan -- deinitializes vulkan runtime | ||||
/// | /// | ||||
class VulkanLaunchFuncToVulkanCallsPass | class VulkanLaunchFuncToVulkanCallsPass | ||||
: public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { | : public OperationPass<VulkanLaunchFuncToVulkanCallsPass, ModuleOp> { | ||||
private: | private: | ||||
/// Include the generated pass utilities. | /// Include the generated pass utilities. | ||||
#define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls | #define GEN_PASS_ConvertVulkanLaunchFuncToVulkanCalls | ||||
#include "mlir/Conversion/Passes.h.inc" | #include "mlir/Conversion/Passes.h.inc" | ||||
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } | LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } | ||||
llvm::LLVMContext &getLLVMContext() { | llvm::LLVMContext &getLLVMContext() { | ||||
▲ Show 20 Lines • Show All 79 Lines • ▼ Show 20 Lines | #include "mlir/Conversion/Passes.h.inc" | ||||
/// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. | /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. | ||||
void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); | void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); | ||||
/// Deduces a rank from the given 'ptrToMemRefDescriptor`. | /// Deduces a rank from the given 'ptrToMemRefDescriptor`. | ||||
LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); | LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); | ||||
public: | public: | ||||
void runOnModule() override; | void runOnOperation() override; | ||||
private: | private: | ||||
LLVM::LLVMDialect *llvmDialect; | LLVM::LLVMDialect *llvmDialect; | ||||
LLVM::LLVMType llvmFloatType; | LLVM::LLVMType llvmFloatType; | ||||
LLVM::LLVMType llvmVoidType; | LLVM::LLVMType llvmVoidType; | ||||
LLVM::LLVMType llvmPointerType; | LLVM::LLVMType llvmPointerType; | ||||
LLVM::LLVMType llvmInt32Type; | LLVM::LLVMType llvmInt32Type; | ||||
LLVM::LLVMType llvmInt64Type; | LLVM::LLVMType llvmInt64Type; | ||||
LLVM::LLVMType llvmMemRef1DFloat; | LLVM::LLVMType llvmMemRef1DFloat; | ||||
LLVM::LLVMType llvmMemRef2DFloat; | LLVM::LLVMType llvmMemRef2DFloat; | ||||
// TODO: Use an associative array to support multiple vulkan launch calls. | // TODO: Use an associative array to support multiple vulkan launch calls. | ||||
std::pair<StringAttr, StringAttr> spirvAttributes; | std::pair<StringAttr, StringAttr> spirvAttributes; | ||||
}; | }; | ||||
} // anonymous namespace | } // anonymous namespace | ||||
void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { | void VulkanLaunchFuncToVulkanCallsPass::runOnOperation() { | ||||
initializeCachedTypes(); | initializeCachedTypes(); | ||||
// Collect SPIR-V attributes such as `spirv_blob` and | // Collect SPIR-V attributes such as `spirv_blob` and | ||||
// `spirv_entry_point_name`. | // `spirv_entry_point_name`. | ||||
getModule().walk([this](LLVM::CallOp op) { | getOperation().walk([this](LLVM::CallOp op) { | ||||
if (isVulkanLaunchCallOp(op)) | if (isVulkanLaunchCallOp(op)) | ||||
collectSPIRVAttributes(op); | collectSPIRVAttributes(op); | ||||
}); | }); | ||||
// Convert vulkan launch call op into a sequence of Vulkan runtime calls. | // Convert vulkan launch call op into a sequence of Vulkan runtime calls. | ||||
getModule().walk([this](LLVM::CallOp op) { | getOperation().walk([this](LLVM::CallOp op) { | ||||
if (isCInterfaceVulkanLaunchCallOp(op)) | if (isCInterfaceVulkanLaunchCallOp(op)) | ||||
translateVulkanLaunchCall(op); | translateVulkanLaunchCall(op); | ||||
}); | }); | ||||
} | } | ||||
void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( | void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( | ||||
LLVM::CallOp vulkanLaunchCallOp) { | LLVM::CallOp vulkanLaunchCallOp) { | ||||
// Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes | // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes | ||||
▲ Show 20 Lines • Show All 81 Lines • ▼ Show 20 Lines | if (llvmDescriptorTy.getStructNumElements() == 3) { | ||||
return success(); | return success(); | ||||
} | } | ||||
rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); | rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); | ||||
return success(); | return success(); | ||||
} | } | ||||
void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { | void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { | ||||
ModuleOp module = getModule(); | ModuleOp module = getOperation(); | ||||
OpBuilder builder(module.getBody()->getTerminator()); | OpBuilder builder(module.getBody()->getTerminator()); | ||||
if (!module.lookupSymbol(kSetEntryPoint)) { | if (!module.lookupSymbol(kSetEntryPoint)) { | ||||
builder.create<LLVM::LLVMFuncOp>( | builder.create<LLVM::LLVMFuncOp>( | ||||
loc, kSetEntryPoint, | loc, kSetEntryPoint, | ||||
LLVM::LLVMType::getFunctionTy(getVoidType(), | LLVM::LLVMType::getFunctionTy(getVoidType(), | ||||
{getPointerType(), getPointerType()}, | {getPointerType(), getPointerType()}, | ||||
/*isVarArg=*/false)); | /*isVarArg=*/false)); | ||||
▲ Show 20 Lines • Show All 143 Lines • Show Last 20 Lines |
clang-format-diff not found in user's PATH; not linting file.