diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -34,6 +34,8 @@ /// of the libc). LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp, diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -469,6 +469,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits); // bits! +extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits! //===----------------------------------------------------------------------===// // Small runtime support library for timing execution and printing GFLOPS diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1466,16 +1466,20 @@ PrintConversion conversion = PrintConversion::None; VectorType vectorType = printType.dyn_cast(); Type eltType = vectorType ? vectorType.getElementType() : printType; + auto parent = printOp->getParentOfType(); Operation *printer; if (eltType.isF32()) { - printer = - LLVM::lookupOrCreatePrintF32Fn(printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintF32Fn(parent); } else if (eltType.isF64()) { - printer = - LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintF64Fn(parent); + } else if (eltType.isF16()) { + conversion = PrintConversion::Bitcast16; // bits! + printer = LLVM::lookupOrCreatePrintF16Fn(parent); + } else if (eltType.isBF16()) { + conversion = PrintConversion::Bitcast16; // bits! + printer = LLVM::lookupOrCreatePrintBF16Fn(parent); } else if (eltType.isIndex()) { - printer = - LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintU64Fn(parent); } else if (auto intTy = eltType.dyn_cast()) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or @@ -1485,8 +1489,7 @@ if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; - printer = LLVM::lookupOrCreatePrintU64Fn( - printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintU64Fn(parent); } else { return failure(); } @@ -1499,8 +1502,7 @@ conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; - printer = LLVM::lookupOrCreatePrintI64Fn( - printOp->getParentOfType()); + printer = LLVM::lookupOrCreatePrintI64Fn(parent); } else { return failure(); } @@ -1515,8 +1517,7 @@ emitRanks(rewriter, printOp, adaptor.getSource(), type, printer, rank, conversion); emitCall(rewriter, printOp->getLoc(), - LLVM::lookupOrCreatePrintNewlineFn( - printOp->getParentOfType())); + LLVM::lookupOrCreatePrintNewlineFn(parent)); rewriter.eraseOp(printOp); return success(); } @@ -1526,7 +1527,8 @@ // clang-format off None, ZeroExt64, - SignExt64 + SignExt64, + Bitcast16 // clang-format on }; @@ -1546,6 +1548,10 @@ value = rewriter.create( loc, IntegerType::get(rewriter.getContext(), 64), value); break; + case PrintConversion::Bitcast16: + value = rewriter.create( + loc, IntegerType::get(rewriter.getContext(), 16), value); + break; case PrintConversion::None: break; } @@ -1553,10 +1559,9 @@ return; } - emitCall(rewriter, loc, - LLVM::lookupOrCreatePrintOpenFn(op->getParentOfType())); - Operation *printComma = - LLVM::lookupOrCreatePrintCommaFn(op->getParentOfType()); + auto parent = op->getParentOfType(); + emitCall(rewriter, loc, LLVM::lookupOrCreatePrintOpenFn(parent)); + Operation *printComma = LLVM::lookupOrCreatePrintCommaFn(parent); if (rank <= 1) { auto reducedType = vectorType.getElementType(); @@ -1570,9 +1575,7 @@ if (d != dim - 1) emitCall(rewriter, loc, printComma); } - emitCall( - rewriter, loc, - LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); + emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); return; } @@ -1587,8 +1590,7 @@ if (d != dim - 1) emitCall(rewriter, loc, printComma); } - emitCall(rewriter, loc, - LLVM::lookupOrCreatePrintCloseFn(op->getParentOfType())); + emitCall(rewriter, loc, LLVM::lookupOrCreatePrintCloseFn(parent)); } // Helper to emit a call. diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -26,6 +26,8 @@ /// part of the libc). static constexpr llvm::StringRef kPrintI64 = "printI64"; static constexpr llvm::StringRef kPrintU64 = "printU64"; +static constexpr llvm::StringRef kPrintF16 = "printF16"; +static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; static constexpr llvm::StringRef kPrintStr = "puts"; @@ -67,6 +69,18 @@ LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintBF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), diff --git a/mlir/lib/ExecutionEngine/Float16bits.cpp b/mlir/lib/ExecutionEngine/Float16bits.cpp --- a/mlir/lib/ExecutionEngine/Float16bits.cpp +++ b/mlir/lib/ExecutionEngine/Float16bits.cpp @@ -192,4 +192,16 @@ return __truncsfbf2(static_cast(d)); } +// Provide these to the CRunner with the local float16 knowledge. +extern "C" void printF16(uint16_t bits) { + f16 f; + std::memcpy(&f, &bits, sizeof(f16)); + std::cout << f; +} +extern "C" void printBF16(uint16_t bits) { + bf16 f; + std::memcpy(&f, &bits, sizeof(bf16)); + std::cout << f; +} + #endif // MLIR_FLOAT16_DEFINE_FUNCTIONS diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +// +// Test various floating-point types. +// +func.func @entry() { + %0 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf64> + vector.print %0 : vector<5xf64> + // CHECK: ( -1000, -1.1, 0, 1.1, 1000 ) + + %1 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf32> + vector.print %1 : vector<5xf32> + // CHECK: ( -1000, -1.1, 0, 1.1, 1000 ) + + %2 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf16> + vector.print %2 : vector<5xf16> + // CHECK: ( -1000, -1.09961, 0, 1.09961, 1000 ) + + %3 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xbf16> + vector.print %3 : vector<5xbf16> + // CHECK: ( -1000, -1.10156, 0, 1.10156, 1000 ) + + return +}