diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -163,6 +163,16 @@ /// context. MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx); +/// Returns the typeID of a TF32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void); + +/// Checks whether the given type is an TF32 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type); + +/// Creates a TF32 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx); + //===----------------------------------------------------------------------===// // None type. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -67,6 +67,7 @@ FloatType getFloat8E4M3B11FNUZType(); FloatType getBF16Type(); FloatType getF16Type(); + FloatType getTF32Type(); FloatType getF32Type(); FloatType getF64Type(); FloatType getF80Type(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -43,6 +43,7 @@ // Convenience factories. static FloatType getBF16(MLIRContext *ctx); static FloatType getF16(MLIRContext *ctx); + static FloatType getTF32(MLIRContext *ctx); static FloatType getF32(MLIRContext *ctx); static FloatType getF64(MLIRContext *ctx); static FloatType getF80(MLIRContext *ctx); @@ -404,8 +405,8 @@ inline bool FloatType::classof(Type type) { return llvm::isa(type); + Float16Type, FloatTF32Type, Float32Type, Float64Type, + Float80Type, Float128Type>(type); } inline FloatType FloatType::getFloat8E5M2(MLIRContext *ctx) { @@ -436,6 +437,10 @@ return Float16Type::get(ctx); } +inline FloatType FloatType::getTF32(MLIRContext *ctx) { + return FloatTF32Type::get(ctx); +} + inline FloatType FloatType::getF32(MLIRContext *ctx) { return Float32Type::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -198,6 +198,13 @@ let summary = "16-bit floating-point type"; } +//===----------------------------------------------------------------------===// +// Float8E4M3B11FNUZType + +def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32"> { + let summary = "TF32 floating-point type"; +} + //===----------------------------------------------------------------------===// // Float32Type diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -570,6 +570,8 @@ def BF16 : Type, "bfloat16 type">, BuildableType<"$_builder.getBF16Type()">; +def TF32 : Type, "tf32 type">, + BuildableType<"$_builder.getTF32Type()">; def F8E4M3FN : Type, "f8E4M3FN type">, BuildableType<"$_builder.getFloat8E4M3FNType()">; def F8E5M2 : Type, "f8E5M2 type">, diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -127,6 +127,7 @@ bool isFloat8E4M3B11FNUZ() const; bool isBF16() const; bool isF16() const; + bool isTF32() const; bool isF32() const; bool isF64() const; bool isF80() const; diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -117,6 +117,7 @@ TOK_KEYWORD(strided) TOK_KEYWORD(symbol) TOK_KEYWORD(tensor) +TOK_KEYWORD(tf32) TOK_KEYWORD(to) TOK_KEYWORD(true) TOK_KEYWORD(tuple) diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -38,6 +38,7 @@ case Token::kw_f8E4M3B11FNUZ: case Token::kw_bf16: case Token::kw_f16: + case Token::kw_tf32: case Token::kw_f32: case Token::kw_f64: case Token::kw_f80: @@ -313,6 +314,9 @@ case Token::kw_f16: consumeToken(Token::kw_f16); return builder.getF16Type(); + case Token::kw_tf32: + consumeToken(Token::kw_tf32); + return builder.getTF32Type(); case Token::kw_f32: consumeToken(Token::kw_f32); return builder.getF32Type(); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -247,6 +247,26 @@ } }; +/// Floating Point Type subclass - TF32Type. +class PyTF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatTF32TypeGetTypeID; + static constexpr const char *pyClassName = "FloatTF32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirTF32TypeGet(context->get()); + return PyTF32Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a tf32 type."); + } +}; + /// Floating Point Type subclass - F32Type. class PyF32Type : public PyConcreteType { public: @@ -754,6 +774,7 @@ PyFloat8E5M2FNUZType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); + PyTF32Type::bind(m); PyF32Type::bind(m); PyF64Type::bind(m); PyNoneType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -152,6 +152,16 @@ return wrap(FloatType::getF16(unwrap(ctx))); } +MlirTypeID mlirFloatTF32TypeGetTypeID() { + return wrap(FloatTF32Type::getTypeID()); +} + +bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); } + +MlirType mlirTF32TypeGet(MlirContext ctx) { + return wrap(FloatType::getTF32(unwrap(ctx))); +} + MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2433,6 +2433,7 @@ .Case([&](Type) { os << "f8E4M3B11FNUZ"; }) .Case([&](Type) { os << "bf16"; }) .Case([&](Type) { os << "f16"; }) + .Case([&](Type) { os << "tf32"; }) .Case([&](Type) { os << "f32"; }) .Case([&](Type) { os << "f64"; }) .Case([&](Type) { os << "f80"; }) diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -58,6 +58,8 @@ FloatType Builder::getF16Type() { return FloatType::getF16(context); } +FloatType Builder::getTF32Type() { return FloatType::getTF32(context); } + FloatType Builder::getF32Type() { return FloatType::getF32(context); } FloatType Builder::getF64Type() { return FloatType::getF64(context); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -93,7 +93,7 @@ return 8; if (llvm::isa(*this)) return 16; - if (llvm::isa(*this)) + if (llvm::isa(*this)) return 32; if (llvm::isa(*this)) return 64; @@ -120,6 +120,8 @@ return APFloat::BFloat(); if (llvm::isa(*this)) return APFloat::IEEEhalf(); + if (llvm::isa(*this)) + return APFloat::FloatTF32(); if (llvm::isa(*this)) return APFloat::IEEEsingle(); if (llvm::isa(*this)) diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -219,6 +219,7 @@ Float8E4M3B11FNUZType f8E4M3B11FNUZTy; BFloat16Type bf16Ty; Float16Type f16Ty; + FloatTF32Type tf32Ty; Float32Type f32Ty; Float64Type f64Ty; Float80Type f80Ty; @@ -294,6 +295,7 @@ impl->f8E4M3B11FNUZTy = TypeUniquer::get(this); impl->bf16Ty = TypeUniquer::get(this); impl->f16Ty = TypeUniquer::get(this); + impl->tf32Ty = TypeUniquer::get(this); impl->f32Ty = TypeUniquer::get(this); impl->f64Ty = TypeUniquer::get(this); impl->f80Ty = TypeUniquer::get(this); @@ -960,6 +962,9 @@ Float16Type Float16Type::get(MLIRContext *context) { return context->getImpl().f16Ty; } +FloatTF32Type FloatTF32Type::get(MLIRContext *context) { + return context->getImpl().tf32Ty; +} Float32Type Float32Type::get(MLIRContext *context) { return context->getImpl().f32Ty; } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -47,6 +47,7 @@ } bool Type::isBF16() const { return llvm::isa(*this); } bool Type::isF16() const { return llvm::isa(*this); } +bool Type::isTF32() const { return llvm::isa(*this); } bool Type::isF32() const { return llvm::isa(*this); } bool Type::isF64() const { return llvm::isa(*this); } bool Type::isF80() const { return llvm::isa(*this); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -56,6 +56,7 @@ "Float8E4M3B11FNUZType", "Float8E5M2FNUZType", "F16Type", + "FloatTF32Type", "F32Type", "F64Type", "FlatSymbolRefAttr", @@ -627,6 +628,14 @@ @staticmethod def isinstance(arg: Any) -> bool: ... +# TODO: Auto-generated. Audit and fix. +class FloatTF32Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> FloatTF32Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + # TODO: Auto-generated. Audit and fix. class F32Type(Type): def __init__(self, cast_from_type: Type) -> None: ... diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -64,6 +64,10 @@ // CHECK: float_attr = 2.000000e+00 : bf16 float_attr = 2. : bf16 } : () -> () + "test.float_attrs"() { + // CHECK: float_attr = 2.000000e+00 : tf32 + float_attr = 2. : tf32 + } : () -> () "test.float_attrs"() { // CHECK: float_attr = 2.000000e+00 : f32 float_attr = 2. : f32 diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -212,6 +212,8 @@ print("float:", BF16Type.get()) # CHECK: float: f16 print("float:", F16Type.get()) + # CHECK: float: tf32 + print("float:", FloatTF32Type.get()) # CHECK: float: f32 print("float:", F32Type.get()) # CHECK: float: f64 diff --git a/mlir/utils/gdb-scripts/prettyprinters.py b/mlir/utils/gdb-scripts/prettyprinters.py --- a/mlir/utils/gdb-scripts/prettyprinters.py +++ b/mlir/utils/gdb-scripts/prettyprinters.py @@ -166,6 +166,7 @@ "IndexType", "IntegerType", "Float16Type", + "FloatTF32Type", "Float32Type", "Float64Type", "Float80Type", diff --git a/mlir/utils/lldb-scripts/mlirDataFormatters.py b/mlir/utils/lldb-scripts/mlirDataFormatters.py --- a/mlir/utils/lldb-scripts/mlirDataFormatters.py +++ b/mlir/utils/lldb-scripts/mlirDataFormatters.py @@ -57,6 +57,7 @@ "mlir::Float8E4M3B11FNUZType": '"f8E4M3B11FNUZ"', "mlir::BFloat16Type": '"bf16"', "mlir::Float16Type": '"f16"', + "mlir::FloatTF32Type": '"tf32"', "mlir::Float32Type": '"f32"', "mlir::Float64Type": '"f64"', "mlir::Float80Type": '"f80"',