diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -180,25 +180,18 @@ // FloatType //===----------------------------------------------------------------------===// -class FloatType : public Type::TypeBase { +class FloatType : public Type { public: - using Base::Base; - - static FloatType get(StandardTypes::Kind kind, MLIRContext *context); + using Type::Type; // Convenience factories. - static FloatType getBF16(MLIRContext *ctx) { - return get(StandardTypes::BF16, ctx); - } - static FloatType getF16(MLIRContext *ctx) { - return get(StandardTypes::F16, ctx); - } - static FloatType getF32(MLIRContext *ctx) { - return get(StandardTypes::F32, ctx); - } - static FloatType getF64(MLIRContext *ctx) { - return get(StandardTypes::F64, ctx); - } + static FloatType getBF16(MLIRContext *ctx); + static FloatType getF16(MLIRContext *ctx); + static FloatType getF32(MLIRContext *ctx); + static FloatType getF64(MLIRContext *ctx); + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(Type type); /// Return the bitwidth of this float type. unsigned getWidth(); @@ -207,6 +200,67 @@ const llvm::fltSemantics &getFloatSemantics(); }; +//===----------------------------------------------------------------------===// +// BFloat16Type + +class BFloat16Type + : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the bfloat16 type. + static BFloat16Type get(MLIRContext *context); +}; + +inline FloatType FloatType::getBF16(MLIRContext *ctx) { + return BFloat16Type::get(ctx); +} + +//===----------------------------------------------------------------------===// +// Float16Type + +class Float16Type : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the float16 type. + static Float16Type get(MLIRContext *context); +}; + +inline FloatType FloatType::getF16(MLIRContext *ctx) { + return Float16Type::get(ctx); +} + +//===----------------------------------------------------------------------===// +// Float32Type + +class Float32Type : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the float32 type. + static Float32Type get(MLIRContext *context); +}; + +inline FloatType FloatType::getF32(MLIRContext *ctx) { + return Float32Type::get(ctx); +} + +//===----------------------------------------------------------------------===// +// Float64Type + +class Float64Type : public Type::TypeBase { +public: + using Base::Base; + + /// Return an instance of the float64 type. + static Float64Type get(MLIRContext *context); +}; + +inline FloatType FloatType::getF64(MLIRContext *ctx) { + return Float64Type::get(ctx); +} + //===----------------------------------------------------------------------===// // NoneType //===----------------------------------------------------------------------===// @@ -618,6 +672,10 @@ return type.isa(); } +inline bool FloatType::classof(Type type) { + return type.isa(); +} + inline bool ShapedType::classof(Type type) { return type.isa(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -215,19 +215,15 @@ } Type LLVMTypeConverter::convertFloatType(FloatType type) { - switch (type.getKind()) { - case mlir::StandardTypes::F32: + if (type.isa()) return LLVM::LLVMType::getFloatTy(&getContext()); - case mlir::StandardTypes::F64: + if (type.isa()) return LLVM::LLVMType::getDoubleTy(&getContext()); - case mlir::StandardTypes::F16: + if (type.isa()) return LLVM::LLVMType::getHalfTy(&getContext()); - case mlir::StandardTypes::BF16: { + if (type.isa()) return LLVM::LLVMType::getBFloatTy(&getContext()); - } - default: - llvm_unreachable("non-float type in convertFloatType"); - } + llvm_unreachable("non-float type in convertFloatType"); } // Convert a `ComplexType` to an LLVM type. The result is a complex number diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::LLVM; @@ -23,46 +24,28 @@ /// Returns the keyword to use for the given type. static StringRef getTypeKeyword(LLVMType type) { - switch (type.getKind()) { - case LLVMType::VoidType: - return "void"; - case LLVMType::HalfType: - return "half"; - case LLVMType::BFloatType: - return "bfloat"; - case LLVMType::FloatType: - return "float"; - case LLVMType::DoubleType: - return "double"; - case LLVMType::FP128Type: - return "fp128"; - case LLVMType::X86FP80Type: - return "x86_fp80"; - case LLVMType::PPCFP128Type: - return "ppc_fp128"; - case LLVMType::X86MMXType: - return "x86_mmx"; - case LLVMType::TokenType: - return "token"; - case LLVMType::LabelType: - return "label"; - case LLVMType::MetadataType: - return "metadata"; - case LLVMType::FunctionType: - return "func"; - case LLVMType::IntegerType: - return "i"; - case LLVMType::PointerType: - return "ptr"; - case LLVMType::FixedVectorType: - case LLVMType::ScalableVectorType: - return "vec"; - case LLVMType::ArrayType: - return "array"; - case LLVMType::StructType: - return "struct"; - } - llvm_unreachable("unhandled type kind"); + return TypeSwitch(type) + .Case([&](Type) { return "void"; }) + .Case([&](Type) { return "half"; }) + .Case([&](Type) { return "bfloat"; }) + .Case([&](Type) { return "float"; }) + .Case([&](Type) { return "double"; }) + .Case([&](Type) { return "fp128"; }) + .Case([&](Type) { return "x86_fp80"; }) + .Case([&](Type) { return "ppc_fp128"; }) + .Case([&](Type) { return "x86_mmx"; }) + .Case([&](Type) { return "token"; }) + .Case([&](Type) { return "label"; }) + .Case([&](Type) { return "metadata"; }) + .Case([&](Type) { return "func"; }) + .Case([&](Type) { return "i"; }) + .Case([&](Type) { return "ptr"; }) + .Case([&](Type) { return "vec"; }) + .Case([&](Type) { return "array"; }) + .Case([&](Type) { return "struct"; }) + .Default([](Type) -> StringRef { + llvm_unreachable("unexpected 'llvm' type kind"); + }); } /// Prints the body of a structure type. Uses `stack` to avoid printing @@ -153,14 +136,8 @@ return; } - unsigned kind = type.getKind(); os << getTypeKeyword(type); - // Trivial types only consist of their keyword. - if (LLVMType::FIRST_TRIVIAL_TYPE <= kind && - kind <= LLVMType::LAST_TRIVIAL_TYPE) - return; - if (auto intType = type.dyn_cast()) { os << intType.getBitWidth(); return; @@ -190,7 +167,8 @@ if (auto structType = type.dyn_cast()) return printStructType(os, structType, stack); - printFunctionType(os, type.cast(), stack); + if (auto funcType = type.dyn_cast()) + return printFunctionType(os, funcType, stack); } void mlir::LLVM::detail::printType(LLVMType type, DialectAsmPrinter &printer) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/StandardTypes.h" @@ -188,108 +189,70 @@ } unsigned CompositeType::getNumElements() const { - switch (getKind()) { - case spirv::TypeKind::Array: - return cast().getNumElements(); - case spirv::TypeKind::CooperativeMatrix: + if (auto arrayType = dyn_cast()) + return arrayType.getNumElements(); + if (auto matrixType = dyn_cast()) + return matrixType.getNumColumns(); + if (auto structType = dyn_cast()) + return structType.getNumElements(); + if (auto vectorType = dyn_cast()) + return vectorType.getNumElements(); + if (isa()) { llvm_unreachable( "invalid to query number of elements of spirv::CooperativeMatrix type"); - case spirv::TypeKind::Matrix: - return cast().getNumColumns(); - case spirv::TypeKind::RuntimeArray: + } + if (isa()) { llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); - case spirv::TypeKind::Struct: - return cast().getNumElements(); - case StandardTypes::Vector: - return cast().getNumElements(); - default: - llvm_unreachable("invalid composite type"); } + llvm_unreachable("invalid composite type"); } bool CompositeType::hasCompileTimeKnownNumElements() const { - switch (getKind()) { - case TypeKind::CooperativeMatrix: - case TypeKind::RuntimeArray: - return false; - default: - return true; - } + return !isa(); } void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { - switch (getKind()) { - case spirv::TypeKind::Array: - cast().getExtensions(extensions, storage); - break; - case spirv::TypeKind::CooperativeMatrix: - cast().getExtensions(extensions, storage); - break; - case spirv::TypeKind::Matrix: - cast().getExtensions(extensions, storage); - break; - case spirv::TypeKind::RuntimeArray: - cast().getExtensions(extensions, storage); - break; - case spirv::TypeKind::Struct: - cast().getExtensions(extensions, storage); - break; - case StandardTypes::Vector: - cast().getElementType().cast().getExtensions( - extensions, storage); - break; - default: - llvm_unreachable("invalid composite type"); - } + TypeSwitch(*this) + .Case( + [&](auto type) { type.getExtensions(extensions, storage); }) + .Case([&](VectorType type) { + return type.getElementType().cast().getExtensions( + extensions, storage); + }) + .Default([](Type) { llvm_unreachable("invalid composite type"); }); } void CompositeType::getCapabilities( SPIRVType::CapabilityArrayRefVector &capabilities, Optional storage) { - switch (getKind()) { - case spirv::TypeKind::Array: - cast().getCapabilities(capabilities, storage); - break; - case spirv::TypeKind::CooperativeMatrix: - cast().getCapabilities(capabilities, storage); - break; - case spirv::TypeKind::Matrix: - cast().getCapabilities(capabilities, storage); - break; - case spirv::TypeKind::RuntimeArray: - cast().getCapabilities(capabilities, storage); - break; - case spirv::TypeKind::Struct: - cast().getCapabilities(capabilities, storage); - break; - case StandardTypes::Vector: - cast().getElementType().cast().getCapabilities( - capabilities, storage); - break; - default: - llvm_unreachable("invalid composite type"); - } + TypeSwitch(*this) + .Case( + [&](auto type) { type.getCapabilities(capabilities, storage); }) + .Case([&](VectorType type) { + return type.getElementType().cast().getCapabilities( + capabilities, storage); + }) + .Default([](Type) { llvm_unreachable("invalid composite type"); }); } Optional CompositeType::getSizeInBytes() { - switch (getKind()) { - case spirv::TypeKind::Array: - return cast().getSizeInBytes(); - case spirv::TypeKind::Struct: - return cast().getSizeInBytes(); - case StandardTypes::Vector: { - auto elementSize = - cast().getElementType().cast().getSizeInBytes(); + if (auto arrayType = dyn_cast()) + return arrayType.getSizeInBytes(); + if (auto structType = dyn_cast()) + return structType.getSizeInBytes(); + if (auto vectorType = dyn_cast()) { + Optional elementSize = + vectorType.getElementType().cast().getSizeInBytes(); if (!elementSize) return llvm::None; - return *elementSize * cast().getNumElements(); - } - default: - return llvm::None; + return *elementSize * vectorType.getNumElements(); } + return llvm::None; } //===----------------------------------------------------------------------===// @@ -741,8 +704,7 @@ bool SPIRVType::classof(Type type) { // Allow SPIR-V dialect types - if (type.getKind() >= Type::FIRST_SPIRV_TYPE && - type.getKind() <= TypeKind::LAST_SPIRV_TYPE) + if (llvm::isa(type.getDialect())) return true; if (type.isa()) return true; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -114,28 +115,14 @@ /// Print a type registered to this dialect. void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { - switch (type.getKind()) { - case ShapeTypes::Component: - os << "component"; - return; - case ShapeTypes::Element: - os << "element"; - return; - case ShapeTypes::Size: - os << "size"; - return; - case ShapeTypes::Shape: - os << "shape"; - return; - case ShapeTypes::ValueShape: - os << "value_shape"; - return; - case ShapeTypes::Witness: - os << "witness"; - return; - default: - llvm_unreachable("unexpected 'shape' type kind"); - } + TypeSwitch(type) + .Case([&](Type) { os << "component"; }) + .Case([&](Type) { os << "element"; }) + .Case([&](Type) { os << "shape"; }) + .Case([&](Type) { os << "size"; }) + .Case([&](Type) { os << "value_shape"; }) + .Case([&](Type) { os << "witness"; }) + .Default([](Type) { llvm_unreachable("unexpected 'shape' type kind"); }); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1576,128 +1576,95 @@ } } - switch (type.getKind()) { - default: - return printDialectType(type); - - case Type::Kind::Opaque: { - auto opaqueTy = type.cast(); - printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), - opaqueTy.getTypeData()); - return; - } - case StandardTypes::Index: - os << "index"; - return; - case StandardTypes::BF16: - os << "bf16"; - return; - case StandardTypes::F16: - os << "f16"; - return; - case StandardTypes::F32: - os << "f32"; - return; - case StandardTypes::F64: - os << "f64"; - return; - - case StandardTypes::Integer: { - auto integer = type.cast(); - if (integer.isSigned()) - os << 's'; - else if (integer.isUnsigned()) - os << 'u'; - os << 'i' << integer.getWidth(); - return; - } - case Type::Kind::Function: { - auto func = type.cast(); - os << '('; - interleaveComma(func.getInputs(), [&](Type type) { printType(type); }); - os << ") -> "; - auto results = func.getResults(); - if (results.size() == 1 && !results[0].isa()) - os << results[0]; - else { - os << '('; - interleaveComma(results, [&](Type type) { printType(type); }); - os << ')'; - } - return; - } - case StandardTypes::Vector: { - auto v = type.cast(); - os << "vector<"; - for (auto dim : v.getShape()) - os << dim << 'x'; - os << v.getElementType() << '>'; - return; - } - case StandardTypes::RankedTensor: { - auto v = type.cast(); - os << "tensor<"; - for (auto dim : v.getShape()) { - if (dim < 0) - os << '?'; - else - os << dim; - os << 'x'; - } - os << v.getElementType() << '>'; - return; - } - case StandardTypes::UnrankedTensor: { - auto v = type.cast(); - os << "tensor<*x"; - printType(v.getElementType()); - os << '>'; - return; - } - case StandardTypes::MemRef: { - auto v = type.cast(); - os << "memref<"; - for (auto dim : v.getShape()) { - if (dim < 0) - os << '?'; - else - os << dim; - os << 'x'; - } - printType(v.getElementType()); - for (auto map : v.getAffineMaps()) { - os << ", "; - printAttribute(AffineMapAttr::get(map)); - } - // Only print the memory space if it is the non-default one. - if (v.getMemorySpace()) - os << ", " << v.getMemorySpace(); - os << '>'; - return; - } - case StandardTypes::UnrankedMemRef: { - auto v = type.cast(); - os << "memref<*x"; - printType(v.getElementType()); - os << '>'; - return; - } - case StandardTypes::Complex: - os << "complex<"; - printType(type.cast().getElementType()); - os << '>'; - return; - case StandardTypes::Tuple: { - auto tuple = type.cast(); - os << "tuple<"; - interleaveComma(tuple.getTypes(), [&](Type type) { printType(type); }); - os << '>'; - return; - } - case StandardTypes::None: - os << "none"; - return; - } + TypeSwitch(type) + .Case([&](OpaqueType opaqueTy) { + printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(), + opaqueTy.getTypeData()); + }) + .Case([&](Type) { os << "index"; }) + .Case([&](Type) { os << "bf16"; }) + .Case([&](Type) { os << "f16"; }) + .Case([&](Type) { os << "f32"; }) + .Case([&](Type) { os << "f64"; }) + .Case([&](IntegerType integerTy) { + if (integerTy.isSigned()) + os << 's'; + else if (integerTy.isUnsigned()) + os << 'u'; + os << 'i' << integerTy.getWidth(); + }) + .Case([&](FunctionType funcTy) { + os << '('; + interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); }); + os << ") -> "; + ArrayRef results = funcTy.getResults(); + if (results.size() == 1 && !results[0].isa()) { + os << results[0]; + } else { + os << '('; + interleaveComma(results, [&](Type ty) { printType(ty); }); + os << ')'; + } + }) + .Case([&](VectorType vectorTy) { + os << "vector<"; + for (int64_t dim : vectorTy.getShape()) + os << dim << 'x'; + os << vectorTy.getElementType() << '>'; + }) + .Case([&](RankedTensorType tensorTy) { + os << "tensor<"; + for (int64_t dim : tensorTy.getShape()) { + if (ShapedType::isDynamic(dim)) + os << '?'; + else + os << dim; + os << 'x'; + } + os << tensorTy.getElementType() << '>'; + }) + .Case([&](UnrankedTensorType tensorTy) { + os << "tensor<*x"; + printType(tensorTy.getElementType()); + os << '>'; + }) + .Case([&](MemRefType memrefTy) { + os << "memref<"; + for (int64_t dim : memrefTy.getShape()) { + if (ShapedType::isDynamic(dim)) + os << '?'; + else + os << dim; + os << 'x'; + } + printType(memrefTy.getElementType()); + for (auto map : memrefTy.getAffineMaps()) { + os << ", "; + printAttribute(AffineMapAttr::get(map)); + } + // Only print the memory space if it is the non-default one. + if (memrefTy.getMemorySpace()) + os << ", " << memrefTy.getMemorySpace(); + os << '>'; + }) + .Case([&](UnrankedMemRefType memrefTy) { + os << "memref<*x"; + printType(memrefTy.getElementType()); + os << '>'; + }) + .Case([&](ComplexType complexTy) { + os << "complex<"; + printType(complexTy.getElementType()); + os << '>'; + }) + .Case([&](TupleType tupleTy) { + os << "tuple<"; + interleaveComma(tupleTy.getTypes(), + [&](Type type) { printType(type); }); + os << '>'; + }) + .Case([&](Type) { os << "none"; }) + .Default([&](Type type) { return printDialectType(type); }); } void ModulePrinter::printOptionalAttrDict(ArrayRef attrs, diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -747,21 +747,13 @@ for (unsigned i = 0, e = values.size(); i < e; ++i) { assert(eltType == values[i].getType() && "expected attribute value to have element type"); - - switch (eltType.getKind()) { - case StandardTypes::BF16: - case StandardTypes::F16: - case StandardTypes::F32: - case StandardTypes::F64: + if (eltType.isa()) intVal = values[i].cast().getValue().bitcastToAPInt(); - break; - case StandardTypes::Integer: - case StandardTypes::Index: + else if (eltType.isa()) intVal = values[i].cast().getValue(); - break; - default: + else llvm_unreachable("unexpected element type"); - } + assert(intVal.getBitWidth() == bitWidth && "expected value to have same bitwidth as element type"); writeBits(data.data(), i * storageBitWidth, intVal); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -268,25 +268,17 @@ } Attribute Builder::getZeroAttr(Type type) { - switch (type.getKind()) { - case StandardTypes::BF16: - case StandardTypes::F16: - case StandardTypes::F32: - case StandardTypes::F64: + if (type.isa()) return getFloatAttr(type, 0.0); - case StandardTypes::Integer: + if (auto integerType = type.dyn_cast()) return getIntegerAttr(type, APInt(type.cast().getWidth(), 0)); - case StandardTypes::Vector: - case StandardTypes::RankedTensor: { + if (type.isa()) { auto vtType = type.cast(); auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; return DenseElementsAttr::get(vtType, element); } - default: - break; - } return {}; } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -95,9 +95,10 @@ addAttributes(); - addTypes(); + addTypes(); // TODO: These operations should be moved to a different dialect when they // have been fully decoupled from the core. @@ -316,7 +317,10 @@ StorageUniquer typeUniquer; /// Cached Type Instances. - FloatType bf16Ty, f16Ty, f32Ty, f64Ty; + BFloat16Type bf16Ty; + Float16Type f16Ty; + Float32Type f32Ty; + Float64Type f64Ty; IndexType indexTy; IntegerType int1Ty, int8Ty, int16Ty, int32Ty, int64Ty, int128Ty; NoneType noneType; @@ -362,10 +366,10 @@ //// Types. /// Floating-point Types. - impl->bf16Ty = TypeUniquer::get(this, StandardTypes::BF16); - impl->f16Ty = TypeUniquer::get(this, StandardTypes::F16); - impl->f32Ty = TypeUniquer::get(this, StandardTypes::F32); - impl->f64Ty = TypeUniquer::get(this, StandardTypes::F64); + impl->bf16Ty = TypeUniquer::get(this, StandardTypes::BF16); + impl->f16Ty = TypeUniquer::get(this, StandardTypes::F16); + impl->f32Ty = TypeUniquer::get(this, StandardTypes::F32); + impl->f64Ty = TypeUniquer::get(this, StandardTypes::F64); /// Index Type. impl->indexTy = TypeUniquer::get(this, StandardTypes::Index); /// Integer Types. @@ -665,19 +669,17 @@ /// This should not be used directly. StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; } -FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) { - switch (kind) { - case StandardTypes::BF16: - return context->getImpl().bf16Ty; - case StandardTypes::F16: - return context->getImpl().f16Ty; - case StandardTypes::F32: - return context->getImpl().f32Ty; - case StandardTypes::F64: - return context->getImpl().f64Ty; - default: - llvm_unreachable("unexpected floating-point kind"); - } +BFloat16Type BFloat16Type::get(MLIRContext *context) { + return context->getImpl().bf16Ty; +} +Float16Type Float16Type::get(MLIRContext *context) { + return context->getImpl().f16Ty; +} +Float32Type Float32Type::get(MLIRContext *context) { + return context->getImpl().f32Ty; +} +Float64Type Float64Type::get(MLIRContext *context) { + return context->getImpl().f64Ty; } /// Get an instance of the IndexType. diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -22,10 +22,10 @@ // Type //===----------------------------------------------------------------------===// -bool Type::isBF16() { return getKind() == StandardTypes::BF16; } -bool Type::isF16() { return getKind() == StandardTypes::F16; } -bool Type::isF32() { return getKind() == StandardTypes::F32; } -bool Type::isF64() { return getKind() == StandardTypes::F64; } +bool Type::isBF16() { return isa(); } +bool Type::isF16() { return isa(); } +bool Type::isF32() { return isa(); } +bool Type::isF64() { return isa(); } bool Type::isIndex() { return isa(); } @@ -90,6 +90,13 @@ bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); } +unsigned Type::getIntOrFloatBitWidth() { + assert(isIntOrFloat() && "only integers and floats have a bitwidth"); + if (auto intType = dyn_cast()) + return intType.getWidth(); + return cast().getWidth(); +} + //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// @@ -142,39 +149,28 @@ //===----------------------------------------------------------------------===// unsigned FloatType::getWidth() { - switch (getKind()) { - case StandardTypes::BF16: - case StandardTypes::F16: + if (isa()) return 16; - case StandardTypes::F32: + if (isa()) return 32; - case StandardTypes::F64: + if (isa()) return 64; - default: - llvm_unreachable("unexpected type"); - } + llvm_unreachable("unexpected float type"); } /// Returns the floating semantics for the given type. const llvm::fltSemantics &FloatType::getFloatSemantics() { - if (isBF16()) + if (isa()) return APFloat::BFloat(); - if (isF16()) + if (isa()) return APFloat::IEEEhalf(); - if (isF32()) + if (isa()) return APFloat::IEEEsingle(); - if (isF64()) + if (isa()) return APFloat::IEEEdouble(); llvm_unreachable("non-floating point type used"); } -unsigned Type::getIntOrFloatBitWidth() { - assert(isIntOrFloat() && "only integers and floats have a bitwidth"); - if (auto intType = dyn_cast()) - return intType.getWidth(); - return cast().getWidth(); -} - //===----------------------------------------------------------------------===// // ShapedType //===----------------------------------------------------------------------===//