diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -67,6 +67,16 @@ int addressSpace; }; +/// Lowering of gpu.printf to a vprintf standard library. +struct GPUPrintfOpToVPrintfLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + struct GPUReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -172,7 +172,17 @@ return success(); } -static const char formatStringPrefix[] = "printfFormat_"; +static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) { + const char formatStringPrefix[] = "printfFormat_"; + // Get a unique global name. + unsigned stringNumber = 0; + SmallString<16> stringConstName; + do { + stringConstName.clear(); + (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); + } while (moduleOp.lookupSymbol(stringConstName)); + return stringConstName; +} template static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, @@ -225,13 +235,8 @@ auto printfBeginCall = rewriter.create(loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(); - // Create a global constant for the format string - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); + // Get a unique global name for the format. + SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); llvm::SmallString<20> formatString(adaptor.getFormat()); formatString.push_back('\0'); // Null terminate for C @@ -320,13 +325,8 @@ LLVM::LLVMFuncOp printfDecl = getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); - // Create a global constant for the format string - unsigned stringNumber = 0; - SmallString<16> stringConstName; - do { - stringConstName.clear(); - (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); - } while (moduleOp.lookupSymbol(stringConstName)); + // Get a unique global name for the format. + SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); llvm::SmallString<20> formatString(adaptor.getFormat()); formatString.push_back('\0'); // Null terminate for C @@ -359,6 +359,80 @@ return success(); } +LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite( + gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Location loc = gpuPrintfOp->getLoc(); + + mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); + mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8); + + // Note: this is the GPUModule op, not the ModuleOp that surrounds it + // This ensures that global constants and declarations are placed within + // the device code, not the host code + auto moduleOp = gpuPrintfOp->getParentOfType(); + + auto vprintfType = + LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr, i8Ptr}); + LLVM::LLVMFuncOp vprintfDecl = + getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType); + + // Get a unique global name for the format. + SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp); + + llvm::SmallString<20> formatString(adaptor.getFormat()); + formatString.push_back('\0'); // Null terminate for C + auto globalType = + LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString), /*allignment=*/0); + } + + // Get a pointer to the format string's first element + Value globalPtr = rewriter.create(loc, global); + Value stringStart = rewriter.create( + loc, i8Ptr, globalPtr, ArrayRef{0, 0}); + SmallVector types; + SmallVector args; + // Promote and pack the arguments into a stack allocation. + for (Value arg : adaptor.getArgs()) { + Type type = arg.getType(); + Value promotedArg = arg; + assert(type.isIntOrFloat()); + if (type.isa()) { + type = rewriter.getF64Type(); + promotedArg = rewriter.create(loc, type, arg); + } + types.push_back(type); + args.push_back(promotedArg); + } + Type structType = + LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types); + Type structPtrType = LLVM::LLVMPointerType::get(structType); + Value one = rewriter.create(loc, rewriter.getI64Type(), + rewriter.getIndexAttr(1)); + Value tempAlloc = rewriter.create(loc, structPtrType, one, + /*alignment=*/0); + for (auto [index, arg] : llvm::enumerate(args)) { + Value ptr = rewriter.create( + loc, LLVM::LLVMPointerType::get(arg.getType()), tempAlloc, + ArrayRef{0, index}); + rewriter.create(loc, arg, ptr); + } + tempAlloc = rewriter.create(loc, i8Ptr, tempAlloc); + std::array printfArgs = {stringStart, tempAlloc}; + + rewriter.create(loc, vprintfDecl, printfArgs); + rewriter.eraseOp(gpuPrintfOp); + return success(); +} + /// Unrolls op if it's operating on vectors. LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -239,6 +239,7 @@ void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { populateWithGenerated(patterns); + patterns.add(converter); patterns .add, diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -501,3 +501,42 @@ gpu.return } } + +// ----- + +gpu.module @test_module { + // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL0:[A-Za-z0-9_]+]]("Hello, world\0A\00") + // CHECK-DAG: llvm.mlir.global internal constant @[[$PRINT_GLOBAL1:[A-Za-z0-9_]+]]("Hello: %d\0A\00") + // CHECK-DAG: llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32 + + // CHECK-LABEL: func @test_const_printf + gpu.func @test_const_printf() { + // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr> + // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr>) -> !llvm.ptr + // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr> + // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr> to !llvm.ptr + // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 + gpu.printf "Hello, world\n" + gpu.return + } + + // CHECK-LABEL: func @test_printf + // CHECK: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: f32) + gpu.func @test_printf(%arg0: i32, %arg1: f32) { + // CHECK-NEXT: %[[FORMATSTR:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL1]] : !llvm.ptr> + // CHECK-NEXT: %[[FORMATSTART:.*]] = llvm.getelementptr %[[FORMATSTR]][0, 0] : (!llvm.ptr>) -> !llvm.ptr + // CHECK-NEXT: %[[EXT:.+]] = llvm.fpext %[[ARG1]] : f32 to f64 + // CHECK-NEXT: %[[O:.*]] = llvm.mlir.constant(1 : index) : i64 + // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<(i32, f64)> : (i64) -> !llvm.ptr> + // CHECK-NEXT: %[[EL0:.*]] = llvm.getelementptr %[[ALLOC]][0, 0] : (!llvm.ptr>) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[ARG0]], %[[EL0]] : !llvm.ptr + // CHECK-NEXT: %[[EL1:.*]] = llvm.getelementptr %[[ALLOC]][0, 1] : (!llvm.ptr>) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[EXT]], %[[EL1]] : !llvm.ptr + // CHECK-NEXT: %[[ARGPTR:.*]] = llvm.bitcast %[[ALLOC]] : !llvm.ptr> to !llvm.ptr + // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ARGPTR]]) : (!llvm.ptr, !llvm.ptr) -> i32 + gpu.printf "Hello: %d\n" %arg0, %arg1 : i32, f32 + gpu.return + } +} + diff --git a/mlir/test/Integration/GPU/CUDA/printf.mlir b/mlir/test/Integration/GPU/CUDA/printf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/printf.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,gpu-to-cubin))' \ +// RUN: | mlir-opt -gpu-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_cuda_runtime%shlibext \ +// RUN: --shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +// CHECK: Hello from 0, 2, 3.000000 +// CHECK: Hello from 1, 2, 3.000000 +module attributes {gpu.container_module} { + gpu.module @kernels { + gpu.func @hello() kernel { + %0 = gpu.thread_id x + %csti8 = arith.constant 2 : i8 + %cstf32 = arith.constant 3.0 : f32 + gpu.printf "Hello from %lld, %d, %f\n" %0, %csti8, %cstf32 : index, i8, f32 + gpu.return + } + } + + func.func @main() { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + gpu.launch_func @kernels::@hello + blocks in (%c1, %c1, %c1) + threads in (%c2, %c1, %c1) + return + } +}