diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/CMakeLists.txt @@ -6,8 +6,10 @@ LINK_LIBS PUBLIC MLIRIR - MLIRGPUDialect MLIRLLVMDialect MLIRSupport MLIRTargetLLVMIRExport + + PRIVATE + MLIRGPUDialect ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.cpp @@ -12,10 +12,25 @@ #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; namespace { +LogicalResult launchKernel(gpu::LaunchFuncOp launchOp, + llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto kernelBinary = SymbolTable::lookupNearestSymbolFrom( + launchOp, launchOp.getKernelModuleName()); + if (!kernelBinary) { + launchOp.emitError("Couldn't find the binary holding the kernel: ") + << launchOp.getKernelModuleName(); + return failure(); + } + return dyn_cast( + kernelBinary.getObjectManagerAttr()) + .launchKernel(launchOp, kernelBinary, builder, moduleTranslation); +} class GPUDialectLLVMIRTranslationInterface : public LLVMTranslationDialectInterface { @@ -23,9 +38,21 @@ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; LogicalResult - convertOperation(Operation *op, llvm::IRBuilderBase &builder, + convertOperation(Operation *operation, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) const override { - return isa(op) ? success() : failure(); + return llvm::TypeSwitch(operation) + .Case([&](gpu::GPUModuleOp) { return success(); }) + .Case([&](gpu::BinaryOp op) { + return dyn_cast( + op.getObjectManagerAttr()) + .embedBinary(op, builder, moduleTranslation); + }) + .Case([&](gpu::LaunchFuncOp op) { + return launchKernel(op, builder, moduleTranslation); + }) + .Default([&](Operation *op) { + return op->emitError("unsupported GPU operation: ") << op->getName(); + }); } }; diff --git a/mlir/test/Target/LLVMIR/gpu.mlir b/mlir/test/Target/LLVMIR/gpu.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/gpu.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +module attributes {gpu.container_module} { + // CHECK: [[STRUCTTY:%.*]] = type { i32, i32 } + // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8 + // CHECK: @kernel_module_kernel_kernel_name = private unnamed_addr constant [7 x i8] c"kernel\00", align 1 + gpu.binary @kernel_module [#gpu.object<#gpu.nvptx, "BLOB">] + llvm.func @foo() { + // CHECK: [[STURCT:%.*]] = alloca %{{.*}}, align 8 + // CHECK: [[ARRAY:%.*]] = alloca ptr, i64 2, align 8 + // CHECK: [[ARG0:%.*]] = getelementptr inbounds [[STRUCTTY]], ptr [[STURCT]], i32 0, i32 0 + // CHECK: store i32 32, ptr [[ARG0]], align 4 + // CHECK: %{{.*}} = getelementptr ptr, ptr [[ARRAY]], i32 0 + // CHECK: store ptr [[ARG0]], ptr %{{.*}}, align 8 + // CHECK: [[ARG1:%.*]] = getelementptr inbounds [[STRUCTTY]], ptr [[STURCT]], i32 0, i32 1 + // CHECK: store i32 32, ptr [[ARG1]], align 4 + // CHECK: %{{.*}} = getelementptr ptr, ptr [[ARRAY]], i32 1 + // CHECK: store ptr [[ARG1]], ptr %{{.*}}, align 8 + // CHECK: [[MODULE:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst) + // CHECK: [[FUNC:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[MODULE]], ptr @kernel_module_kernel_kernel_name) + // CHECK: [[STREAM:%.*]] = call ptr @mgpuStreamCreate() + // CHECK: call void @mgpuLaunchKernel(ptr [[FUNC]], i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i32 256, ptr [[STREAM]], ptr [[ARRAY]], ptr null) + // CHECK: call void @mgpuStreamSynchronize(ptr [[STREAM]]) + // CHECK: call void @mgpuStreamDestroy(ptr [[STREAM]]) + // CHECK: call void @mgpuModuleUnload(ptr [[MODULE]]) + %0 = llvm.mlir.constant(8 : index) : i64 + %1 = llvm.mlir.constant(32 : i32) : i32 + %2 = llvm.mlir.constant(256 : i32) : i32 + gpu.launch_func @kernel_module::@kernel blocks in (%0, %0, %0) : i64 threads in (%0, %0, %0) : i64 dynamic_shared_memory_size %2 args(%1 : i32, %1 : i32) + llvm.return + } +} + +// ----- + +module attributes {gpu.container_module} { + // CHECK: @kernel_module_bin_cst = internal constant [1 x i8] c"1", align 8 + gpu.binary @kernel_module <#gpu.select_object<1>> [#gpu.object<#gpu.nvptx, "0">, #gpu.object<#gpu.nvptx, "1">] +} + +// ----- + +module attributes {gpu.container_module} { + // CHECK: @kernel_module_bin_cst = internal constant [6 x i8] c"AMDGPU", align 8 + gpu.binary @kernel_module <#gpu.select_object<#gpu.amdgpu>> [#gpu.object<#gpu.nvptx, "NVPTX">, #gpu.object<#gpu.amdgpu, "AMDGPU">] +} + +// ----- + +module attributes {gpu.container_module} { + // CHECK: @kernel_module_bin_cst = internal constant [4 x i8] c"BLOB", align 8 + gpu.binary @kernel_module [#gpu.object<#gpu.amdgpu, "BLOB">] + llvm.func @foo() { + %0 = llvm.mlir.constant(8 : index) : i64 + // CHECK: = call ptr @mgpuStreamCreate() + // CHECK-NEXT: = alloca {{.*}}, align 8 + // CHECK-NEXT: [[ARGS:%.*]] = alloca ptr, i64 0, align 8 + // CHECK-NEXT: [[MODULE:%.*]] = call ptr @mgpuModuleLoad(ptr @kernel_module_bin_cst) + // CHECK-NEXT: [[FUNC:%.*]] = call ptr @mgpuModuleGetFunction(ptr [[MODULE]], ptr @kernel_module_kernel_kernel_name) + // CHECK-NEXT: call void @mgpuLaunchKernel(ptr [[FUNC]], i64 8, i64 8, i64 8, i64 8, i64 8, i64 8, i32 0, ptr {{.*}}, ptr [[ARGS]], ptr null) + // CHECK-NEXT: call void @mgpuModuleUnload(ptr [[MODULE]]) + // CHECK-NEXT: call void @mgpuStreamSynchronize(ptr %{{.*}}) + // CHECK-NEXT: call void @mgpuStreamDestroy(ptr %{{.*}}) + %1 = llvm.call @mgpuStreamCreate() : () -> !llvm.ptr + gpu.launch_func <%1 !llvm.ptr> @kernel_module::@kernel blocks in (%0, %0, %0) : i64 threads in (%0, %0, %0) : i64 + llvm.call @mgpuStreamSynchronize(%1) : (!llvm.ptr) -> () + llvm.call @mgpuStreamDestroy(%1) : (!llvm.ptr) -> () + llvm.return + } + llvm.func @mgpuStreamCreate() -> !llvm.ptr + llvm.func @mgpuStreamSynchronize(!llvm.ptr) + llvm.func @mgpuStreamDestroy(!llvm.ptr) +}