diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -55,7 +55,8 @@ bool isSupportedType(Type type) { // TODO(denis0x0D): Handle other types. if (auto memRefType = type.dyn_cast_or_null()) - return memRefType.hasRank() && memRefType.getRank() == 1; + return memRefType.hasRank() && + (memRefType.getRank() == 1 || memRefType.getRank() == 2); return false; } @@ -98,7 +99,8 @@ // Check that all operands have supported types except those for the launch // configuration. - for (auto type : llvm::drop_begin(vulkanLaunchTypes, 6)) { + for (auto type : + llvm::drop_begin(vulkanLaunchTypes, gpu::LaunchOp::kNumConfigOperands)) { if (!isSupportedType(type)) return launchOp.emitError() << type << " is unsupported to run on Vulkan"; } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -24,10 +24,12 @@ #include "mlir/Pass/Pass.h" #include "llvm/ADT/SmallString.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; static constexpr const char *kBindMemRef1DFloat = "bindMemRef1DFloat"; +static constexpr const char *kBindMemRef2DFloat = "bindMemRef2DFloat"; static constexpr const char *kCInterfaceVulkanLaunch = "_mlir_ciface_vulkanLaunch"; static constexpr const char *kDeinitVulkan = "deinitVulkan"; @@ -87,12 +89,20 @@ auto llvmPtrToFloatType = getFloatType().getPointerTo(); auto llvmArrayOneElementSizeType = LLVM::LLVMType::getArrayTy(getInt64Type(), 1); + auto llvmArrayTwoElementSizeType = + LLVM::LLVMType::getArrayTy(getInt64Type(), 2); // Create a type `!llvm<"{ float*, float*, i64, [1 x i64], [1 x i64]}">`. llvmMemRef1DFloat = LLVM::LLVMType::getStructTy( llvmDialect, {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), llvmArrayOneElementSizeType, llvmArrayOneElementSizeType}); + + // Create a type `!llvm<"{ float*, float*, i64, [2 x i64], [2 x i64]}">`. + llvmMemRef2DFloat = LLVM::LLVMType::getStructTy( + llvmDialect, + {llvmPtrToFloatType, llvmPtrToFloatType, getInt64Type(), + llvmArrayTwoElementSizeType, llvmArrayTwoElementSizeType}); } LLVM::LLVMType getFloatType() { return llvmFloatType; } @@ -101,6 +111,7 @@ LLVM::LLVMType getInt32Type() { return llvmInt32Type; } LLVM::LLVMType getInt64Type() { return llvmInt64Type; } LLVM::LLVMType getMemRef1DFloat() { return llvmMemRef1DFloat; } + LLVM::LLVMType getMemRef2DFloat() { return llvmMemRef2DFloat; } /// Creates a LLVM global for the given `name`. Value createEntryPointNameConstant(StringRef name, Location loc, @@ -134,6 +145,9 @@ /// Collects SPIRV attributes from the given `vulkanLaunchCallOp`. void collectSPIRVAttributes(LLVM::CallOp vulkanLaunchCallOp); + /// Deduces a rank from the given 'ptrToMemRefDescriptor`. + LogicalResult deduceMemRefRank(Value ptrToMemRefDescriptor, uint32_t &rank); + public: void runOnModule() override; @@ -145,6 +159,7 @@ LLVM::LLVMType llvmInt32Type; LLVM::LLVMType llvmInt64Type; LLVM::LLVMType llvmMemRef1DFloat; + LLVM::LLVMType llvmMemRef2DFloat; // TODO: Use an associative array to support multiple vulkan launch calls. std::pair spirvAttributes; @@ -212,16 +227,54 @@ // Create LLVM constant for the descriptor binding index. Value descriptorBinding = builder.create( loc, getInt32Type(), builder.getI32IntegerAttr(en.index())); + + auto ptrToMemRefDescriptor = en.value(); + uint32_t rank = 0; + if (failed(deduceMemRefRank(ptrToMemRefDescriptor, rank))) { + cInterfaceVulkanLaunchCallOp.emitError() + << "invalid memref descriptor " << ptrToMemRefDescriptor.getType(); + return signalPassFailure(); + } + + auto symbolName = llvm::formatv("bindMemRef{0}DFloat", rank).str(); // Create call to `bindMemRef`. builder.create( loc, ArrayRef{getVoidType()}, - // TODO: Add support for memref with other ranks. - builder.getSymbolRefAttr(kBindMemRef1DFloat), + builder.getSymbolRefAttr( + StringRef(symbolName.data(), symbolName.size())), ArrayRef{vulkanRuntime, descriptorSet, descriptorBinding, - en.value()}); + ptrToMemRefDescriptor}); } } +LogicalResult +VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRank(Value ptrToMemRefDescriptor, + uint32_t &rank) { + auto llvmPtrDescriptorTy = + ptrToMemRefDescriptor.getType().dyn_cast(); + if (!llvmPtrDescriptorTy) + return failure(); + + auto llvmDescriptorTy = llvmPtrDescriptorTy.getPointerElementTy(); + // template + // struct { + // Elem *allocated; + // Elem *aligned; + // int64_t offset; + // int64_t sizes[Rank]; // omitted when rank == 0 + // int64_t strides[Rank]; // omitted when rank == 0 + // }; + if (!llvmDescriptorTy || !llvmDescriptorTy.isStructTy()) + return failure(); + if (llvmDescriptorTy.getStructNumElements() == 3) { + rank = 0; + return success(); + } + + rank = llvmDescriptorTy.getStructElementType(3).getArrayNumElements(); + return success(); +} + void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) { ModuleOp module = getModule(); OpBuilder builder(module.getBody()->getTerminator()); @@ -268,6 +321,16 @@ /*isVarArg=*/false)); } + if (!module.lookupSymbol(kBindMemRef2DFloat)) { + builder.create( + loc, kBindMemRef2DFloat, + LLVM::LLVMType::getFunctionTy(getVoidType(), + {getPointerType(), getInt32Type(), + getInt32Type(), + getMemRef2DFloat().getPointerTo()}, + /*isVarArg=*/false)); + } + if (!module.lookupSymbol(kInitVulkan)) { builder.create( loc, kInitVulkan, diff --git a/mlir/test/mlir-vulkan-runner/mulf.mlir b/mlir/test/mlir-vulkan-runner/mulf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-vulkan-runner/mulf.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +// CHECK-COUNT-4: [6, 6, 6, 6] +module attributes { + gpu.container_module, + spv.target_env = #spv.target_env< + #spv.vce, + {max_compute_workgroup_invocations = 128 : i32, + max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> +} { + gpu.module @kernels { + gpu.func @kernel_mul(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>, %arg2 : memref<4x4xf32>) + attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} { + %x = "gpu.block_id"() {dimension = "x"} : () -> index + %y = "gpu.block_id"() {dimension = "y"} : () -> index + %1 = load %arg0[%x, %y] : memref<4x4xf32> + %2 = load %arg1[%x, %y] : memref<4x4xf32> + %3 = mulf %1, %2 : f32 + store %3, %arg2[%x, %y] : memref<4x4xf32> + gpu.return + } + } + + func @main() { + %arg0 = alloc() : memref<4x4xf32> + %arg1 = alloc() : memref<4x4xf32> + %arg2 = alloc() : memref<4x4xf32> + %0 = constant 0 : i32 + %1 = constant 1 : i32 + %2 = constant 2 : i32 + %value0 = constant 0.0 : f32 + %value1 = constant 2.0 : f32 + %value2 = constant 3.0 : f32 + %arg3 = memref_cast %arg0 : memref<4x4xf32> to memref + %arg4 = memref_cast %arg1 : memref<4x4xf32> to memref + %arg5 = memref_cast %arg2 : memref<4x4xf32> to memref + call @fillResource2DFloat(%arg3, %value1) : (memref, f32) -> () + call @fillResource2DFloat(%arg4, %value2) : (memref, f32) -> () + call @fillResource2DFloat(%arg5, %value0) : (memref, f32) -> () + + %cst1 = constant 1 : index + %cst4 = constant 4 : index + "gpu.launch_func"(%cst4, %cst4, %cst1, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = "kernel_mul", kernel_module = @kernels } + : (index, index, index, index, index, index, memref<4x4xf32>, memref<4x4xf32>, memref<4x4xf32>) -> () + %arg6 = memref_cast %arg5 : memref to memref<*xf32> + call @print_memref_f32(%arg6) : (memref<*xf32>) -> () + return + } + func @fillResource2DFloat(%0 : memref, %1 : f32) + func @print_memref_f32(%ptr : memref<*xf32>) +} + diff --git a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp --- a/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp +++ b/mlir/tools/mlir-vulkan-runner/vulkan-runtime-wrappers.cpp @@ -111,9 +111,27 @@ ->setResourceData(setIndex, bindIndex, memBuffer); } +/// Binds the given 2D float memref to the given descriptor set and descriptor +/// index. +void bindMemRef2DFloat(void *vkRuntimeManager, DescriptorSetIndex setIndex, + BindingIndex bindIndex, + MemRefDescriptor *ptr) { + VulkanHostMemoryBuffer memBuffer{ + ptr->allocated, + static_cast(ptr->sizes[0] * ptr->sizes[1] * sizeof(float))}; + reinterpret_cast(vkRuntimeManager) + ->setResourceData(setIndex, bindIndex, memBuffer); +} + /// Fills the given 1D float memref with the given float value. void _mlir_ciface_fillResource1DFloat(MemRefDescriptor *ptr, // NOLINT float value) { std::fill_n(ptr->allocated, ptr->sizes[0], value); } + +/// Fills the given 2D float memref with the given float value. +void _mlir_ciface_fillResource2DFloat(MemRefDescriptor *ptr, // NOLINT + float value) { + std::fill_n(ptr->allocated, ptr->sizes[0] * ptr->sizes[1], value); +} }