diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -64,10 +64,10 @@ /// Collect a set of patterns to convert from the GPU dialect to LLVM and /// populate converter for gpu types. -void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns, - StringRef gpuBinaryAnnotation = {}, - bool kernelBarePtrCallConv = false); +void populateGpuToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + StringRef gpuBinaryAnnotation = {}, bool kernelBarePtrCallConv = false, + SymbolTable *cachedModuleTable = nullptr); } // namespace mlir diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -410,10 +410,12 @@ public: ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, StringRef gpuBinaryAnnotation, - bool kernelBarePtrCallConv) + bool kernelBarePtrCallConv, + SymbolTable *cachedModuleTable) : ConvertOpToGpuRuntimeCallPattern(typeConverter), gpuBinaryAnnotation(gpuBinaryAnnotation), - kernelBarePtrCallConv(kernelBarePtrCallConv) {} + kernelBarePtrCallConv(kernelBarePtrCallConv), + cachedModuleTable(cachedModuleTable) {} private: Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, @@ -427,6 +429,7 @@ llvm::SmallString<32> gpuBinaryAnnotation; bool kernelBarePtrCallConv; + SymbolTable *cachedModuleTable; }; class EraseGpuModuleOpPattern : public OpRewritePattern { @@ -663,7 +666,23 @@ RewritePatternSet patterns(&getContext()); LLVMConversionTarget target(getContext()); - target.addIllegalDialect(); + SymbolTable symbolTable = SymbolTable(getOperation()); + // Preserve GPU modules if they have target attributes. + target.addDynamicallyLegalOp( + [](gpu::GPUModuleOp module) -> bool { + return module.getTargetsAttr() != nullptr; + }); + // Accept as legal LaunchFuncOps if they refer to GPU Modules with targets and + // the operands have been lowered. + target.addDynamicallyLegalOp( + [&](gpu::LaunchFuncOp op) -> bool { + auto module = + symbolTable.lookup(op.getKernelModuleName()); + return converter.isLegal(op->getOperandTypes()) && + converter.isLegal(op->getResultTypes()) && + (module && module.getTargetsAttr() && + module.getTargetsAttr().size()); + }); mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); @@ -673,7 +692,7 @@ populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, target); populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation, - kernelBarePtrCallConv); + kernelBarePtrCallConv, &symbolTable); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -1161,10 +1180,58 @@ // Create an LLVM global with CUBIN extracted from the kernel annotation and // obtain a pointer to the first byte in it. - auto kernelModule = SymbolTable::lookupNearestSymbolFrom( - launchOp, launchOp.getKernelModuleName()); + gpu::GPUModuleOp kernelModule; + if (cachedModuleTable) + kernelModule = cachedModuleTable->lookup( + launchOp.getKernelModuleName()); + else + kernelModule = SymbolTable::lookupNearestSymbolFrom( + launchOp, launchOp.getKernelModuleName()); assert(kernelModule && "expected a kernel module"); + // If the module has Targets then just update the op operands. + if (ArrayAttr targets = kernelModule.getTargetsAttr()) { + Value stream = Value(); + if (adaptor.getAsyncDependencies().size()) + stream = adaptor.getAsyncDependencies().front(); + // If the async keyword is present and there are no dependencies, then a + // stream must be created to pass to subsequent operations. + else if (launchOp.getAsyncToken()) + stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(); + + // Lower the kernel operands to match kernel parameters. + SmallVector arguments; + if (kernelBarePtrCallConv) { + // Hack the bare pointer value on just for the argument promotion + LLVMTypeConverter *converter = getTypeConverter(); + LowerToLLVMOptions options = converter->getOptions(); + LowerToLLVMOptions overrideToMatchKernelOpts = options; + overrideToMatchKernelOpts.useBarePtrCallConv = true; + converter->dangerousSetOptions(overrideToMatchKernelOpts); + arguments = + converter->promoteOperands(loc, launchOp.getKernelOperands(), + adaptor.getKernelOperands(), rewriter); + converter->dangerousSetOptions(options); + } else { + arguments = getTypeConverter()->promoteOperands( + loc, launchOp.getKernelOperands(), adaptor.getKernelOperands(), + rewriter); + } + + rewriter.create( + launchOp.getLoc(), launchOp.getKernelAttr(), + gpu::KernelDim3{adaptor.getGridSizeX(), adaptor.getGridSizeY(), + adaptor.getGridSizeZ()}, + gpu::KernelDim3{adaptor.getBlockSizeX(), adaptor.getBlockSizeY(), + adaptor.getBlockSizeZ()}, + adaptor.getDynamicSharedMemorySize(), arguments, stream); + if (launchOp.getAsyncToken()) + rewriter.replaceOp(launchOp, {stream}); + else + rewriter.eraseOp(launchOp); + return success(); + } + auto binaryAttr = kernelModule->getAttrOfType(gpuBinaryAnnotation); if (!binaryAttr) { @@ -1775,7 +1842,8 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef gpuBinaryAnnotation, - bool kernelBarePtrCallConv) { + bool kernelBarePtrCallConv, + SymbolTable *cachedModuleTable) { addOpaquePointerConversion(converter); addOpaquePointerConversion(converter); addOpaquePointerConversion(converter); @@ -1804,6 +1872,6 @@ ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern, ConvertSDDMMOpToGpuRuntimeCallPattern>(converter); patterns.add( - converter, gpuBinaryAnnotation, kernelBarePtrCallConv); + converter, gpuBinaryAnnotation, kernelBarePtrCallConv, cachedModuleTable); patterns.add(&converter.getContext()); } diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-bare-ptr.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s --gpu-to-llvm="use-bare-pointers-for-kernels=1" -split-input-file | FileCheck %s + +module attributes {gpu.container_module} { + gpu.module @kernels [#nvvm.target] { + llvm.func @kernel_1(%arg0: f32, %arg1: !llvm.ptr<1>) attributes {gpu.kernel, nvvm.kernel} { + %0 = llvm.mlir.undef : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %1 = llvm.insertvalue %arg1, %0[0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %3 = llvm.mlir.constant(0 : index) : i64 + %4 = llvm.insertvalue %3, %2[2] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.mlir.constant(10 : index) : i64 + %6 = llvm.insertvalue %5, %4[3, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + %7 = llvm.mlir.constant(1 : index) : i64 + %8 = llvm.insertvalue %7, %6[4, 0] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + llvm.return + } + } + func.func @foo() { + // CHECK: [[MEMREF:%.*]] = gpu.alloc () : memref<10xf32, 1> + // CHECK: [[DESCRIPTOR:%.*]] = builtin.unrealized_conversion_cast [[MEMREF]] : memref<10xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: [[PTR:%.*]] = llvm.extractvalue [[DESCRIPTOR]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: gpu.launch_func @kernels::@kernel_1 blocks in ({{.*}}) threads in ({{.*}}) : i64 + // CHECK: args(%{{.*}} : f32, [[PTR]] : !llvm.ptr<1>) + %0 = arith.constant 0. : f32 + %1 = gpu.alloc () : memref<10xf32, 1> + %c8 = arith.constant 8 : index + gpu.launch_func @kernels::@kernel_1 blocks in (%c8, %c8, %c8) threads in (%c8, %c8, %c8) args(%0 : f32, %1 : memref<10xf32, 1>) + return + } +} diff --git a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir --- a/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-launch-func-to-gpu-runtime-calls.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=nvvm.cubin use-opaque-pointers=1" | FileCheck %s -// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=rocdl.hsaco use-opaque-pointers=1" | FileCheck %s --check-prefix=ROCDL +// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=nvvm.cubin use-opaque-pointers=1" -split-input-file | FileCheck %s +// RUN: mlir-opt %s --gpu-to-llvm="gpu-binary-annotation=rocdl.hsaco use-opaque-pointers=1" -split-input-file | FileCheck %s --check-prefix=ROCDL module attributes {gpu.container_module} { @@ -61,3 +61,37 @@ // CHECK: llvm.call @mgpuStreamDestroy // CHECK: llvm.call @mgpuModuleUnload } + +// ----- + +module attributes {gpu.container_module} { + // CHECK: gpu.module + // ROCDL: gpu.module + gpu.module @kernel_module [#nvvm.target] { + llvm.func @kernel(%arg0: i32, %arg1: !llvm.ptr, + %arg2: !llvm.ptr, %arg3: i64, %arg4: i64, + %arg5: i64) attributes {gpu.kernel} { + llvm.return + } + } + + func.func @foo(%buffer: memref) { + // CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : index) : i64 + // CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32 + // CHECK: [[C256:%.*]] = llvm.mlir.constant(256 : i32) : i32 + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : i32 + %c256 = arith.constant 256 : i32 + + // CHECK: gpu.launch_func @kernel_module::@kernel + // CHECK: blocks in ([[C8]], [[C8]], [[C8]]) threads in ([[C8]], [[C8]], [[C8]]) : i64 + // CHECK: dynamic_shared_memory_size [[C256]] + // CHECK: args([[C32]] : i32, %{{.*}} : !llvm.ptr, %{{.*}} : !llvm.ptr, %{{.*}} : i64, %{{.*}} : i64, %{{.*}} : i64) + gpu.launch_func @kernel_module::@kernel + blocks in (%c8, %c8, %c8) + threads in (%c8, %c8, %c8) + dynamic_shared_memory_size %c256 + args(%c32 : i32, %buffer : memref) + return + } +}