diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -101,10 +101,7 @@ } /// Support for isa/cast. - static bool classof(Type type) { - return type.getKind() >= FIRST_NEW_LLVM_TYPE && - type.getKind() <= LAST_NEW_LLVM_TYPE; - } + static bool classof(Type type); LLVMDialect &getDialect(); diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -71,10 +71,7 @@ int64_t storageTypeMax); /// Support method to enable LLVM-style type casting. - static bool classof(Type type) { - return type.getKind() >= Type::FIRST_QUANTIZATION_TYPE && - type.getKind() <= QuantizationTypes::LAST_USED_QUANTIZATION_TYPE; - } + static bool classof(Type type); /// Gets the minimum possible stored by a storageType. storageTypeMin must /// be greater than or equal to this value. diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -294,13 +294,7 @@ int64_t getSizeInBits() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type) { - return type.getKind() == StandardTypes::Vector || - type.getKind() == StandardTypes::RankedTensor || - type.getKind() == StandardTypes::UnrankedTensor || - type.getKind() == StandardTypes::UnrankedMemRef || - type.getKind() == StandardTypes::MemRef; - } + static bool classof(Type type); /// Whether the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dSize) { @@ -358,20 +352,10 @@ using ShapedType::ShapedType; /// Return true if the specified element type is ok in a tensor. - static bool isValidElementType(Type type) { - // Note: Non standard/builtin types are allowed to exist within tensor - // types. Dialects are expected to verify that tensor types have a valid - // element type within that dialect. - return type.isa() || - (type.getKind() > Type::Kind::LAST_STANDARD_TYPE); - } + static bool isValidElementType(Type type); /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type) { - return type.getKind() == StandardTypes::RankedTensor || - type.getKind() == StandardTypes::UnrankedTensor; - } + static bool classof(Type type); }; //===----------------------------------------------------------------------===// @@ -443,10 +427,7 @@ using ShapedType::ShapedType; /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type) { - return type.getKind() == StandardTypes::MemRef || - type.getKind() == StandardTypes::UnrankedMemRef; - } + static bool classof(Type type); }; //===----------------------------------------------------------------------===// @@ -629,6 +610,23 @@ } }; +//===----------------------------------------------------------------------===// +// Deferred Method Definitions +//===----------------------------------------------------------------------===// + +inline bool BaseMemRefType::classof(Type type) { + return type.isa(); +} + +inline bool ShapedType::classof(Type type) { + return type.isa(); +} + +inline bool TensorType::classof(Type type) { + return type.isa(); +} + //===----------------------------------------------------------------------===// // Type Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -27,6 +27,10 @@ // LLVMType. //===----------------------------------------------------------------------===// +bool LLVMType::classof(Type type) { + return llvm::isa(type.getDialect()); +} + LLVMDialect &LLVMType::getDialect() { return static_cast(Type::getDialect()); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -55,11 +55,5 @@ void mlir::linalg::LinalgDialect::printType(Type type, DialectAsmPrinter &os) const { - switch (type.getKind()) { - default: - llvm_unreachable("Unhandled Linalg type"); - case LinalgTypes::Range: - print(type.cast(), os); - break; - } + print(type.cast(), os); } diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Quant/QuantTypes.h" #include "TypeDetail.h" +#include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/StandardTypes.h" @@ -23,6 +24,10 @@ return static_cast(impl)->flags; } +bool QuantizedType::classof(Type type) { + return llvm::isa(type.getDialect()); +} + LogicalResult QuantizedType::verifyConstructionInvariants( Location loc, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -365,18 +365,12 @@ /// Print a type registered to this dialect. void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { - switch (type.getKind()) { - default: + if (auto anyType = type.dyn_cast()) + printAnyQuantizedType(anyType, os); + else if (auto uniformType = type.dyn_cast()) + printUniformQuantizedType(uniformType, os); + else if (auto perAxisType = type.dyn_cast()) + printUniformQuantizedPerAxisType(perAxisType, os); + else llvm_unreachable("Unhandled quantized type"); - case QuantizationTypes::Any: - printAnyQuantizedType(type.cast(), os); - break; - case QuantizationTypes::UniformQuantized: - printUniformQuantizedType(type.cast(), os); - break; - case QuantizationTypes::UniformQuantizedPerAxis: - printUniformQuantizedPerAxisType(type.cast(), - os); - break; - } } diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -19,48 +19,33 @@ const ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType(Type inputType) { - switch (inputType.getKind()) { - default: - if (isQuantizablePrimitiveType(inputType)) { - // Supported primitive type (which just is the expressed type). - return ExpressedToQuantizedConverter{inputType, inputType}; - } - // Unsupported. - return ExpressedToQuantizedConverter{inputType, nullptr}; - case StandardTypes::RankedTensor: - case StandardTypes::UnrankedTensor: - case StandardTypes::Vector: { + if (inputType.isa()) { Type elementType = inputType.cast().getElementType(); - if (!isQuantizablePrimitiveType(elementType)) { - // Unsupported. + if (!isQuantizablePrimitiveType(elementType)) return ExpressedToQuantizedConverter{inputType, nullptr}; - } - return ExpressedToQuantizedConverter{ - inputType, inputType.cast().getElementType()}; - } + return ExpressedToQuantizedConverter{inputType, elementType}; } + // Supported primitive type (which just is the expressed type). + if (isQuantizablePrimitiveType(inputType)) + return ExpressedToQuantizedConverter{inputType, inputType}; + // Unsupported. + return ExpressedToQuantizedConverter{inputType, nullptr}; } Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { assert(expressedType && "convert() on unsupported conversion"); - - switch (inputType.getKind()) { - default: - if (elementalType.getExpressedType() == expressedType) { - // If the expressed types match, just use the new elemental type. - return elementalType; - } - // Unsupported. - return nullptr; - case StandardTypes::RankedTensor: - return RankedTensorType::get(inputType.cast().getShape(), - elementalType); - case StandardTypes::UnrankedTensor: + if (auto tensorType = inputType.dyn_cast()) + return RankedTensorType::get(tensorType.getShape(), elementalType); + if (auto tensorType = inputType.dyn_cast()) return UnrankedTensorType::get(elementalType); - case StandardTypes::Vector: - return VectorType::get(inputType.cast().getShape(), - elementalType); - } + if (auto vectorType = inputType.dyn_cast()) + return VectorType::get(vectorType.getShape(), elementalType); + + // If the expressed types match, just use the new elemental type. + if (elementalType.getExpressedType() == expressedType) + return elementalType; + // Unsupported. + return nullptr; } ElementsAttr diff --git a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp --- a/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/LayoutUtils.cpp @@ -78,20 +78,17 @@ size = alignment; return type; } - - switch (type.getKind()) { - case spirv::TypeKind::Struct: - return decorateType(type.cast(), size, alignment); - case spirv::TypeKind::Array: - return decorateType(type.cast(), size, alignment); - case StandardTypes::Vector: - return decorateType(type.cast(), size, alignment); - case spirv::TypeKind::RuntimeArray: + if (auto structType = type.dyn_cast()) + return decorateType(structType, size, alignment); + if (auto arrayType = type.dyn_cast()) + return decorateType(arrayType, size, alignment); + if (auto vectorType = type.dyn_cast()) + return decorateType(vectorType, size, alignment); + if (auto arrayType = type.dyn_cast()) { size = std::numeric_limits().max(); - return decorateType(type.cast(), alignment); - default: - llvm_unreachable("unhandled SPIR-V type"); + return decorateType(arrayType, alignment); } + llvm_unreachable("unhandled SPIR-V type"); } Type VulkanLayoutUtils::decorateType(VectorType vectorType, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" namespace mlir { @@ -727,31 +728,11 @@ } void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { - switch (type.getKind()) { - case TypeKind::Array: - print(type.cast(), os); - return; - case TypeKind::CooperativeMatrix: - print(type.cast(), os); - return; - case TypeKind::Pointer: - print(type.cast(), os); - return; - case TypeKind::RuntimeArray: - print(type.cast(), os); - return; - case TypeKind::Image: - print(type.cast(), os); - return; - case TypeKind::Struct: - print(type.cast(), os); - return; - case TypeKind::Matrix: - print(type.cast(), os); - return; - default: - llvm_unreachable("unhandled SPIR-V type"); - } + TypeSwitch(type) + .Case( + [&](auto type) { print(type, os); }) + .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); }); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1534,8 +1534,7 @@ if (!type.isa()) return false; - if (type.getKind() >= Type::FIRST_SPIRV_TYPE && - type.getKind() <= spirv::TypeKind::LAST_SPIRV_TYPE) { + if (isa(type.getDialect())) { // TODO: support constant struct return type.isa(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::spirv; @@ -163,18 +164,11 @@ //===----------------------------------------------------------------------===// bool CompositeType::classof(Type type) { - switch (type.getKind()) { - case TypeKind::Array: - case TypeKind::CooperativeMatrix: - case TypeKind::Matrix: - case TypeKind::RuntimeArray: - case TypeKind::Struct: - return true; - case StandardTypes::Vector: - return isValid(type.cast()); - default: - return false; - } + if (auto vectorType = type.dyn_cast()) + return isValid(vectorType); + return type + .isa(); } bool CompositeType::isValid(VectorType type) { @@ -183,22 +177,14 @@ } Type CompositeType::getElementType(unsigned index) const { - switch (getKind()) { - case spirv::TypeKind::Array: - return cast().getElementType(); - case spirv::TypeKind::CooperativeMatrix: - return cast().getElementType(); - case spirv::TypeKind::Matrix: - return cast().getColumnType(); - case spirv::TypeKind::RuntimeArray: - return cast().getElementType(); - case spirv::TypeKind::Struct: - return cast().getElementType(index); - case StandardTypes::Vector: - return cast().getElementType(); - default: - llvm_unreachable("invalid composite type"); - } + return TypeSwitch(*this) + .Case( + [](auto type) { return type.getElementType(); }) + .Case([](MatrixType type) { return type.getColumnType(); }) + .Case( + [index](StructType type) { return type.getElementType(index); }) + .Default( + [](Type) -> Type { llvm_unreachable("invalid composite type"); }); } unsigned CompositeType::getNumElements() const { diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp --- a/mlir/lib/Dialect/Traits.cpp +++ b/mlir/lib/Dialect/Traits.cpp @@ -123,16 +123,16 @@ // Returns the type kind if the given type is a vector or ranked tensor type. // Returns llvm::None otherwise. - auto getCompositeTypeKind = [](Type type) -> Optional { + auto getCompositeTypeKind = [](Type type) -> Optional { if (type.isa()) - return static_cast(type.getKind()); + return type.getTypeID(); return llvm::None; }; // Make sure the composite type, if has, is consistent. - auto compositeKind1 = getCompositeTypeKind(type1); - auto compositeKind2 = getCompositeTypeKind(type2); - Optional resultCompositeKind; + Optional compositeKind1 = getCompositeTypeKind(type1); + Optional compositeKind2 = getCompositeTypeKind(type2); + Optional resultCompositeKind; if (compositeKind1 && compositeKind2) { // Disallow mixing vector and tensor. @@ -151,9 +151,9 @@ return {}; // Compose the final broadcasted type - if (resultCompositeKind == StandardTypes::Vector) + if (resultCompositeKind == VectorType::getTypeID()) return VectorType::get(resultShape, elementType); - if (resultCompositeKind == StandardTypes::RankedTensor) + if (resultCompositeKind == RankedTensorType::getTypeID()) return RankedTensorType::get(resultShape, elementType); return elementType; } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Twine.h" @@ -244,16 +245,11 @@ } ArrayRef ShapedType::getShape() const { - switch (getKind()) { - case StandardTypes::Vector: - return cast().getShape(); - case StandardTypes::RankedTensor: - return cast().getShape(); - case StandardTypes::MemRef: - return cast().getShape(); - default: - llvm_unreachable("not a ShapedType or not ranked"); - } + if (auto vectorType = dyn_cast()) + return vectorType.getShape(); + if (auto tensorType = dyn_cast()) + return tensorType.getShape(); + return cast().getShape(); } int64_t ShapedType::getNumDynamicDims() const { @@ -305,13 +301,23 @@ // Check if "elementType" can be an element type of a tensor. Emit errors if // location is not nullptr. Returns failure if check failed. -static inline LogicalResult checkTensorElementType(Location location, - Type elementType) { +static LogicalResult checkTensorElementType(Location location, + Type elementType) { if (!TensorType::isValidElementType(elementType)) return emitError(location, "invalid tensor element type"); return success(); } +/// Return true if the specified element type is ok in a tensor. +bool TensorType::isValidElementType(Type type) { + // Note: Non standard/builtin types are allowed to exist within tensor + // types. Dialects are expected to verify that tensor types have a valid + // element type within that dialect. + return type.isa() || + !type.getDialect().getNamespace().empty(); +} + //===----------------------------------------------------------------------===// // RankedTensorType //===----------------------------------------------------------------------===//