diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -184,7 +184,8 @@ /// Convert a floating point type: `f16` to `f16`, `f32` to /// `f32` and `f64` to `f64`. `bf16` is not supported - /// by LLVM. + /// by LLVM. 8-bit float types are converted to 8-bit integers as this is how + /// all LLVM backends that support them currently represent them. Type convertFloatType(FloatType type); /// Convert complex number type: `complex` to `!llvm<"{ half, half }">`, diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -539,14 +539,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { - // ROCDL supports fp8 types in some contexts, but there is no LLVM-level f8 - // type. Therefore, for this target, declare f8 to be equal to i8. - converter.addConversion([](FloatType type) -> std::optional { - if (type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ()) - return IntegerType::get(type.getContext(), 8); - return std::nullopt; - }); - patterns.add(converter); patterns.add< RawBufferOpLowering, diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -193,7 +193,12 @@ return IntegerType::get(&getContext(), type.getWidth()); } -Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; } +Type LLVMTypeConverter::convertFloatType(FloatType type) { + if (type.isFloat8E5M2() || type.isFloat8E4M3FN() || type.isFloat8E5M2FNUZ() || + type.isFloat8E4M3FNUZ() || type.isFloat8E4M3B11FNUZ()) + return IntegerType::get(&getContext(), type.getWidth()); + return type; +} // Convert a `ComplexType` to an LLVM type. The result is a complex number // struct with entries for the diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -360,6 +360,12 @@ llvmType, intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); if (auto floatAttr = dyn_cast(attr)) { + const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics(); + // Special case for 8-bit floats, which are represented by integers due to + // the lack of native fp8 types in LLVM at the moment. + if (APFloat::getSizeInBits(sem) == 8 && llvmType->isIntegerTy(8)) + return llvm::ConstantInt::get(llvmType, + floatAttr.getValue().bitcastToAPInt()); if (llvmType != llvm::Type::getFloatingPointTy(llvmType->getContext(), floatAttr.getValue().getSemantics())) { diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir --- a/mlir/test/Target/LLVMIR/llvmir.mlir +++ b/mlir/test/Target/LLVMIR/llvmir.mlir @@ -55,6 +55,21 @@ // CHECK: @int_global_undef = internal global i64 undef llvm.mlir.global internal @int_global_undef() : i64 +// CHECK: @f8E4M3FN_global_as_i8 = internal global i8 60 +llvm.mlir.global internal @f8E4M3FN_global_as_i8(1.5 : f8E4M3FN) : i8 + +// CHECK: @f8E5M2_global_as_i8 = internal global i8 62 +llvm.mlir.global internal @f8E5M2_global_as_i8(1.5 : f8E5M2) : i8 + +// CHECK: @f8E4M3FNUZ_global_as_i8 = internal global i8 68 +llvm.mlir.global internal @f8E4M3FNUZ_global_as_i8(1.5 : f8E4M3FNUZ) : i8 + +// CHECK: @f8E5M2FNUZ_global_as_i8 = internal global i8 66 +llvm.mlir.global internal @f8E5M2FNUZ_global_as_i8(1.5 : f8E5M2FNUZ) : i8 + +// CHECK: @f8E4M3B11FNUZ_global_as_i8 = internal global i8 92 +llvm.mlir.global internal @f8E4M3B11FNUZ_global_as_i8(1.5 : f8E4M3B11FNUZ) : i8 + // CHECK: @explicit_undef = global i32 undef llvm.mlir.global external @explicit_undef() : i32 { %0 = llvm.mlir.undef : i32