Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp
Show All 20 Lines | |||||
#include "mlir/IR/Function.h" | #include "mlir/IR/Function.h" | ||||
#include "mlir/IR/Module.h" | #include "mlir/IR/Module.h" | ||||
#include "mlir/Pass/Pass.h" | #include "mlir/Pass/Pass.h" | ||||
#include "llvm/ADT/SmallString.h" | #include "llvm/ADT/SmallString.h" | ||||
using namespace mlir; | using namespace mlir; | ||||
static constexpr const char *kBindResource = "bindResource"; | static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; | ||||
static constexpr const char *kCiFaceVulkanLaunch = "_mlir_ciface_vulkanLaunch"; | |||||
static constexpr const char *kDeinitVulkan = "deinitVulkan"; | static constexpr const char *kDeinitVulkan = "deinitVulkan"; | ||||
static constexpr const char *kRunOnVulkan = "runOnVulkan"; | static constexpr const char *kRunOnVulkan = "runOnVulkan"; | ||||
static constexpr const char *kInitVulkan = "initVulkan"; | static constexpr const char *kInitVulkan = "initVulkan"; | ||||
static constexpr const char *kSetBinaryShader = "setBinaryShader"; | static constexpr const char *kSetBinaryShader = "setBinaryShader"; | ||||
static constexpr const char *kSetEntryPoint = "setEntryPoint"; | static constexpr const char *kSetEntryPoint = "setEntryPoint"; | ||||
static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; | static constexpr const char *kSetNumWorkGroups = "setNumWorkGroups"; | ||||
static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; | static constexpr const char *kSPIRVBinary = "SPIRV_BIN"; | ||||
static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; | static constexpr const char *kSPIRVBlobAttrName = "spirv_blob"; | ||||
static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; | static constexpr const char *kSPIRVEntryPointAttrName = "spirv_entry_point"; | ||||
static constexpr const char *kVulkanLaunch = "vulkanLaunch"; | static constexpr const char *kVulkanLaunch = "vulkanLaunch"; | ||||
namespace { | namespace { | ||||
/// A pass to convert vulkan launch func into a sequence of Vulkan | /// A pass to convert vulkan launch call op into a sequence of Vulkan | ||||
/// runtime calls in the following order: | /// runtime calls in the following order: | ||||
/// | /// | ||||
/// * initVulkan -- initializes vulkan runtime | /// * initVulkan -- initializes vulkan runtime | ||||
/// * bindResource -- binds resource | /// * bindMemRef -- binds memref | ||||
/// * setBinaryShader -- sets the binary shader data | /// * setBinaryShader -- sets the binary shader data | ||||
/// * setEntryPoint -- sets the entry point name | /// * setEntryPoint -- sets the entry point name | ||||
/// * setNumWorkGroups -- sets the number of a local workgroups | /// * setNumWorkGroups -- sets the number of a local workgroups | ||||
/// * runOnVulkan -- runs vulkan runtime | /// * runOnVulkan -- runs vulkan runtime | ||||
/// * deinitVulkan -- deinitializes vulkan runtime | /// * deinitVulkan -- deinitializes vulkan runtime | ||||
/// | /// | ||||
class VulkanLaunchFuncToVulkanCallsPass | class VulkanLaunchFuncToVulkanCallsPass | ||||
: public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { | : public ModulePass<VulkanLaunchFuncToVulkanCallsPass> { | ||||
private: | private: | ||||
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } | LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; } | ||||
llvm::LLVMContext &getLLVMContext() { | llvm::LLVMContext &getLLVMContext() { | ||||
return getLLVMDialect()->getLLVMContext(); | return getLLVMDialect()->getLLVMContext(); | ||||
} | } | ||||
void initializeCachedTypes() { | void initializeCachedTypes() { | ||||
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); | llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>(); | ||||
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); | llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect); | ||||
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); | llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect); | ||||
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); | llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect); | ||||
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); | llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect); | ||||
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); | llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect); | ||||
initializeMemRefTypes(); | |||||
} | |||||
void initializeMemRefTypes() { | |||||
// According to the MLIR doc memref argument is converted into a | |||||
// pointer-to-struct argument of type: | |||||
// template <typename Elem, size_t Rank> | |||||
// struct { | |||||
// Elem *allocated; | |||||
// Elem *aligned; | |||||
// int64_t offset; | |||||
// int64_t sizes[Rank]; // omitted when rank == 0 | |||||
// int64_t strides[Rank]; // omitted when rank == 0 | |||||
// }; | |||||
auto llvmPtrToFloatType = getFloatType().getPointerTo(); | |||||
auto llvmArrayOneElementSizeType = | |||||
LLVM::LLVMType::getArrayTy(getInt64Type(), 1); | |||||
// Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`. | |||||
llvmMemRef1DFloat = LLVM::LLVMType::getStructTy( | |||||
llvmDialect, | |||||
{llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), | |||||
llvmArrayOneElementSizeType, llvmArrayOneElementSizeType}); | |||||
} | } | ||||
LLVM::LLVMType getFloatType() { return llvmFloatType; } | LLVM::LLVMType getFloatType() { return llvmFloatType; } | ||||
LLVM::LLVMType getVoidType() { return llvmVoidType; } | LLVM::LLVMType getVoidType() { return llvmVoidType; } | ||||
LLVM::LLVMType getPointerType() { return llvmPointerType; } | LLVM::LLVMType getPointerType() { return llvmPointerType; } | ||||
LLVM::LLVMType getInt32Type() { return llvmInt32Type; } | LLVM::LLVMType getInt32Type() { return llvmInt32Type; } | ||||
LLVM::LLVMType getInt64Type() { return llvmInt64Type; } | LLVM::LLVMType getInt64Type() { return llvmInt64Type; } | ||||
LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } | |||||
/// Creates a LLVM global for the given `name`. | /// Creates a LLVM global for the given `name`. | ||||
Value createEntryPointNameConstant(StringRef name, Location loc, | Value createEntryPointNameConstant(StringRef name, Location loc, | ||||
OpBuilder &builder); | OpBuilder &builder); | ||||
/// Declares all needed runtime functions. | /// Declares all needed runtime functions. | ||||
void declareVulkanFunctions(Location loc); | void declareVulkanFunctions(Location loc); | ||||
/// Checks whether the given LLVM::CallOp is a vulkan launch call op. | /// Checks whether the given LLVM::CallOp is a vulkan launch call op. | ||||
bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { | bool isVulkanLaunchCallOp(LLVM::CallOp callOp) { | ||||
return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && | return (callOp.callee() && callOp.callee().getValue() == kVulkanLaunch && | ||||
callOp.getNumOperands() >= 6); | callOp.getNumOperands() >= kNumConfigOps); | ||||
} | |||||
/// Checks whether the given LLVM::CallOp is a "ci_face" vulkan launch call | |||||
/// op. | |||||
bool isCiFaceVulkanLaunchCallOp(LLVM::CallOp callOp) { | |||||
antiagainst: Nit: `isCInterfaceVulkanLaunchCallOp`? Just spell it out to make it clear. Applies to other… | |||||
Thanks, fixed! denis13: Thanks, fixed! | |||||
return (callOp.callee() && | |||||
callOp.callee().getValue() == kCiFaceVulkanLaunch && | |||||
callOp.getNumOperands() >= kNumConfigOps); | |||||
} | } | ||||
/// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan | /// Translates the given `vulkanLaunchCallOp` to the sequence of Vulkan | ||||
/// runtime calls. | /// runtime calls. | ||||
void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); | void translateVulkanLaunchCall(LLVM::CallOp vulkanLaunchCallOp); | ||||
/// Creates call to `bindResource` for each resource operand. | /// Creates call to `bindMemRef` for each memref operand. | ||||
void createBindResourceCalls(LLVM::CallOp vulkanLaunchCallOp, | void createBindMemRefCalls(LLVM::CallOp vulkanLaunchCallOp, | ||||
Value vulkanRuntiem); | Value vulkanRuntime); | ||||
/// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. | |||||
void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); | |||||
public: | public: | ||||
void runOnModule() override; | void runOnModule() override; | ||||
private: | private: | ||||
LLVM::LLVMDialect *llvmDialect; | LLVM::LLVMDialect *llvmDialect; | ||||
LLVM::LLVMType llvmFloatType; | LLVM::LLVMType llvmFloatType; | ||||
LLVM::LLVMType llvmVoidType; | LLVM::LLVMType llvmVoidType; | ||||
LLVM::LLVMType llvmPointerType; | LLVM::LLVMType llvmPointerType; | ||||
LLVM::LLVMType llvmInt32Type; | LLVM::LLVMType llvmInt32Type; | ||||
LLVM::LLVMType llvmInt64Type; | LLVM::LLVMType llvmInt64Type; | ||||
}; | LLVM::LLVMType llvmMemRef1DFloat; | ||||
/// Represents operand adaptor for vulkan launch call operation, to simplify an | |||||
/// access to the lowered memref. | |||||
// TODO: We should use 'emit-c-wrappers' option to lower memref type: | |||||
// https://mlir.llvm.org/docs/ConversionToLLVMDialect/#c-compatible-wrapper-emission. | |||||
struct VulkanLaunchOpOperandAdaptor { | |||||
VulkanLaunchOpOperandAdaptor(ArrayRef<Value> values) { operands = values; } | |||||
VulkanLaunchOpOperandAdaptor(const VulkanLaunchOpOperandAdaptor &) = delete; | |||||
VulkanLaunchOpOperandAdaptor | |||||
operator=(const VulkanLaunchOpOperandAdaptor &) = delete; | |||||
/// Returns a tuple with a pointer to the memory and the size for the index-th | |||||
/// resource. | |||||
std::tuple<Value, Value> getResourceDescriptor1D(uint32_t index) { | |||||
assert(index < getResourceCount1D()); | |||||
// 1D memref calling convention according to "ConversionToLLVMDialect.md": | |||||
// 0. Allocated pointer. | |||||
// 1. Aligned pointer. | |||||
// 2. Offset. | |||||
// 3. Size in dim 0. | |||||
// 4. Stride in dim 0. | |||||
auto offset = numConfigOps + index * loweredMemRefNumOps1D; | |||||
return std::make_tuple(operands[offset], operands[offset + 3]); | |||||
} | |||||
/// Returns the number of resources assuming all operands lowered from | |||||
/// 1D memref. | |||||
uint32_t getResourceCount1D() { | |||||
return (operands.size() - numConfigOps) / loweredMemRefNumOps1D; | |||||
} | |||||
private: | // TODO: Use an associative array to support multiple vulkan launch calls. | ||||
/// The number of operands of lowered 1D memref. | std::pair<StringAttr, StringAttr> spirvAttributes; | ||||
static constexpr const uint32_t loweredMemRefNumOps1D = 5; | static constexpr const uint32_t kNumConfigOps = 6; | ||||
Not Done ReplyInline ActionsI think we can use gpu::LaunchOp::kNumConfigOperands instead of having our own here. antiagainst: I think we can use `gpu::LaunchOp::kNumConfigOperands` instead of having our own here. | |||||
Thanks for pointing on this! Fixed. denis13: Thanks for pointing on this! Fixed. | |||||
/// The number of the first config operands. | |||||
static constexpr const uint32_t numConfigOps = 6; | |||||
ArrayRef<Value> operands; | |||||
}; | }; | ||||
} // anonymous namespace | } // anonymous namespace | ||||
void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { | void VulkanLaunchFuncToVulkanCallsPass::runOnModule() { | ||||
initializeCachedTypes(); | initializeCachedTypes(); | ||||
// Collect SPIRV attributes such as `spirv_blob` and `spirv_entry_point_name`. | |||||
Not Done ReplyInline ActionsNit: SPIR-V antiagainst: Nit: SPIR-V | |||||
Thanks! denis13: Thanks! | |||||
getModule().walk([this](LLVM::CallOp op) { | getModule().walk([this](LLVM::CallOp op) { | ||||
if (isVulkanLaunchCallOp(op)) | if (isVulkanLaunchCallOp(op)) | ||||
collectSPIRVAttributes(op); | |||||
}); | |||||
// Convert vulkan launch call op into a sequence of Vulkan runtime calls. | |||||
getModule().walk([this](LLVM::CallOp op) { | |||||
if (isCiFaceVulkanLaunchCallOp(op)) | |||||
translateVulkanLaunchCall(op); | translateVulkanLaunchCall(op); | ||||
}); | }); | ||||
} | } | ||||
void VulkanLaunchFuncToVulkanCallsPass::createBindResourceCalls( | void VulkanLaunchFuncToVulkanCallsPass::collectSPIRVAttributes( | ||||
LLVM::CallOp vulkanLaunchCallOp, Value vulkanRuntime) { | LLVM::CallOp vulkanLaunchCallOp) { | ||||
if (vulkanLaunchCallOp.getNumOperands() == 6) | // Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes | ||||
// for the given vulkan launch call. | |||||
auto spirvBlobAttr = | |||||
vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); | |||||
if (!spirvBlobAttr) { | |||||
vulkanLaunchCallOp.emitError() | |||||
<< "missing " << kSPIRVBlobAttrName << " attribute"; | |||||
return signalPassFailure(); | |||||
} | |||||
auto spirvEntryPointNameAttr = | |||||
vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); | |||||
if (!spirvEntryPointNameAttr) { | |||||
vulkanLaunchCallOp.emitError() | |||||
<< "missing " << kSPIRVEntryPointAttrName << " attribute"; | |||||
return signalPassFailure(); | |||||
} | |||||
spirvAttributes = std::make_pair(spirvBlobAttr, spirvEntryPointNameAttr); | |||||
} | |||||
void VulkanLaunchFuncToVulkanCallsPass::createBindMemRefCalls( | |||||
LLVM::CallOp ciFaceVulkanLaunchCallOp, Value vulkanRuntime) { | |||||
if (ciFaceVulkanLaunchCallOp.getNumOperands() == kNumConfigOps) | |||||
return; | return; | ||||
OpBuilder builder(vulkanLaunchCallOp); | OpBuilder builder(ciFaceVulkanLaunchCallOp); | ||||
Location loc = vulkanLaunchCallOp.getLoc(); | Location loc = ciFaceVulkanLaunchCallOp.getLoc(); | ||||
// Create LLVM constant for the descriptor set index. | // Create LLVM constant for the descriptor set index. | ||||
// Bind all resources to the `0` descriptor set, the same way as `GPUToSPIRV` | // Bind all memrefs to the `0` descriptor set, the same way as `GPUToSPIRV` | ||||
// pass does. | // pass does. | ||||
Value descriptorSet = builder.create<LLVM::ConstantOp>( | Value descriptorSet = builder.create<LLVM::ConstantOp>( | ||||
loc, getInt32Type(), builder.getI32IntegerAttr(0)); | loc, getInt32Type(), builder.getI32IntegerAttr(0)); | ||||
auto operands = | |||||
Not Done ReplyInline Actionsnit: You shouldn't need to construct a vector here, just use the operand_range directly. You can then change the llvm::drop_begin to operands.drop_front(...). rriddle: nit: You shouldn't need to construct a vector here, just use the operand_range directly. You… | |||||
Not Done ReplyInline ActionsThanks for point on this! denis13: Thanks for point on this! | |||||
SmallVector<Value, 16>{ciFaceVulkanLaunchCallOp.getOperands()}; | |||||
auto operands = SmallVector<Value, 32>{vulkanLaunchCallOp.getOperands()}; | uint32_t operandIdx = 0; | ||||
VulkanLaunchOpOperandAdaptor vkLaunchOperandAdaptor(operands); | for (const auto ptrToMemRefDescriptor : | ||||
Not Done ReplyInline ActionsWhy not use llvm::enumerate here? That would remove the need for the extra index variable. rriddle: Why not use llvm::enumerate here? That would remove the need for the extra index variable. | |||||
Thanks! denis13: Thanks! | |||||
llvm::drop_begin(operands, kNumConfigOps)) { | |||||
for (auto resourceIdx : | |||||
llvm::seq<uint32_t>(0, vkLaunchOperandAdaptor.getResourceCount1D())) { | |||||
// Create LLVM constant for the descriptor binding index. | // Create LLVM constant for the descriptor binding index. | ||||
Value descriptorBinding = builder.create<LLVM::ConstantOp>( | Value descriptorBinding = builder.create<LLVM::ConstantOp>( | ||||
loc, getInt32Type(), builder.getI32IntegerAttr(resourceIdx)); | loc, getInt32Type(), builder.getI32IntegerAttr(operandIdx)); | ||||
// Get a pointer to the memory and size of that memory. | // Create call to `bindMemRef`. | ||||
auto resourceDescriptor = | |||||
vkLaunchOperandAdaptor.getResourceDescriptor1D(resourceIdx); | |||||
// Create call to `bindResource`. | |||||
builder.create<LLVM::CallOp>( | builder.create<LLVM::CallOp>( | ||||
loc, ArrayRef<Type>{getVoidType()}, | loc, ArrayRef<Type>{getVoidType()}, | ||||
builder.getSymbolRefAttr(kBindResource), | // TODO: Add support for memref with other ranks. | ||||
builder.getSymbolRefAttr(kBindMemRef1DFloat), | |||||
ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, | ArrayRef<Value>{vulkanRuntime, descriptorSet, descriptorBinding, | ||||
// Pointer to the memory. | ptrToMemRefDescriptor}); | ||||
std::get<0>(resourceDescriptor), | ++operandIdx; | ||||
// Size of the memory. | |||||
std::get<1>(resourceDescriptor)}); | |||||
} | } | ||||
} | } | ||||
void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { | void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { | ||||
ModuleOp module = getModule(); | ModuleOp module = getModule(); | ||||
OpBuilder builder(module.getBody()->getTerminator()); | OpBuilder builder(module.getBody()->getTerminator()); | ||||
if (!module.lookupSymbol(kSetEntryPoint)) { | if (!module.lookupSymbol(kSetEntryPoint)) { | ||||
Show All 23 Lines | void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { | ||||
if (!module.lookupSymbol(kRunOnVulkan)) { | if (!module.lookupSymbol(kRunOnVulkan)) { | ||||
builder.create<LLVM::LLVMFuncOp>( | builder.create<LLVM::LLVMFuncOp>( | ||||
loc, kRunOnVulkan, | loc, kRunOnVulkan, | ||||
LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, | LLVM::LLVMType::getFunctionTy(getVoidType(), {getPointerType()}, | ||||
/*isVarArg=*/false)); | /*isVarArg=*/false)); | ||||
} | } | ||||
if (!module.lookupSymbol(kBindResource)) { | if (!module.lookupSymbol(kBindMemRef1DFloat)) { | ||||
builder.create<LLVM::LLVMFuncOp>( | builder.create<LLVM::LLVMFuncOp>( | ||||
loc, kBindResource, | loc, kBindMemRef1DFloat, | ||||
LLVM::LLVMType::getFunctionTy( | LLVM::LLVMType::getFunctionTy(getVoidType(), | ||||
getVoidType(), | {getPointerType(), getInt32Type(), | ||||
{getPointerType(), getInt32Type(), getInt32Type(), | getInt32Type(), | ||||
getFloatType().getPointerTo(), getInt64Type()}, | getMemRef1DFloat().getPointerTo()}, | ||||
/*isVarArg=*/false)); | /*isVarArg=*/false)); | ||||
} | } | ||||
if (!module.lookupSymbol(kInitVulkan)) { | if (!module.lookupSymbol(kInitVulkan)) { | ||||
builder.create<LLVM::LLVMFuncOp>( | builder.create<LLVM::LLVMFuncOp>( | ||||
loc, kInitVulkan, | loc, kInitVulkan, | ||||
LLVM::LLVMType::getFunctionTy(getPointerType(), {}, | LLVM::LLVMType::getFunctionTy(getPointerType(), {}, | ||||
/*isVarArg=*/false)); | /*isVarArg=*/false)); | ||||
} | } | ||||
Show All 15 Lines | Value VulkanLaunchFuncToVulkanCallsPass::createEntryPointNameConstant( | ||||
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); | std::string entryPointGlobalName = (name + "_spv_entry_point_name").str(); | ||||
return LLVM::createGlobalString(loc, builder, entryPointGlobalName, | return LLVM::createGlobalString(loc, builder, entryPointGlobalName, | ||||
shaderName, LLVM::Linkage::Internal, | shaderName, LLVM::Linkage::Internal, | ||||
getLLVMDialect()); | getLLVMDialect()); | ||||
} | } | ||||
void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( | void VulkanLaunchFuncToVulkanCallsPass::translateVulkanLaunchCall( | ||||
LLVM::CallOp vulkanLaunchCallOp) { | LLVM::CallOp ciFaceVulkanLaunchCallOp) { | ||||
OpBuilder builder(vulkanLaunchCallOp); | OpBuilder builder(ciFaceVulkanLaunchCallOp); | ||||
Location loc = vulkanLaunchCallOp.getLoc(); | Location loc = ciFaceVulkanLaunchCallOp.getLoc(); | ||||
// Check that `kSPIRVBinary` and `kSPIRVEntryPoint` are present in attributes | |||||
// for the given vulkan launch call. | |||||
auto spirvBlobAttr = | |||||
vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVBlobAttrName); | |||||
if (!spirvBlobAttr) { | |||||
vulkanLaunchCallOp.emitError() | |||||
<< "missing " << kSPIRVBlobAttrName << " attribute"; | |||||
return signalPassFailure(); | |||||
} | |||||
auto entryPointNameAttr = | |||||
vulkanLaunchCallOp.getAttrOfType<StringAttr>(kSPIRVEntryPointAttrName); | |||||
if (!entryPointNameAttr) { | |||||
vulkanLaunchCallOp.emitError() | |||||
<< "missing " << kSPIRVEntryPointAttrName << " attribute"; | |||||
return signalPassFailure(); | |||||
} | |||||
// Create call to `initVulkan`. | // Create call to `initVulkan`. | ||||
auto initVulkanCall = builder.create<LLVM::CallOp>( | auto initVulkanCall = builder.create<LLVM::CallOp>( | ||||
loc, ArrayRef<Type>{getPointerType()}, | loc, ArrayRef<Type>{getPointerType()}, | ||||
builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); | builder.getSymbolRefAttr(kInitVulkan), ArrayRef<Value>{}); | ||||
// The result of `initVulkan` function is a pointer to Vulkan runtime, we | // The result of `initVulkan` function is a pointer to Vulkan runtime, we | ||||
// need to pass that pointer to each Vulkan runtime call. | // need to pass that pointer to each Vulkan runtime call. | ||||
auto vulkanRuntime = initVulkanCall.getResult(0); | auto vulkanRuntime = initVulkanCall.getResult(0); | ||||
// Create LLVM global with SPIR-V binary data, so we can pass a pointer with | // Create LLVM global with SPIR-V binary data, so we can pass a pointer with | ||||
// that data to runtime call. | // that data to runtime call. | ||||
Value ptrToSPIRVBinary = LLVM::createGlobalString( | Value ptrToSPIRVBinary = LLVM::createGlobalString( | ||||
loc, builder, kSPIRVBinary, spirvBlobAttr.getValue(), | loc, builder, kSPIRVBinary, spirvAttributes.first.getValue(), | ||||
LLVM::Linkage::Internal, getLLVMDialect()); | LLVM::Linkage::Internal, getLLVMDialect()); | ||||
// Create LLVM constant for the size of SPIR-V binary shader. | // Create LLVM constant for the size of SPIR-V binary shader. | ||||
Value binarySize = builder.create<LLVM::ConstantOp>( | Value binarySize = builder.create<LLVM::ConstantOp>( | ||||
loc, getInt32Type(), | loc, getInt32Type(), | ||||
builder.getI32IntegerAttr(spirvBlobAttr.getValue().size())); | builder.getI32IntegerAttr(spirvAttributes.first.getValue().size())); | ||||
// Create call to `bindResource` for each resource operand. | // Create call to `bindMemRef` for each memref operand. | ||||
createBindResourceCalls(vulkanLaunchCallOp, vulkanRuntime); | createBindMemRefCalls(ciFaceVulkanLaunchCallOp, vulkanRuntime); | ||||
// Create call to `setBinaryShader` runtime function with the given pointer to | // Create call to `setBinaryShader` runtime function with the given pointer to | ||||
// SPIR-V binary and binary size. | // SPIR-V binary and binary size. | ||||
builder.create<LLVM::CallOp>( | builder.create<LLVM::CallOp>( | ||||
loc, ArrayRef<Type>{getVoidType()}, | loc, ArrayRef<Type>{getVoidType()}, | ||||
builder.getSymbolRefAttr(kSetBinaryShader), | builder.getSymbolRefAttr(kSetBinaryShader), | ||||
ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); | ArrayRef<Value>{vulkanRuntime, ptrToSPIRVBinary, binarySize}); | ||||
// Create LLVM global with entry point name. | // Create LLVM global with entry point name. | ||||
Value entryPointName = | Value entryPointName = createEntryPointNameConstant( | ||||
createEntryPointNameConstant(entryPointNameAttr.getValue(), loc, builder); | spirvAttributes.second.getValue(), loc, builder); | ||||
// Create call to `setEntryPoint` runtime function with the given pointer to | // Create call to `setEntryPoint` runtime function with the given pointer to | ||||
// entry point name. | // entry point name. | ||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, | builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, | ||||
builder.getSymbolRefAttr(kSetEntryPoint), | builder.getSymbolRefAttr(kSetEntryPoint), | ||||
ArrayRef<Value>{vulkanRuntime, entryPointName}); | ArrayRef<Value>{vulkanRuntime, entryPointName}); | ||||
// Create number of local workgroup for each dimension. | // Create number of local workgroup for each dimension. | ||||
builder.create<LLVM::CallOp>( | builder.create<LLVM::CallOp>( | ||||
loc, ArrayRef<Type>{getVoidType()}, | loc, ArrayRef<Type>{getVoidType()}, | ||||
builder.getSymbolRefAttr(kSetNumWorkGroups), | builder.getSymbolRefAttr(kSetNumWorkGroups), | ||||
ArrayRef<Value>{vulkanRuntime, vulkanLaunchCallOp.getOperand(0), | ArrayRef<Value>{vulkanRuntime, ciFaceVulkanLaunchCallOp.getOperand(0), | ||||
vulkanLaunchCallOp.getOperand(1), | ciFaceVulkanLaunchCallOp.getOperand(1), | ||||
vulkanLaunchCallOp.getOperand(2)}); | ciFaceVulkanLaunchCallOp.getOperand(2)}); | ||||
// Create call to `runOnVulkan` runtime function. | // Create call to `runOnVulkan` runtime function. | ||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, | builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, | ||||
builder.getSymbolRefAttr(kRunOnVulkan), | builder.getSymbolRefAttr(kRunOnVulkan), | ||||
ArrayRef<Value>{vulkanRuntime}); | ArrayRef<Value>{vulkanRuntime}); | ||||
// Create call to 'deinitVulkan' runtime function. | // Create call to 'deinitVulkan' runtime function. | ||||
builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, | builder.create<LLVM::CallOp>(loc, ArrayRef<Type>{getVoidType()}, | ||||
builder.getSymbolRefAttr(kDeinitVulkan), | builder.getSymbolRefAttr(kDeinitVulkan), | ||||
ArrayRef<Value>{vulkanRuntime}); | ArrayRef<Value>{vulkanRuntime}); | ||||
// Declare runtime functions. | // Declare runtime functions. | ||||
declareVulkanFunctions(loc); | declareVulkanFunctions(loc); | ||||
vulkanLaunchCallOp.erase(); | ciFaceVulkanLaunchCallOp.erase(); | ||||
} | } | ||||
std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> | std::unique_ptr<mlir::OpPassBase<mlir::ModuleOp>> | ||||
mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { | mlir::createConvertVulkanLaunchFuncToVulkanCallsPass() { | ||||
return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); | return std::make_unique<VulkanLaunchFuncToVulkanCallsPass>(); | ||||
} | } | ||||
static PassRegistration<VulkanLaunchFuncToVulkanCallsPass> | static PassRegistration<VulkanLaunchFuncToVulkanCallsPass> | ||||
pass("launch-func-to-vulkan", | pass("launch-func-to-vulkan", | ||||
"Convert vulkanLaunch external call to Vulkan runtime external calls"); | "Convert vulkanLaunch external call to Vulkan runtime external calls"); |
Nit: isCInterfaceVulkanLaunchCallOp? Just spell it out to make it clear. Applies to other parameter names, etc.