diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -526,6 +526,23 @@ let hasCanonicalizer = 1; } +def GPU_PrintfOp : GPU_Op<"printf", [MemoryEffects<[MemWrite]>]>, + Arguments<(ins StrAttr:$format, + Variadic:$args)> { + let summary = "Device-side printf, as in CUDA or OpenCL, for debugging"; + let description = [{ + `gpu.printf` takes a literal format string `format` and an arbitrary number of + scalar arguments that should be printed. + + The format string is a C-style printf string, subject to any restrictions + imposed by one's target platform. + }]; + let assemblyFormat = [{ + attr-dict ($args^ `:` type($args))? + }]; + let verifier = [{ return ::verify(*this); }]; +} + def GPU_ReturnOp : GPU_Op<"return", [HasParent<"GPUFuncOp">, NoSideEffect, Terminator]>, Arguments<(ins Variadic:$operands)>, Results<(outs)> { diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h @@ -44,6 +44,14 @@ } }; +struct GPUPrintfOpToLLVMCallLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::PrintfOp gpuPrintfOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_GPUOPSLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -144,3 +144,62 @@ rewriter.eraseOp(gpuFuncOp); return success(); } + +LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( + gpu::PrintfOp gpuPrintfOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + Location loc = gpuPrintfOp->getLoc(); + gpu::PrintfOpAdaptor op(operands, gpuPrintfOp->getAttrDictionary()); + + mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); + mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8); + auto moduleOp = gpuPrintfOp->getParentOfType(); + mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType()); + + LLVM::LLVMFuncOp printf = nullptr; + // Declare printf if it doesn't exist + if (!(printf = moduleOp.lookupSymbol("printf"))) { + auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), + {i8Ptr}, /*isVarArg=*/true); + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + printf = rewriter.create(loc, "printf", printfType, + LLVM::Linkage::External); + } + + unsigned stringNumber = 0; + std::string stringConstName; + do { + stringConstName = llvm::formatv("printfFormat_{0}", stringNumber++); + } while (moduleOp.lookupSymbol(stringConstName)); + + llvm::SmallString<20> formatString(op.format().getValue()); + formatString.push_back('\0'); // Null terminate for C + auto globalType = + LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); + LLVM::GlobalOp global; + { + ConversionPatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(moduleOp.getBody()); + global = rewriter.create( + loc, globalType, + /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, + rewriter.getStringAttr(formatString)); + } + + Value globalPtr = rewriter.create(loc, global); + Value zero = rewriter.create( + loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0)); + Value stringStart = rewriter.create( + loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero})); + + auto &&argsRange = op.args(); + SmallVector printfArgs; + printfArgs.push_back(stringStart); + printfArgs.reserve(argsRange.size() + 1); + std::copy(argsRange.begin(), argsRange.end(), std::back_inserter(printfArgs)); + + rewriter.create(loc, printf, printfArgs); + rewriter.eraseOp(gpuPrintfOp); + return success(); +} diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -111,7 +111,7 @@ ROCDL::BlockIdYOp, ROCDL::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPUReturnOpLowering>(converter); + GPUReturnOpLowering, GPUPrintfOpToLLVMCallLowering>(converter); patterns.add( converter, /*allocaAddrSpace=*/5, Identifier::get(ROCDL::ROCDLDialect::getKernelFuncAttrName(), diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -901,6 +901,20 @@ return success(); } +//===----------------------------------------------------------------------===// +// PrintfOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(gpu::PrintfOp op) { + for (const auto &ty : op.getOperandTypes()) { + if (!ty.isIntOrIndexOrFloat()) { + op.emitOpError("Arguments to printf() must be scalars"); + return failure(); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -386,3 +386,19 @@ gpu.return } } + +// ----- + +gpu.module @test_module { + // CHECK: llvm.mlir.global internal constant @[[$PRINT_GLOBAL:[A-Za-z0-9_]+]]("Hello, World\0A\00") + // CHECK: llvm.func @printf(!llvm.ptr, ...) -> i32 + // CHECK-LABEL: func @test_printf + gpu.func @test_printf() { + // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr> + // CHECK-NEXT: %[[IMM1:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][%[[IMM1]], %[[IMM1]]] : (!llvm.ptr>, i64, i64) -> !llvm.ptr + // CHECK-NEXT: %{{.*}} = llvm.call @printf(%[[IMM2]]) : (!llvm.ptr) -> i32 + gpu.printf { format = "Hello, World\n" } + gpu.return + } +} diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -109,6 +109,14 @@ gpu.return } + // CHECK-LABEL gpu.func @printf_test + // CHECK: (%[[ARG0:.*]]: i32) + // CHECK: gpu.printf {format = "Value: %d"} %[[ARG0]] : i32 + gpu.func @printf_test(%arg0 : i32) { + gpu.printf {format = "Value: %d"} %arg0 : i32 + gpu.return + } + // CHECK-LABEL: @no_attribution_attrs // CHECK: attributes // CHECK: {