diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -62,6 +62,7 @@ llvm::Type *getUnderlyingType() const; /// Utilities to identify types. + bool isBFloatTy() { return getUnderlyingType()->isBFloatTy(); } bool isHalfTy() { return getUnderlyingType()->isHalfTy(); } bool isFloatTy() { return getUnderlyingType()->isFloatTy(); } bool isDoubleTy() { return getUnderlyingType()->isDoubleTy(); } @@ -99,6 +100,7 @@ /// Utilities used to generate floating point types. static LLVMType getDoubleTy(LLVMDialect *dialect); static LLVMType getFloatTy(LLVMDialect *dialect); + static LLVMType getBFloatTy(LLVMDialect *dialect); static LLVMType getHalfTy(LLVMDialect *dialect); static LLVMType getFP128Ty(LLVMDialect *dialect); static LLVMType getX86_FP80Ty(LLVMDialect *dialect); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -201,9 +201,7 @@ case mlir::StandardTypes::F16: return LLVM::LLVMType::getHalfTy(llvmDialect); case mlir::StandardTypes::BF16: { - auto *mlirContext = llvmDialect->getContext(); - return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), - Type(); + return LLVM::LLVMType::getBFloatTy(llvmDialect); } default: llvm_unreachable("non-float type in convertFloatType"); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -939,8 +939,9 @@ if (auto llvmType = type.dyn_cast()) { if (llvmType.isVectorTy()) llvmType = llvmType.getVectorElementType(); - if (llvmType.isIntegerTy() || llvmType.isHalfTy() || - llvmType.isFloatTy() || llvmType.isDoubleTy()) { + if (llvmType.isIntegerTy() || llvmType.isBFloatTy() || + llvmType.isHalfTy() || llvmType.isFloatTy() || + llvmType.isDoubleTy()) { return success(); } return op.emitOpError("type must be non-index integer types, float " @@ -1500,7 +1501,8 @@ } else if (op.bin_op() == AtomicBinOp::xchg) { if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && !valType.isIntegerTy(32) && !valType.isIntegerTy(64) && - !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy()) + !valType.isBFloatTy() && !valType.isHalfTy() && !valType.isFloatTy() && + !valType.isDoubleTy()) return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) && @@ -1561,8 +1563,8 @@ "match type for all other operands"); if (!valType.isPointerTy() && !valType.isIntegerTy(8) && !valType.isIntegerTy(16) && !valType.isIntegerTy(32) && - !valType.isIntegerTy(64) && !valType.isHalfTy() && !valType.isFloatTy() && - !valType.isDoubleTy()) + !valType.isIntegerTy(64) && !valType.isBFloatTy() && + !valType.isHalfTy() && !valType.isFloatTy() && !valType.isDoubleTy()) return op.emitOpError("unexpected LLVM IR type"); if (op.success_ordering() < AtomicOrdering::monotonic || op.failure_ordering() < AtomicOrdering::monotonic) @@ -1630,7 +1632,7 @@ /// A set of LLVMTypes that are cached on construction to avoid any lookups or /// locking. LLVMType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; - LLVMType doubleTy, floatTy, halfTy, fp128Ty, x86_fp80Ty; + LLVMType doubleTy, floatTy, bfloatTy, halfTy, fp128Ty, x86_fp80Ty; LLVMType voidTy; /// A smart mutex to lock access to the llvm context. Unlike MLIR, LLVM is not @@ -1665,6 +1667,7 @@ /// Float Types. impl->doubleTy = LLVMType::get(context, llvm::Type::getDoubleTy(llvmContext)); impl->floatTy = LLVMType::get(context, llvm::Type::getFloatTy(llvmContext)); + impl->bfloatTy = LLVMType::get(context, llvm::Type::getBFloatTy(llvmContext)); impl->halfTy = LLVMType::get(context, llvm::Type::getHalfTy(llvmContext)); impl->fp128Ty = LLVMType::get(context, llvm::Type::getFP128Ty(llvmContext)); impl->x86_fp80Ty = @@ -1827,6 +1830,9 @@ LLVMType LLVMType::getFloatTy(LLVMDialect *dialect) { return dialect->impl->floatTy; } +LLVMType LLVMType::getBFloatTy(LLVMDialect *dialect) { + return dialect->impl->bfloatTy; +} LLVMType LLVMType::getHalfTy(LLVMDialect *dialect) { return dialect->impl->halfTy; } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1228,3 +1228,12 @@ // CHECK-NEXT: llvm.return %[[ARG]] return %1 : f16 } + +// ----- + +// CHECK-LABEL: func @bfloat +// CHECK-SAME: !llvm.bfloat) -> !llvm.bfloat +func @bfloat(%arg0: bf16) -> bf16 { + return %arg0 : bf16 +} +// CHECK-NEXT: return %{{.*}} : !llvm.bfloat