diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -34,6 +34,8 @@ /// of the libc). LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreatePrintStrFn(ModuleOp moduleOp, diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -469,6 +469,8 @@ extern "C" MLIR_CRUNNERUTILS_EXPORT void printClose(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printComma(); extern "C" MLIR_CRUNNERUTILS_EXPORT void printNewline(); +extern "C" MLIR_CRUNNERUTILS_EXPORT void printF16(uint16_t bits); // bits! +extern "C" MLIR_CRUNNERUTILS_EXPORT void printBF16(uint16_t bits); // bits! //===----------------------------------------------------------------------===// // Small runtime support library for timing execution and printing GFLOPS diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1473,6 +1473,14 @@ } else if (eltType.isF64()) { printer = LLVM::lookupOrCreatePrintF64Fn(printOp->getParentOfType()); + } else if (eltType.isF16()) { + conversion = PrintConversion::Bitcast16; // bits! + printer = + LLVM::lookupOrCreatePrintF16Fn(printOp->getParentOfType()); + } else if (eltType.isBF16()) { + conversion = PrintConversion::Bitcast16; // bits! + printer = + LLVM::lookupOrCreatePrintBF16Fn(printOp->getParentOfType()); } else if (eltType.isIndex()) { printer = LLVM::lookupOrCreatePrintU64Fn(printOp->getParentOfType()); @@ -1526,7 +1534,8 @@ // clang-format off None, ZeroExt64, - SignExt64 + SignExt64, + Bitcast16 // clang-format on }; @@ -1546,6 +1555,10 @@ value = rewriter.create( loc, IntegerType::get(rewriter.getContext(), 64), value); break; + case PrintConversion::Bitcast16: + value = rewriter.create( + loc, IntegerType::get(rewriter.getContext(), 16), value); + break; case PrintConversion::None: break; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -26,6 +26,8 @@ /// part of the libc). static constexpr llvm::StringRef kPrintI64 = "printI64"; static constexpr llvm::StringRef kPrintU64 = "printU64"; +static constexpr llvm::StringRef kPrintF16 = "printF16"; +static constexpr llvm::StringRef kPrintBF16 = "printBF16"; static constexpr llvm::StringRef kPrintF32 = "printF32"; static constexpr llvm::StringRef kPrintF64 = "printF64"; static constexpr llvm::StringRef kPrintStr = "puts"; @@ -67,6 +69,18 @@ LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) { + return lookupOrCreateFn(moduleOp, kPrintBF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { return lookupOrCreateFn(moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), diff --git a/mlir/lib/ExecutionEngine/Float16bits.cpp b/mlir/lib/ExecutionEngine/Float16bits.cpp --- a/mlir/lib/ExecutionEngine/Float16bits.cpp +++ b/mlir/lib/ExecutionEngine/Float16bits.cpp @@ -192,4 +192,16 @@ return __truncsfbf2(static_cast(d)); } +// Provide these to the CRunner with the local float16 knowledge. +extern "C" void printF16(uint16_t bits) { + f16 f; + std::memcpy(&f, &bits, sizeof(f16)); + std::cout << f; +} +extern "C" void printBF16(uint16_t bits) { + bf16 f; + std::memcpy(&f, &bits, sizeof(bf16)); + std::cout << f; +} + #endif // MLIR_FLOAT16_DEFINE_FUNCTIONS diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-print-fp.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-scf-to-cf -convert-vector-to-llvm -convert-func-to-llvm -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +// +// Test various floating-point types. +// +func.func @entry() { + %0 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf64> + vector.print %0 : vector<5xf64> + // CHECK: ( -1000, -1.1, 0, 1.1, 1000 ) + + %1 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf32> + vector.print %1 : vector<5xf32> + // CHECK: ( -1000, -1.1, 0, 1.1, 1000 ) + + %2 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xf16> + vector.print %2 : vector<5xf16> + // CHECK: ( -1000, -1.09961, 0, 1.09961, 1000 ) + + %3 = arith.constant dense<[-1000.0, -1.1, 0.0, 1.1, 1000.0]> : vector<5xbf16> + vector.print %3 : vector<5xbf16> + // CHECK: ( -1000, -1.10156, 0, 1.10156, 1000 ) + + return +}