diff --git a/flang/include/flang/Optimizer/Dialect/FIRAttr.h b/flang/include/flang/Optimizer/Dialect/FIRAttr.h --- a/flang/include/flang/Optimizer/Dialect/FIRAttr.h +++ b/flang/include/flang/Optimizer/Dialect/FIRAttr.h @@ -25,17 +25,6 @@ struct TypeAttributeStorage; } // namespace detail -enum AttributeKind { - FIR_ATTR = mlir::Attribute::FIRST_FIR_ATTR, - FIR_EXACTTYPE, // instance_of, precise type relation - FIR_SUBCLASS, // subsumed_by, is-a (subclass) relation - FIR_POINT, - FIR_CLOSEDCLOSED_INTERVAL, - FIR_OPENCLOSED_INTERVAL, - FIR_CLOSEDOPEN_INTERVAL, - FIR_REAL_ATTR, -}; - class ExactTypeAttr : public mlir::Attribute::AttrBase { @@ -47,8 +36,6 @@ static ExactTypeAttr get(mlir::Type value); mlir::Type getType() const; - - static constexpr unsigned getId() { return AttributeKind::FIR_EXACTTYPE; } }; class SubclassAttr @@ -62,8 +49,6 @@ static SubclassAttr get(mlir::Type value); mlir::Type getType() const; - - static constexpr unsigned getId() { return AttributeKind::FIR_SUBCLASS; } }; // Attributes for building SELECT CASE multiway branches @@ -80,9 +65,6 @@ static constexpr llvm::StringRef getAttrName() { return "interval"; } static ClosedIntervalAttr get(mlir::MLIRContext *ctxt); - static constexpr unsigned getId() { - return AttributeKind::FIR_CLOSEDCLOSED_INTERVAL; - } }; /// An upper bound is an open interval (including the bound value) as given as @@ -97,9 +79,6 @@ static constexpr llvm::StringRef getAttrName() { return "upper"; } static UpperBoundAttr get(mlir::MLIRContext *ctxt); - static constexpr unsigned getId() { - return AttributeKind::FIR_OPENCLOSED_INTERVAL; - } }; /// A lower bound is an open interval (including the bound value) as given as @@ -114,9 +93,6 @@ static constexpr llvm::StringRef getAttrName() { return "lower"; } static LowerBoundAttr get(mlir::MLIRContext *ctxt); - static constexpr unsigned getId() { - return AttributeKind::FIR_CLOSEDOPEN_INTERVAL; - } }; /// A pointer interval is a closed interval as given as an ssa-value. The @@ -131,7 +107,6 @@ static constexpr llvm::StringRef getAttrName() { return "point"; } static PointIntervalAttr get(mlir::MLIRContext *ctxt); - static constexpr unsigned getId() { return AttributeKind::FIR_POINT; } }; /// A real attribute is used to workaround MLIR's default parsing of a real @@ -150,8 +125,6 @@ int getFKind() const; llvm::APFloat getValue() const; - - static constexpr unsigned getId() { return AttributeKind::FIR_REAL_ATTR; } }; mlir::Attribute parseFirAttribute(FIROpsDialect *dialect, diff --git a/flang/include/flang/Optimizer/Dialect/FIRType.h b/flang/include/flang/Optimizer/Dialect/FIRType.h --- a/flang/include/flang/Optimizer/Dialect/FIRType.h +++ b/flang/include/flang/Optimizer/Dialect/FIRType.h @@ -54,29 +54,6 @@ struct TypeDescTypeStorage; } // namespace detail -/// Integral identifier for all the types comprising the FIR type system -enum TypeKind { - // The enum starts at the range reserved for this dialect. - FIR_TYPE = mlir::Type::FIRST_FIR_TYPE, - FIR_BOX, // (static) descriptor - FIR_BOXCHAR, // CHARACTER pointer and length - FIR_BOXPROC, // procedure with host association - FIR_CHARACTER, // intrinsic type - FIR_COMPLEX, // intrinsic type - FIR_DERIVED, // derived - FIR_DIMS, - FIR_FIELD, - FIR_HEAP, - FIR_INT, // intrinsic type - FIR_LEN, - FIR_LOGICAL, // intrinsic type - FIR_POINTER, // POINTER attr - FIR_REAL, // intrinsic type - FIR_REFERENCE, - FIR_SEQUENCE, // DIMENSION attr - FIR_TYPEDESC, -}; - // These isa_ routines follow the precedent of llvm::isa_or_null<> /// Is `t` any of the FIR dialect types? @@ -111,12 +88,6 @@ /// not a memory reference type, then returns a null `Type`. mlir::Type dyn_cast_ptrEleTy(mlir::Type t); -/// Boilerplate mixin template -template -struct IntrinsicTypeMixin { - static constexpr unsigned getId() { return Id; } -}; - // Intrinsic types /// Model of the Fortran CHARACTER intrinsic type, including the KIND type @@ -124,8 +95,7 @@ /// is thus the type of a single character value. class CharacterType : public mlir::Type::TypeBase, - public IntrinsicTypeMixin { + detail::CharacterTypeStorage> { public: using Base::Base; static CharacterType get(mlir::MLIRContext *ctxt, KindTy kind); @@ -136,8 +106,7 @@ /// parameter. COMPLEX is a floating point type with a real and imaginary /// member. class CplxType : public mlir::Type::TypeBase, - public IntrinsicTypeMixin { + detail::CplxTypeStorage> { public: using Base::Base; static CplxType get(mlir::MLIRContext *ctxt, KindTy kind); @@ -151,8 +120,7 @@ /// Model of a Fortran INTEGER intrinsic type, including the KIND type /// parameter. class IntType - : public mlir::Type::TypeBase, - public IntrinsicTypeMixin { + : public mlir::Type::TypeBase { public: using Base::Base; static IntType get(mlir::MLIRContext *ctxt, KindTy kind); @@ -163,8 +131,7 @@ /// parameter. class LogicalType : public mlir::Type::TypeBase, - public IntrinsicTypeMixin { + detail::LogicalTypeStorage> { public: using Base::Base; static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind); @@ -174,8 +141,7 @@ /// Model of a Fortran REAL (and DOUBLE PRECISION) intrinsic type, including the /// KIND type parameter. class RealType : public mlir::Type::TypeBase, - public IntrinsicTypeMixin { + detail::RealTypeStorage> { public: using Base::Base; static RealType get(mlir::MLIRContext *ctxt, KindTy kind); @@ -400,7 +366,6 @@ static RecordType get(mlir::MLIRContext *ctxt, llvm::StringRef name); void finalize(llvm::ArrayRef lenPList, llvm::ArrayRef typeList); - static constexpr unsigned getId() { return TypeKind::FIR_DERIVED; } detail::RecordTypeStorage const *uniqueKey() const; diff --git a/flang/lib/Optimizer/Dialect/FIRAttr.cpp b/flang/lib/Optimizer/Dialect/FIRAttr.cpp --- a/flang/lib/Optimizer/Dialect/FIRAttr.cpp +++ b/flang/lib/Optimizer/Dialect/FIRAttr.cpp @@ -74,13 +74,13 @@ } // namespace detail ExactTypeAttr ExactTypeAttr::get(mlir::Type value) { - return Base::get(value.getContext(), FIR_EXACTTYPE, value); + return Base::get(value.getContext(), value); } mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); } SubclassAttr SubclassAttr::get(mlir::Type value) { - return Base::get(value.getContext(), FIR_SUBCLASS, value); + return Base::get(value.getContext(), value); } mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); } @@ -88,26 +88,26 @@ using AttributeUniquer = mlir::detail::AttributeUniquer; ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) { - return AttributeUniquer::get(ctxt, getId()); + return AttributeUniquer::get(ctxt); } UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) { - return AttributeUniquer::get(ctxt, getId()); + return AttributeUniquer::get(ctxt); } LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) { - return AttributeUniquer::get(ctxt, getId()); + return AttributeUniquer::get(ctxt); } PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) { - return AttributeUniquer::get(ctxt, getId()); + return AttributeUniquer::get(ctxt); } // RealAttr RealAttr RealAttr::get(mlir::MLIRContext *ctxt, const RealAttr::ValueType &key) { - return Base::get(ctxt, getId(), key); + return Base::get(ctxt, key); } int RealAttr::getFKind() const { return getImpl()->getFKind(); } diff --git a/flang/lib/Optimizer/Dialect/FIRType.cpp b/flang/lib/Optimizer/Dialect/FIRType.cpp --- a/flang/lib/Optimizer/Dialect/FIRType.cpp +++ b/flang/lib/Optimizer/Dialect/FIRType.cpp @@ -824,13 +824,11 @@ } bool isa_fir_type(mlir::Type t) { - return inbounds(t.getKind(), mlir::Type::FIRST_FIR_TYPE, - mlir::Type::LAST_FIR_TYPE); + return llvm::isa(t.getDialect()); } bool isa_std_type(mlir::Type t) { - return inbounds(t.getKind(), mlir::Type::FIRST_STANDARD_TYPE, - mlir::Type::LAST_STANDARD_TYPE); + return t.getDialect().getNamespace().empty(); } bool isa_fir_or_std_type(mlir::Type t) { @@ -868,7 +866,7 @@ // CHARACTER CharacterType fir::CharacterType::get(mlir::MLIRContext *ctxt, KindTy kind) { - return Base::get(ctxt, FIR_CHARACTER, kind); + return Base::get(ctxt, kind); } int fir::CharacterType::getFKind() const { return getImpl()->getFKind(); } @@ -876,7 +874,7 @@ // Dims DimsType fir::DimsType::get(mlir::MLIRContext *ctxt, unsigned rank) { - return Base::get(ctxt, FIR_DIMS, rank); + return Base::get(ctxt, rank); } unsigned fir::DimsType::getRank() const { return getImpl()->getRank(); } @@ -884,19 +882,19 @@ // Field FieldType fir::FieldType::get(mlir::MLIRContext *ctxt) { - return Base::get(ctxt, FIR_FIELD, 0); + return Base::get(ctxt, 0); } // Len LenType fir::LenType::get(mlir::MLIRContext *ctxt) { - return Base::get(ctxt, FIR_LEN, 0); + return Base::get(ctxt, 0); } // LOGICAL LogicalType fir::LogicalType::get(mlir::MLIRContext *ctxt, KindTy kind) { - return Base::get(ctxt, FIR_LOGICAL, kind); + return Base::get(ctxt, kind); } int fir::LogicalType::getFKind() const { return getImpl()->getFKind(); } @@ -904,7 +902,7 @@ // INTEGER IntType fir::IntType::get(mlir::MLIRContext *ctxt, KindTy kind) { - return Base::get(ctxt, FIR_INT, kind); + return Base::get(ctxt, kind); } int fir::IntType::getFKind() const { return getImpl()->getFKind(); } @@ -912,7 +910,7 @@ // COMPLEX CplxType fir::CplxType::get(mlir::MLIRContext *ctxt, KindTy kind) { - return Base::get(ctxt, FIR_COMPLEX, kind); + return Base::get(ctxt, kind); } mlir::Type fir::CplxType::getElementType() const { @@ -924,7 +922,7 @@ // REAL RealType fir::RealType::get(mlir::MLIRContext *ctxt, KindTy kind) { - return Base::get(ctxt, FIR_REAL, kind); + return Base::get(ctxt, kind); } int fir::RealType::getFKind() const { return getImpl()->getFKind(); } @@ -932,7 +930,7 @@ // Box BoxType fir::BoxType::get(mlir::Type elementType, mlir::AffineMapAttr map) { - return Base::get(elementType.getContext(), FIR_BOX, elementType, map); + return Base::get(elementType.getContext(), elementType, map); } mlir::Type fir::BoxType::getEleTy() const { @@ -953,7 +951,7 @@ // BoxChar BoxCharType fir::BoxCharType::get(mlir::MLIRContext *ctxt, KindTy kind) { - return Base::get(ctxt, FIR_BOXCHAR, kind); + return Base::get(ctxt, kind); } CharacterType fir::BoxCharType::getEleTy() const { @@ -963,7 +961,7 @@ // BoxProc BoxProcType fir::BoxProcType::get(mlir::Type elementType) { - return Base::get(elementType.getContext(), FIR_BOXPROC, elementType); + return Base::get(elementType.getContext(), elementType); } mlir::Type fir::BoxProcType::getEleTy() const { @@ -984,7 +982,7 @@ // Reference ReferenceType fir::ReferenceType::get(mlir::Type elementType) { - return Base::get(elementType.getContext(), FIR_REFERENCE, elementType); + return Base::get(elementType.getContext(), elementType); } mlir::Type fir::ReferenceType::getEleTy() const { @@ -1005,7 +1003,7 @@ PointerType fir::PointerType::get(mlir::Type elementType) { assert(singleIndirectionLevel(elementType) && "invalid element type"); - return Base::get(elementType.getContext(), FIR_POINTER, elementType); + return Base::get(elementType.getContext(), elementType); } mlir::Type fir::PointerType::getEleTy() const { @@ -1033,7 +1031,7 @@ HeapType fir::HeapType::get(mlir::Type elementType) { assert(singleIndirectionLevel(elementType) && "invalid element type"); - return Base::get(elementType.getContext(), FIR_HEAP, elementType); + return Base::get(elementType.getContext(), elementType); } mlir::Type fir::HeapType::getEleTy() const { @@ -1054,7 +1052,7 @@ SequenceType fir::SequenceType::get(const Shape &shape, mlir::Type elementType, mlir::AffineMapAttr map) { auto *ctxt = elementType.getContext(); - return Base::get(ctxt, FIR_SEQUENCE, shape, elementType, map); + return Base::get(ctxt, shape, elementType, map); } mlir::Type fir::SequenceType::getEleTy() const { @@ -1136,7 +1134,7 @@ /// This type captures a Fortran "derived type" RecordType fir::RecordType::get(mlir::MLIRContext *ctxt, llvm::StringRef name) { - return Base::get(ctxt, FIR_DERIVED, name); + return Base::get(ctxt, name); } void fir::RecordType::finalize(llvm::ArrayRef lenPList, @@ -1179,7 +1177,7 @@ TypeDescType fir::TypeDescType::get(mlir::Type ofType) { assert(!ofType.isa()); - return Base::get(ofType.getContext(), FIR_TYPEDESC, ofType); + return Base::get(ofType.getContext(), ofType); } mlir::Type fir::TypeDescType::getOfTy() const { return getImpl()->getOfType(); } @@ -1222,9 +1220,7 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p) { auto &os = p.getStream(); - switch (ty.getKind()) { - case fir::FIR_BOX: { - auto type = ty.cast(); + if (auto type = ty.dyn_cast()) { os << "box<"; p.printType(type.getEleTy()); if (auto map = type.getLayoutMap()) { @@ -1232,24 +1228,28 @@ p.printAttribute(map); } os << '>'; - } break; - case fir::FIR_BOXCHAR: { - auto type = ty.cast().getEleTy(); - os << "boxchar<" << type.cast().getFKind() << '>'; - } break; - case fir::FIR_BOXPROC: + return; + } + if (auto type = ty.dyn_cast()) { + os << "boxchar<" << type.getEleTy().cast().getFKind() + << '>'; + return; + } + if (auto type = ty.dyn_cast()) { os << "boxproc<"; - p.printType(ty.cast().getEleTy()); + p.printType(type.getEleTy()); os << '>'; - break; - case fir::FIR_CHARACTER: // intrinsic - os << "char<" << ty.cast().getFKind() << '>'; - break; - case fir::FIR_COMPLEX: // intrinsic - os << "complex<" << ty.cast().getFKind() << '>'; - break; - case fir::FIR_DERIVED: { // derived - auto type = ty.cast(); + return; + } + if (auto type = ty.dyn_cast()) { + os << "char<" << type.getFKind() << '>'; + return; + } + if (auto type = ty.dyn_cast()) { + os << "complex<" << type.getFKind() << '>'; + return; + } + if (auto type = ty.dyn_cast()) { os << "type<" << type.getName(); if (!recordTypeVisited.count(type.uniqueKey())) { recordTypeVisited.insert(type.uniqueKey()); @@ -1274,43 +1274,52 @@ recordTypeVisited.erase(type.uniqueKey()); } os << '>'; - } break; - case fir::FIR_DIMS: - os << "dims<" << ty.cast().getRank() << '>'; - break; - case fir::FIR_FIELD: + return; + } + if (auto type = ty.dyn_cast()) { + os << "dims<" << type.getRank() << '>'; + return; + } + if (ty.isa()) { os << "field"; - break; - case fir::FIR_HEAP: + return; + } + if (auto type = ty.dyn_cast()) { os << "heap<"; - p.printType(ty.cast().getEleTy()); + p.printType(type.getEleTy()); os << '>'; - break; - case fir::FIR_INT: // intrinsic - os << "int<" << ty.cast().getFKind() << '>'; - break; - case fir::FIR_LEN: + return; + } + if (auto type = ty.dyn_cast()) { + os << "int<" << type.getFKind() << '>'; + return; + } + if (auto type = ty.dyn_cast()) { os << "len"; - break; - case fir::FIR_LOGICAL: // intrinsic - os << "logical<" << ty.cast().getFKind() << '>'; - break; - case fir::FIR_POINTER: + return; + } + if (auto type = ty.dyn_cast()) { + os << "logical<" << type.getFKind() << '>'; + return; + } + if (auto type = ty.dyn_cast()) { os << "ptr<"; - p.printType(ty.cast().getEleTy()); + p.printType(type.getEleTy()); os << '>'; - break; - case fir::FIR_REAL: // intrinsic - os << "real<" << ty.cast().getFKind() << '>'; - break; - case fir::FIR_REFERENCE: + return; + } + if (auto type = ty.dyn_cast()) { + os << "real<" << type.getFKind() << '>'; + return; + } + if (auto type = ty.dyn_cast()) { os << "ref<"; - p.printType(ty.cast().getEleTy()); + p.printType(type.getEleTy()); os << '>'; - break; - case fir::FIR_SEQUENCE: { + return; + } + if (auto type = ty.dyn_cast()) { os << "array"; - auto type = ty.cast(); auto shape = type.getShape(); if (shape.size()) { printBounds(os, shape); @@ -1323,11 +1332,12 @@ map.print(os); } os << '>'; - } break; - case fir::FIR_TYPEDESC: + return; + } + if (auto type = ty.dyn_cast()) { os << "tdesc<"; - p.printType(ty.cast().getOfTy()); + p.printType(type.getOfTy()); os << '>'; - break; + return; } } diff --git a/mlir/docs/Tutorials/Toy/Ch-7.md b/mlir/docs/Tutorials/Toy/Ch-7.md --- a/mlir/docs/Tutorials/Toy/Ch-7.md +++ b/mlir/docs/Tutorials/Toy/Ch-7.md @@ -190,11 +190,10 @@ assert(!elementTypes.empty() && "expected at least 1 element type"); // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance - // of this type. The first two parameters are the context to unique in and - // the kind of the type. The parameters after the type kind are forwarded to - // the storage instance. + // of this type. The first parameter is the context to unique in. The + // parameters after the type kind are forwarded to the storage instance. mlir::MLIRContext *ctx = elementTypes.front().getContext(); - return Base::get(ctx, ToyTypes::Struct, elementTypes); + return Base::get(ctx, elementTypes); } /// Returns the element types of this struct type. diff --git a/mlir/examples/toy/Ch7/include/toy/Dialect.h b/mlir/examples/toy/Ch7/include/toy/Dialect.h --- a/mlir/examples/toy/Ch7/include/toy/Dialect.h +++ b/mlir/examples/toy/Ch7/include/toy/Dialect.h @@ -63,13 +63,6 @@ // Toy Types //===----------------------------------------------------------------------===// -/// Create a local enumeration with all of the types that are defined by Toy. -namespace ToyTypes { -enum Types { - Struct = mlir::Type::FIRST_TOY_TYPE, -}; -} // end namespace ToyTypes - /// This class defines the Toy struct type. It represents a collection of /// element types. All derived types in MLIR must inherit from the CRTP class /// 'Type::TypeBase'. It takes as template parameters the concrete type diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -474,11 +474,10 @@ assert(!elementTypes.empty() && "expected at least 1 element type"); // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance - // of this type. The first two parameters are the context to unique in and the - // kind of the type. The parameters after the type kind are forwarded to the - // storage instance. + // of this type. The first parameter is the context to unique in. The + // parameters after the type kind are forwarded to the storage instance. mlir::MLIRContext *ctx = elementTypes.front().getContext(); - return Base::get(ctx, ToyTypes::Struct, elementTypes); + return Base::get(ctx, elementTypes); } /// Returns the element types of this struct type. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -64,34 +64,6 @@ /// structs, the entire type is the identifier) and are thread-safe. class LLVMType : public Type { public: - enum Kind { - // Keep non-parametric types contiguous in the enum. - VoidType = FIRST_LLVM_TYPE + 1, - HalfType, - BFloatType, - FloatType, - DoubleType, - FP128Type, - X86FP80Type, - PPCFP128Type, - X86MMXType, - LabelType, - TokenType, - MetadataType, - // End of non-parametric types. - FunctionType, - IntegerType, - PointerType, - FixedVectorType, - ScalableVectorType, - ArrayType, - StructType, - FIRST_NEW_LLVM_TYPE = VoidType, - LAST_NEW_LLVM_TYPE = StructType, - FIRST_TRIVIAL_TYPE = VoidType, - LAST_TRIVIAL_TYPE = MetadataType - }; - /// Inherit base constructors. using Type::Type; @@ -256,27 +228,24 @@ //===----------------------------------------------------------------------===// // Batch-define trivial types. -#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName, Kind) \ +#define DEFINE_TRIVIAL_LLVM_TYPE(ClassName) \ class ClassName : public Type::TypeBase { \ public: \ using Base::Base; \ - static ClassName get(MLIRContext *context) { \ - return Base::get(context, Kind); \ - } \ } -DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, LLVMType::VoidType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType, LLVMType::HalfType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType, LLVMType::BFloatType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType, LLVMType::FloatType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType, LLVMType::DoubleType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type, LLVMType::FP128Type); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type, LLVMType::X86FP80Type); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, LLVMType::PPCFP128Type); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType, LLVMType::X86MMXType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, LLVMType::TokenType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, LLVMType::LabelType); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, LLVMType::MetadataType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMHalfType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMBFloatType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMFloatType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMDoubleType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMFP128Type); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86FP80Type); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMX86MMXType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType); +DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType); #undef DEFINE_TRIVIAL_LLVM_TYPE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -16,11 +16,6 @@ class MLIRContext; namespace linalg { -enum LinalgTypes { - Range = Type::FIRST_LINALG_TYPE, - LAST_USED_LINALG_TYPE = Range, -}; - #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" /// A RangeType represents a minimal range abstraction (min, max, step). @@ -36,11 +31,6 @@ public: // Used for generic hooks in TypeBase. using Base::Base; - /// Construction hook. - static RangeType get(MLIRContext *context) { - /// Custom, uniq'ed construction in the MLIRContext. - return Base::get(context, LinalgTypes::Range); - } }; } // namespace linalg diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -31,15 +31,6 @@ } // namespace detail -namespace QuantizationTypes { -enum Kind { - Any = Type::FIRST_QUANTIZATION_TYPE, - UniformQuantized, - UniformQuantizedPerAxis, - LAST_USED_QUANTIZATION_TYPE = UniformQuantizedPerAxis, -}; -} // namespace QuantizationTypes - /// Enumeration of bit-mapped flags related to quantized types. namespace QuantizationFlags { enum FlagValue { diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVAttributes.h @@ -32,15 +32,6 @@ struct VerCapExtAttributeStorage; } // namespace detail -/// SPIR-V dialect-specific attribute kinds. -namespace AttrKind { -enum Kind { - InterfaceVarABI = Attribute::FIRST_SPIRV_ATTR, /// Interface var ABI - TargetEnv, /// Target environment - VerCapExt, /// (version, extension, capability) triple -}; -} // namespace AttrKind - /// An attribute that specifies the information regarding the interface /// variable: descriptor set, binding, storage class. class InterfaceVarABIAttr diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -65,19 +65,6 @@ } // namespace detail -namespace TypeKind { -enum Kind { - Array = Type::FIRST_SPIRV_TYPE, - CooperativeMatrix, - Image, - Matrix, - Pointer, - RuntimeArray, - Struct, - LAST_SPIRV_TYPE = Struct, -}; -} - // Base SPIR-V type for providing availability queries. class SPIRVType : public Type { public: diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -29,56 +29,28 @@ /// Alias type for extent tensors. RankedTensorType getExtentTensorType(MLIRContext *ctx); -namespace ShapeTypes { -enum Kind { - Component = Type::FIRST_SHAPE_TYPE, - Element, - Shape, - Size, - ValueShape, - Witness, - LAST_SHAPE_TYPE = Witness -}; -} // namespace ShapeTypes - /// The component type corresponding to shape, element type and attribute. class ComponentType : public Type::TypeBase { public: using Base::Base; - - static ComponentType get(MLIRContext *context) { - return Base::get(context, ShapeTypes::Kind::Component); - } }; /// The element type of the shaped type. class ElementType : public Type::TypeBase { public: using Base::Base; - - static ElementType get(MLIRContext *context) { - return Base::get(context, ShapeTypes::Kind::Element); - } }; /// The shape descriptor type represents rank and dimension sizes. class ShapeType : public Type::TypeBase { public: using Base::Base; - - static ShapeType get(MLIRContext *context) { - return Base::get(context, ShapeTypes::Kind::Shape); - } }; /// The type of a single dimension. class SizeType : public Type::TypeBase { public: using Base::Base; - - static SizeType get(MLIRContext *context) { - return Base::get(context, ShapeTypes::Kind::Size); - } }; /// The ValueShape represents a (potentially unknown) runtime value and shape. @@ -86,10 +58,6 @@ : public Type::TypeBase { public: using Base::Base; - - static ValueShapeType get(MLIRContext *context) { - return Base::get(context, ShapeTypes::Kind::ValueShape); - } }; /// The Witness represents a runtime constraint, to be used as shape related @@ -97,10 +65,6 @@ class WitnessType : public Type::TypeBase { public: using Base::Base; - - static WitnessType get(MLIRContext *context) { - return Base::get(context, ShapeTypes::Kind::Witness); - } }; #define GET_OP_CLASSES diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -137,15 +137,23 @@ // MLIRContext. This class manages all creation and uniquing of attributes. class AttributeUniquer { public: - /// Get an uniqued instance of attribute T. + /// Get an uniqued instance of a parametric attribute T. template - static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { + static typename std::enable_if_t< + !std::is_same::value, T> + get(MLIRContext *ctx, Args &&...args) { return ctx->getAttributeUniquer().get( - T::getTypeID(), [ctx](AttributeStorage *storage) { initializeAttributeStorage(storage, ctx, T::getTypeID()); }, - kind, std::forward(args)...); + T::getTypeID(), std::forward(args)...); + } + /// Get an uniqued instance of a singleton attribute T. + template + static typename std::enable_if_t< + std::is_same::value, T> + get(MLIRContext *ctx) { + return ctx->getAttributeUniquer().get(T::getTypeID()); } template @@ -156,6 +164,26 @@ std::forward(args)...); } + /// Register a parametric attribute instance T with the uniquer. + template + static typename std::enable_if_t< + !std::is_same::value> + registerAttribute(MLIRContext *ctx) { + ctx->getAttributeUniquer() + .registerParametricStorageType(T::getTypeID()); + } + /// Register a singleton attribute instance T with the uniquer. + template + static typename std::enable_if_t< + std::is_same::value> + registerAttribute(MLIRContext *ctx) { + ctx->getAttributeUniquer() + .registerSingletonStorageType( + T::getTypeID(), [ctx](AttributeStorage *storage) { + initializeAttributeStorage(storage, ctx, T::getTypeID()); + }); + } + private: /// Initialize the given attribute storage instance. static void initializeAttributeStorage(AttributeStorage *storage, diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -54,14 +54,6 @@ /// passed by value. class Attribute { public: - /// Integer identifier for all the concrete attribute kinds. - enum Kind { - // Reserve attribute kinds for dialect specific extensions. -#define DEFINE_SYM_KIND_RANGE(Dialect) \ - FIRST_##Dialect##_ATTR, LAST_##Dialect##_ATTR = FIRST_##Dialect##_ATTR + 0xff, -#include "DialectSymbolRegistry.def" - }; - /// Utility class for implementing attributes. template class... Traits> @@ -94,9 +86,6 @@ // Support dyn_cast'ing Attribute to itself. static bool classof(Attribute) { return true; } - /// Return the classification for this attribute. - unsigned getKind() const { return impl->getKind(); } - /// Return a unique identifier for the concrete attribute type. This is used /// to support dynamic type casting. TypeID getTypeID() { return impl->getAbstractAttribute().getTypeID(); } @@ -173,54 +162,6 @@ friend InterfaceBase; }; -//===----------------------------------------------------------------------===// -// StandardAttributes -//===----------------------------------------------------------------------===// - -namespace StandardAttributes { -enum Kind { - AffineMap = Attribute::FIRST_STANDARD_ATTR, - Array, - Dictionary, - Float, - Integer, - IntegerSet, - Opaque, - String, - SymbolRef, - Type, - Unit, - - /// Elements Attributes. - DenseIntOrFPElements, - DenseStringElements, - OpaqueElements, - SparseElements, - FIRST_ELEMENTS_ATTR = DenseIntOrFPElements, - LAST_ELEMENTS_ATTR = SparseElements, - - /// Locations. - CallSiteLocation, - FileLineColLocation, - FusedLocation, - NameLocation, - OpaqueLocation, - UnknownLocation, - - // Represents a location as a 'void*' pointer to a front-end's opaque - // location information, which must live longer than the MLIR objects that - // refer to it. OpaqueLocation's are never serialized. - // - // TODO: OpaqueLocation, - - // Represents a value inlined through a function call. - // TODO: InlinedLocation, - - FIRST_LOCATION_ATTR = CallSiteLocation, - LAST_LOCATION_ATTR = UnknownLocation, -}; -} // namespace StandardAttributes - //===----------------------------------------------------------------------===// // AffineMapAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -154,21 +154,15 @@ void addOperation(AbstractOperation opInfo); - /// This method is used by derived classes to add their types to the set. + /// Register a set of type classes with this dialect. template void addTypes() { - (void)std::initializer_list{ - 0, (addType(Args::getTypeID(), AbstractType::get(*this)), 0)...}; + (void)std::initializer_list{0, (addType(), 0)...}; } - void addType(TypeID typeID, AbstractType &&typeInfo); - /// This method is used by derived classes to add their attributes to the set. + /// Register a set of attribute classes with this dialect. template void addAttributes() { - (void)std::initializer_list{ - 0, - (addAttribute(Args::getTypeID(), AbstractAttribute::get(*this)), - 0)...}; + (void)std::initializer_list{0, (addAttribute(), 0)...}; } - void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); /// Enable support for unregistered operations. void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } @@ -189,6 +183,22 @@ Dialect(const Dialect &) = delete; void operator=(Dialect &) = delete; + /// Register an attribute instance with this dialect. + template void addAttribute() { + // Add this attribute to the dialect and register it with the uniquer. + addAttribute(T::getTypeID(), AbstractAttribute::get(*this)); + detail::AttributeUniquer::registerAttribute(context); + } + void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); + + /// Register a type instance with this dialect. + template void addType() { + // Add this type to the dialect and register it with the uniquer. + addType(T::getTypeID(), AbstractType::get(*this)); + detail::TypeUniquer::registerType(context); + } + void addType(TypeID typeID, AbstractType &&typeInfo); + /// The namespace of this dialect. StringRef name; diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def deleted file mode 100644 --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ /dev/null @@ -1,44 +0,0 @@ -//===- DialectSymbolRegistry.def - MLIR Dialect Symbol Registry -*- C++ -*-===// -// -// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file enumerates the different dialects that define custom classes -// within the attribute or type system. -// -//===----------------------------------------------------------------------===// - -DEFINE_SYM_KIND_RANGE(STANDARD) -DEFINE_SYM_KIND_RANGE(TENSORFLOW_CONTROL) -DEFINE_SYM_KIND_RANGE(TENSORFLOW_EXECUTOR) -DEFINE_SYM_KIND_RANGE(TENSORFLOW) -DEFINE_SYM_KIND_RANGE(LLVM) -DEFINE_SYM_KIND_RANGE(QUANTIZATION) -DEFINE_SYM_KIND_RANGE(IREE) // IREE stands for IR Execution Engine -DEFINE_SYM_KIND_RANGE(LINALG) // Linear Algebra Dialect -DEFINE_SYM_KIND_RANGE(FIR) // Flang Fortran IR Dialect -DEFINE_SYM_KIND_RANGE(OPENACC) // OpenACC IR Dialect -DEFINE_SYM_KIND_RANGE(OPENMP) // OpenMP IR Dialect -DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect -DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect -DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect -DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect -DEFINE_SYM_KIND_RANGE(TF_FRAMEWORK) // TF Framework dialect - -// The following ranges are reserved for experimenting with MLIR dialects in a -// private context without having to register them here. -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_0) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_1) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_2) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_3) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_4) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_5) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_6) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_7) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_8) -DEFINE_SYM_KIND_RANGE(PRIVATE_EXPERIMENTAL_9) - -#undef DEFINE_SYM_KIND_RANGE diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -756,7 +756,7 @@ /// all attributes of the given kind in the form : [0-9]+. These /// aliases must not contain `.`. virtual void getAttributeKindAliases( - SmallVectorImpl> &aliases) const {} + SmallVectorImpl> &aliases) const {} /// Hook for defining Attribute aliases. These aliases must not contain `.` or /// end with a numeric digit([0-9]+). virtual void getAttributeAliases( diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -38,33 +38,6 @@ } // namespace detail -namespace StandardTypes { -enum Kind { - // Floating point. - BF16 = Type::Kind::FIRST_STANDARD_TYPE, - F16, - F32, - F64, - FIRST_FLOATING_POINT_TYPE = BF16, - LAST_FLOATING_POINT_TYPE = F64, - - // Target pointer sized integer, used (e.g.) in affine mappings. - Index, - - // Derived types. - Integer, - Vector, - RankedTensor, - UnrankedTensor, - MemRef, - UnrankedMemRef, - Complex, - Tuple, - None, -}; - -} // namespace StandardTypes - //===----------------------------------------------------------------------===// // ComplexType //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -82,29 +82,29 @@ return detail::InterfaceMap::template get...>(); } -protected: /// Get or create a new ConcreteT instance within the ctx. This /// function is guaranteed to return a non null object and will assert if /// the arguments provided are invalid. template - static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) { + static ConcreteT get(MLIRContext *ctx, Args... args) { // Ensure that the invariants are correct for construction. assert(succeeded(ConcreteT::verifyConstructionInvariants( generateUnknownStorageLocation(ctx), args...))); - return UniquerT::template get(ctx, kind, args...); + return UniquerT::template get(ctx, args...); } /// Get or create a new ConcreteT instance within the ctx, defined at /// the given, potentially unknown, location. If the arguments provided are /// invalid then emit errors and return a null object. template - static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) { + static ConcreteT getChecked(LocationT loc, Args... args) { // If the construction invariants fail then we return a null attribute. if (failed(ConcreteT::verifyConstructionInvariants(loc, args...))) return ConcreteT(); - return UniquerT::template get(loc.getContext(), kind, args...); + return UniquerT::template get(loc.getContext(), args...); } +protected: /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. template LogicalResult mutate(Args &&...args) { diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -121,15 +121,23 @@ /// A utility class to get, or create, unique instances of types within an /// MLIRContext. This class manages all creation and uniquing of types. struct TypeUniquer { - /// Get an uniqued instance of a type T. + /// Get an uniqued instance of a parametric type T. template - static T get(MLIRContext *ctx, unsigned kind, Args &&... args) { + static typename std::enable_if_t< + !std::is_same::value, T> + get(MLIRContext *ctx, Args &&...args) { return ctx->getTypeUniquer().get( - T::getTypeID(), [&](TypeStorage *storage) { storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); }, - kind, std::forward(args)...); + T::getTypeID(), std::forward(args)...); + } + /// Get an uniqued instance of a singleton type T. + template + static typename std::enable_if_t< + std::is_same::value, T> + get(MLIRContext *ctx) { + return ctx->getTypeUniquer().get(T::getTypeID()); } /// Change the mutable component of the given type instance in the provided @@ -141,6 +149,25 @@ return ctx->getTypeUniquer().mutate(T::getTypeID(), impl, std::forward(args)...); } + + /// Register a parametric type instance T with the uniquer. + template + static typename std::enable_if_t< + !std::is_same::value> + registerType(MLIRContext *ctx) { + ctx->getTypeUniquer().registerParametricStorageType( + T::getTypeID()); + } + /// Register a singleton type instance T with the uniquer. + template + static typename std::enable_if_t< + std::is_same::value> + registerType(MLIRContext *ctx) { + ctx->getTypeUniquer().registerSingletonStorageType( + T::getTypeID(), [&](TypeStorage *storage) { + storage->initialize(AbstractType::lookup(T::getTypeID(), ctx)); + }); + } }; } // namespace detail diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -34,11 +34,11 @@ /// /// Some types are "primitives" meaning they do not have any parameters, for /// example the Index type. Parametric types have additional information that -/// differentiates the types of the same kind between them, for example the -/// Integer type has bitwidth, making i8 and i16 belong to the same kind by be -/// different instances of the IntegerType. Type parameters are part of the -/// unique immutable key. The mutable component of the type can be modified -/// after the type is created, but cannot affect the identity of the type. +/// differentiates the types of the same class, for example the Integer type has +/// bitwidth, making i8 and i16 belong to the same kind by be different +/// instances of the IntegerType. Type parameters are part of the unique +/// immutable key. The mutable component of the type can be modified after the +/// type is created, but cannot affect the identity of the type. /// /// Types are constructed and uniqued via the 'detail::TypeUniquer' class. /// @@ -53,20 +53,19 @@ /// * This method is expected to return failure if a type cannot be /// constructed with 'args', success otherwise. /// * 'args' must correspond with the arguments passed into the -/// 'TypeBase::get' call after the type kind. +/// 'TypeBase::get' call. /// /// /// Type storage objects inherit from TypeStorage and contain the following: -/// - The type kind (for LLVM-style RTTI). /// - The dialect that defined the type. /// - Any parameters of the type. /// - An optional mutable component. /// For non-parametric types, a convenience DefaultTypeStorage is provided. /// Parametric storage types must derive TypeStorage and respect the following: /// - Define a type alias, KeyTy, to a type that uniquely identifies the -/// instance of the type within its kind. +/// instance of the type. /// * The key type must be constructible from the values passed into the -/// detail::TypeUniquer::get call after the type kind. +/// detail::TypeUniquer::get call. /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the /// storage class must define a hashing method: /// 'static unsigned hashKey(const KeyTy &)' @@ -84,23 +83,6 @@ // the key. class Type { public: - /// Integer identifier for all the concrete type kinds. - /// Note: This is not an enum class as each dialect will likely define a - /// separate enumeration for the specific types that they define. Not being an - /// enum class also simplifies the handling of type kinds by not requiring - /// casts for each use. - enum Kind { - // Builtin types. - Function, - Opaque, - LAST_BUILTIN_TYPE = Opaque, - - // Reserve type kinds for dialect specific type system extensions. -#define DEFINE_SYM_KIND_RANGE(Dialect) \ - FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff, -#include "DialectSymbolRegistry.def" - }; - /// Utility class for implementing types. template class... Traits> @@ -136,9 +118,6 @@ /// dynamic type casting. TypeID getTypeID() { return impl->getAbstractType().getTypeID(); } - /// Return the classification for this type. - unsigned getKind() const; - /// Return the LLVMContext in which this type was uniqued. MLIRContext *getContext() const; diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -11,12 +11,11 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Allocator.h" namespace mlir { -class TypeID; - namespace detail { struct StorageUniquerImpl; @@ -29,22 +28,19 @@ using has_impltype_hash_t = decltype(ImplTy::hashKey(std::declval())); } // namespace detail -/// A utility class to get, or create instances of storage classes. These -/// storage classes must respect the following constraints: -/// - Derive from StorageUniquer::BaseStorage. -/// - Provide an unsigned 'kind' value to be used as part of the unique'ing -/// process. +/// A utility class to get or create instances of "storage classes". These +/// storage classes must derive from 'StorageUniquer::BaseStorage'. /// -/// For non-parametric storage classes, i.e. those that are solely uniqued by -/// their kind, nothing else is needed. Instances of these classes can be -/// created by calling `get` without trailing arguments. +/// For non-parametric storage classes, i.e. singleton classes, nothing else is +/// needed. Instances of these classes can be created by calling `get` without +/// trailing arguments. /// /// Otherwise, the parametric storage classes may be created with `get`, /// and must respect the following: /// - Define a type alias, KeyTy, to a type that uniquely identifies the -/// instance of the storage class within its kind. +/// instance of the storage class. /// * The key type must be constructible from the values passed into the -/// getComplex call after the kind. +/// getComplex call. /// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the /// storage class must define a hashing method: /// 'static unsigned hashKey(const KeyTy &)' @@ -83,32 +79,11 @@ /// class. class StorageUniquer { public: - StorageUniquer(); - ~StorageUniquer(); - - /// Set the flag specifying if multi-threading is disabled within the uniquer. - void disableMultithreading(bool disable = true); - - /// Register a new storage object with this uniquer using the given unique - /// type id. - void registerStorageType(TypeID id); - /// This class acts as the base storage that all storage classes must derived /// from. class BaseStorage { - public: - /// Get the kind classification of this storage. - unsigned getKind() const { return kind; } - protected: - BaseStorage() : kind(0) {} - - private: - /// Allow access to the kind field. - friend detail::StorageUniquerImpl; - - /// Classification of the subclass, used for type checking. - unsigned kind; + BaseStorage() = default; }; /// This is a utility allocator used to allocate memory for instances of @@ -145,19 +120,61 @@ llvm::BumpPtrAllocator allocator; }; - /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter - /// that can be used to initialize a newly inserted storage instance. This - /// function is used for derived types that have complex storage or uniquing + StorageUniquer(); + ~StorageUniquer(); + + /// Set the flag specifying if multi-threading is disabled within the uniquer. + void disableMultithreading(bool disable = true); + + /// Register a new parametric storage class, this is necessary to create + /// instances of this class type. `id` is the type identifier that will be + /// used to identify this type when creating instances of it via 'get'. + template void registerParametricStorageType(TypeID id) { + registerParametricStorageTypeImpl(id); + } + /// Utility override when the storage type represents the type id. + template void registerParametricStorageType() { + registerParametricStorageType(TypeID::get()); + } + /// Register a new singleton storage class, this is necessary to get the + /// singletone instance. `id` is the type identifier that will be used to + /// access the singleton instance via 'get'. An optional initialization + /// function may also be provided to initialize the newly created storage + /// instance, and used when the singleton instance is created. + template + void registerSingletonStorageType(TypeID id, + function_ref initFn) { + auto ctorFn = [&](StorageAllocator &allocator) { + auto *storage = new (allocator.allocate()) Storage(); + if (initFn) + initFn(storage); + return storage; + }; + registerSingletonImpl(id, ctorFn); + } + template void registerSingletonStorageType(TypeID id) { + registerSingletonStorageType(id, llvm::None); + } + /// Utility override when the storage type represents the type id. + template + void registerSingletonStorageType( + function_ref initFn = llvm::None) { + registerSingletonStorageType(TypeID::get(), initFn); + } + + /// Gets a uniqued instance of 'Storage'. 'id' is the type id used when + /// registering the storage instance. 'initFn' is an optional parameter that + /// can be used to initialize a newly inserted storage instance. This function + /// is used for derived types that have complex storage or uniquing /// constraints. - template - Storage *get(const TypeID &id, function_ref initFn, - unsigned kind, Arg &&arg, Args &&...args) { + template + Storage *get(function_ref initFn, TypeID id, + Args &&...args) { // Construct a value of the derived key type. - auto derivedKey = - getKey(std::forward(arg), std::forward(args)...); + auto derivedKey = getKey(std::forward(args)...); - // Create a hash of the kind and the derived key. - unsigned hashValue = getHash(kind, derivedKey); + // Create a hash of the derived key. + unsigned hashValue = getHash(derivedKey); // Generate an equality function for the derived storage. auto isEqual = [&derivedKey](const BaseStorage *existing) { @@ -174,29 +191,29 @@ // Get an instance for the derived storage. return static_cast( - getImpl(id, kind, hashValue, isEqual, ctorFn)); + getParametricStorageTypeImpl(id, hashValue, isEqual, ctorFn)); + } + /// Utility override when the storage type represents the type id. + template + Storage *get(function_ref initFn, Args &&...args) { + return get(initFn, TypeID::get(), + std::forward(args)...); } - /// Gets a uniqued instance of 'Storage'. 'initFn' is an optional parameter - /// that can be used to initialize a newly inserted storage instance. This - /// function is used for derived types that use no additional storage or - /// uniquing outside of the kind. - template - Storage *get(const TypeID &id, function_ref initFn, - unsigned kind) { - auto ctorFn = [&](StorageAllocator &allocator) { - auto *storage = new (allocator.allocate()) Storage(); - if (initFn) - initFn(storage); - return storage; - }; - return static_cast(getImpl(id, kind, ctorFn)); + /// Gets a uniqued instance of 'Storage' which is a singleton storage type. + /// 'id' is the type id used when registering the storage instance. + template Storage *get(TypeID id) { + return static_cast(getSingletonImpl(id)); + } + /// Utility override when the storage type represents the type id. + template Storage *get() { + return get(TypeID::get()); } /// Changes the mutable component of 'storage' by forwarding the trailing /// arguments to the 'mutate' function of the derived class. template - LogicalResult mutate(const TypeID &id, Storage *storage, Args &&...args) { + LogicalResult mutate(TypeID id, Storage *storage, Args &&...args) { auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult { return static_cast(*storage).mutate( allocator, std::forward(args)...); @@ -207,13 +224,13 @@ /// Erases a uniqued instance of 'Storage'. This function is used for derived /// types that have complex storage or uniquing constraints. template - void erase(const TypeID &id, unsigned kind, Arg &&arg, Args &&...args) { + void erase(TypeID id, Arg &&arg, Args &&...args) { // Construct a value of the derived key type. auto derivedKey = getKey(std::forward(arg), std::forward(args)...); - // Create a hash of the kind and the derived key. - unsigned hashValue = getHash(kind, derivedKey); + // Create a hash of the derived key. + unsigned hashValue = getHash(derivedKey); // Generate an equality function for the derived storage. auto isEqual = [&derivedKey](const BaseStorage *existing) { @@ -221,32 +238,42 @@ }; // Attempt to erase the storage instance. - eraseImpl(id, kind, hashValue, isEqual, [](BaseStorage *storage) { + eraseImpl(id, hashValue, isEqual, [](BaseStorage *storage) { static_cast(storage)->cleanup(); }); } private: /// Implementation for getting/creating an instance of a derived type with - /// complex storage. - BaseStorage *getImpl(const TypeID &id, unsigned kind, unsigned hashValue, - function_ref isEqual, - function_ref ctorFn); + /// parametric storage. + BaseStorage *getParametricStorageTypeImpl( + TypeID id, unsigned hashValue, + function_ref isEqual, + function_ref ctorFn); - /// Implementation for getting/creating an instance of a derived type with - /// default storage. - BaseStorage *getImpl(const TypeID &id, unsigned kind, - function_ref ctorFn); + /// Implementation for registering an instance of a derived type with + /// parametric storage. + void registerParametricStorageTypeImpl(TypeID id); + + /// Implementation for getting an instance of a derived type with default + /// storage. + BaseStorage *getSingletonImpl(TypeID id); + + /// Implementation for registering an instance of a derived type with default + /// storage. + void + registerSingletonImpl(TypeID id, + function_ref ctorFn); /// Implementation for erasing an instance of a derived type with complex /// storage. - void eraseImpl(const TypeID &id, unsigned kind, unsigned hashValue, + void eraseImpl(TypeID id, unsigned hashValue, function_ref isEqual, function_ref cleanupFn); /// Implementation for mutating an instance of a derived storage. LogicalResult - mutateImpl(const TypeID &id, + mutateImpl(TypeID id, function_ref mutationFn); /// The internal implementation class. @@ -276,27 +303,26 @@ } //===--------------------------------------------------------------------===// - // Key and Kind Hashing + // Key Hashing //===--------------------------------------------------------------------===// - /// Used to generate a hash for the 'ImplTy::KeyTy' and kind of a storage - /// instance if there is an 'ImplTy::hashKey' overload for 'DerivedKey'. + /// Used to generate a hash for the 'ImplTy::KeyTy' of a storage instance if + /// there is an 'ImplTy::hashKey' overload for 'DerivedKey'. template static typename std::enable_if< llvm::is_detected::value, ::llvm::hash_code>::type - getHash(unsigned kind, const DerivedKey &derivedKey) { - return llvm::hash_combine(kind, ImplTy::hashKey(derivedKey)); + getHash(const DerivedKey &derivedKey) { + return ImplTy::hashKey(derivedKey); } - /// If there is no 'ImplTy::hashKey' default to using the - /// 'llvm::DenseMapInfo' definition for 'DerivedKey' for generating a hash. + /// If there is no 'ImplTy::hashKey' default to using the 'llvm::DenseMapInfo' + /// definition for 'DerivedKey' for generating a hash. template static typename std::enable_if::value, ::llvm::hash_code>::type - getHash(unsigned kind, const DerivedKey &derivedKey) { - return llvm::hash_combine( - kind, DenseMapInfo::getHashValue(derivedKey)); + getHash(const DerivedKey &derivedKey) { + return DenseMapInfo::getHashValue(derivedKey); } }; } // end namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -264,14 +264,13 @@ LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); - return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType, - numElements); + return Base::get(elementType.getContext(), elementType, numElements); } LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); - return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements); + return Base::getChecked(loc, elementType, numElements); } LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; } @@ -301,16 +300,14 @@ ArrayRef arguments, bool isVarArg) { assert(result && "expected non-null result"); - return Base::get(result.getContext(), LLVMType::FunctionType, result, - arguments, isVarArg); + return Base::get(result.getContext(), result, arguments, isVarArg); } LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result, ArrayRef arguments, bool isVarArg) { assert(result && "expected non-null result"); - return Base::getChecked(loc, LLVMType::FunctionType, result, arguments, - isVarArg); + return Base::getChecked(loc, result, arguments, isVarArg); } LLVMType LLVMFunctionType::getReturnType() { @@ -347,11 +344,11 @@ // Integer type. LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) { - return Base::get(ctx, LLVMType::IntegerType, bitwidth); + return Base::get(ctx, bitwidth); } LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) { - return Base::getChecked(loc, LLVMType::IntegerType, bitwidth); + return Base::getChecked(loc, bitwidth); } unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; } @@ -374,13 +371,12 @@ LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) { assert(pointee && "expected non-null subtype"); - return Base::get(pointee.getContext(), LLVMType::PointerType, pointee, - addressSpace); + return Base::get(pointee.getContext(), pointee, addressSpace); } LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee, unsigned addressSpace) { - return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace); + return Base::getChecked(loc, pointee, addressSpace); } LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; } @@ -405,32 +401,32 @@ LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, StringRef name) { - return Base::get(context, LLVMType::StructType, name, /*opaque=*/false); + return Base::get(context, name, /*opaque=*/false); } LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc, StringRef name) { - return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false); + return Base::getChecked(loc, name, /*opaque=*/false); } LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, ArrayRef types, bool isPacked) { - return Base::get(context, LLVMType::StructType, types, isPacked); + return Base::get(context, types, isPacked); } LLVMStructType LLVMStructType::getLiteralChecked(Location loc, ArrayRef types, bool isPacked) { - return Base::getChecked(loc, LLVMType::StructType, types, isPacked); + return Base::getChecked(loc, types, isPacked); } LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) { - return Base::get(context, LLVMType::StructType, name, /*opaque=*/true); + return Base::get(context, name, /*opaque=*/true); } LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) { - return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true); + return Base::getChecked(loc, name, /*opaque=*/true); } LogicalResult LLVMStructType::setBody(ArrayRef types, bool isPacked) { @@ -508,16 +504,14 @@ LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); - return Base::get(elementType.getContext(), LLVMType::FixedVectorType, - elementType, numElements); + return Base::get(elementType.getContext(), elementType, numElements); } LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc, LLVMType elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); - return Base::getChecked(loc, LLVMType::FixedVectorType, elementType, - numElements); + return Base::getChecked(loc, elementType, numElements); } unsigned LLVMFixedVectorType::getNumElements() { @@ -527,16 +521,14 @@ LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType, unsigned minNumElements) { assert(elementType && "expected non-null subtype"); - return Base::get(elementType.getContext(), LLVMType::ScalableVectorType, - elementType, minNumElements); + return Base::get(elementType.getContext(), elementType, minNumElements); } LLVMScalableVectorType LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType, unsigned minNumElements) { assert(elementType && "expected non-null subtype"); - return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType, - minNumElements); + return Base::getChecked(loc, elementType, minNumElements); } unsigned LLVMScalableVectorType::getMinNumElements() { diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -204,8 +204,8 @@ Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { - return Base::get(storageType.getContext(), QuantizationTypes::Any, flags, - storageType, expressedType, storageTypeMin, storageTypeMax); + return Base::get(storageType.getContext(), flags, storageType, expressedType, + storageTypeMin, storageTypeMax); } AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType, @@ -213,8 +213,8 @@ int64_t storageTypeMin, int64_t storageTypeMax, Location location) { - return Base::getChecked(location, QuantizationTypes::Any, flags, storageType, - expressedType, storageTypeMin, storageTypeMax); + return Base::getChecked(location, flags, storageType, expressedType, + storageTypeMin, storageTypeMax); } LogicalResult AnyQuantizedType::verifyConstructionInvariants( @@ -240,10 +240,8 @@ int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) { - return Base::get(storageType.getContext(), - QuantizationTypes::UniformQuantized, flags, storageType, - expressedType, scale, zeroPoint, storageTypeMin, - storageTypeMax); + return Base::get(storageType.getContext(), flags, storageType, expressedType, + scale, zeroPoint, storageTypeMin, storageTypeMax); } UniformQuantizedType @@ -251,9 +249,8 @@ Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax, Location location) { - return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags, - storageType, expressedType, scale, zeroPoint, - storageTypeMin, storageTypeMax); + return Base::getChecked(location, flags, storageType, expressedType, scale, + zeroPoint, storageTypeMin, storageTypeMax); } LogicalResult UniformQuantizedType::verifyConstructionInvariants( @@ -295,10 +292,9 @@ ArrayRef scales, ArrayRef zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax) { - return Base::get(storageType.getContext(), - QuantizationTypes::UniformQuantizedPerAxis, flags, - storageType, expressedType, scales, zeroPoints, - quantizedDimension, storageTypeMin, storageTypeMax); + return Base::get(storageType.getContext(), flags, storageType, expressedType, + scales, zeroPoints, quantizedDimension, storageTypeMin, + storageTypeMax); } UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( @@ -306,9 +302,9 @@ ArrayRef scales, ArrayRef zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax, Location location) { - return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis, - flags, storageType, expressedType, scales, zeroPoints, - quantizedDimension, storageTypeMin, storageTypeMax); + return Base::getChecked(location, flags, storageType, expressedType, scales, + zeroPoints, quantizedDimension, storageTypeMin, + storageTypeMax); } LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants( diff --git a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp --- a/mlir/lib/Dialect/SDBM/SDBMDialect.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMDialect.cpp @@ -13,11 +13,11 @@ SDBMDialect::SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get()) { - uniquer.registerStorageType(TypeID::get()); - uniquer.registerStorageType(TypeID::get()); - uniquer.registerStorageType(TypeID::get()); - uniquer.registerStorageType(TypeID::get()); - uniquer.registerStorageType(TypeID::get()); + uniquer.registerParametricStorageType(); + uniquer.registerParametricStorageType(); + uniquer.registerParametricStorageType(); + uniquer.registerParametricStorageType(); + uniquer.registerParametricStorageType(); } SDBMDialect::~SDBMDialect() = default; diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp --- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp +++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp @@ -246,7 +246,6 @@ StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); return uniquer.get( - TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Add), lhs, rhs); } @@ -533,9 +532,7 @@ assert(rhs && "expected SDBM dimension"); StorageUniquer &uniquer = lhs.getDialect()->getUniquer(); - return uniquer.get( - TypeID::get(), - /*initFn=*/{}, static_cast(SDBMExprKind::Diff), lhs, rhs); + return uniquer.get(/*initFn=*/{}, lhs, rhs); } SDBMDirectExpr SDBMDiffExpr::getLHS() const { @@ -575,7 +572,6 @@ StorageUniquer &uniquer = var.getDialect()->getUniquer(); return uniquer.get( - TypeID::get(), /*initFn=*/{}, static_cast(SDBMExprKind::Stripe), var, stripeFactor); } @@ -611,8 +607,7 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - TypeID::get(), assignDialect, - static_cast(SDBMExprKind::DimId), position); + assignDialect, static_cast(SDBMExprKind::DimId), position); } //===----------------------------------------------------------------------===// @@ -628,8 +623,7 @@ StorageUniquer &uniquer = dialect->getUniquer(); return uniquer.get( - TypeID::get(), assignDialect, - static_cast(SDBMExprKind::SymbolId), position); + assignDialect, static_cast(SDBMExprKind::SymbolId), position); } //===----------------------------------------------------------------------===// @@ -644,9 +638,7 @@ }; StorageUniquer &uniquer = dialect->getUniquer(); - return uniquer.get( - TypeID::get(), assignCtx, - static_cast(SDBMExprKind::Constant), value); + return uniquer.get(assignCtx, value); } int64_t SDBMConstantExpr::getValue() const { @@ -661,9 +653,7 @@ assert(var && "expected non-null SDBM direct expression"); StorageUniquer &uniquer = var.getDialect()->getUniquer(); - return uniquer.get( - TypeID::get(), - /*initFn=*/{}, static_cast(SDBMExprKind::Neg), var); + return uniquer.get(/*initFn=*/{}, var); } SDBMDirectExpr SDBMNegExpr::getVar() const { diff --git a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h --- a/mlir/lib/Dialect/SDBM/SDBMExprDetail.h +++ b/mlir/lib/Dialect/SDBM/SDBMExprDetail.h @@ -25,27 +25,28 @@ // Base storage class for SDBMExpr. struct SDBMExprStorage : public StorageUniquer::BaseStorage { - SDBMExprKind getKind() { - return static_cast(BaseStorage::getKind()); - } + SDBMExprKind getKind() { return kind; } SDBMDialect *dialect; + SDBMExprKind kind; }; // Storage class for SDBM sum and stripe expressions. struct SDBMBinaryExprStorage : public SDBMExprStorage { - using KeyTy = std::pair; + using KeyTy = std::tuple; bool operator==(const KeyTy &key) const { - return std::get<0>(key) == lhs && std::get<1>(key) == rhs; + return static_cast(std::get<0>(key)) == kind && + std::get<1>(key) == lhs && std::get<2>(key) == rhs; } static SDBMBinaryExprStorage * construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { auto *result = allocator.allocate(); - result->lhs = std::get<0>(key); - result->rhs = std::get<1>(key); + result->lhs = std::get<1>(key); + result->rhs = std::get<2>(key); result->dialect = result->lhs.getDialect(); + result->kind = static_cast(std::get<0>(key)); return result; } @@ -67,6 +68,7 @@ result->lhs = std::get<0>(key); result->rhs = std::get<1>(key); result->dialect = result->lhs.getDialect(); + result->kind = SDBMExprKind::Diff; return result; } @@ -84,6 +86,7 @@ construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { auto *result = allocator.allocate(); result->constant = key; + result->kind = SDBMExprKind::Constant; return result; } @@ -92,14 +95,18 @@ // Storage class for SDBM dimension and symbol expressions. struct SDBMTermExprStorage : public SDBMExprStorage { - using KeyTy = unsigned; + using KeyTy = std::pair; - bool operator==(const KeyTy &key) const { return position == key; } + bool operator==(const KeyTy &key) const { + return kind == static_cast(key.first) && + position == key.second; + } static SDBMTermExprStorage * construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { auto *result = allocator.allocate(); - result->position = key; + result->kind = static_cast(key.first); + result->position = key.second; return result; } @@ -117,6 +124,7 @@ auto *result = allocator.allocate(); result->expr = key; result->dialect = key.getDialect(); + result->kind = SDBMExprKind::Neg; return result; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVAttributes.cpp @@ -120,8 +120,7 @@ IntegerAttr storageClass) { assert(descriptorSet && binding); MLIRContext *context = descriptorSet.getContext(); - return Base::get(context, spirv::AttrKind::InterfaceVarABI, descriptorSet, - binding, storageClass); + return Base::get(context, descriptorSet, binding, storageClass); } StringRef spirv::InterfaceVarABIAttr::getKindName() { @@ -195,8 +194,7 @@ ArrayAttr extensions) { assert(version && capabilities && extensions); MLIRContext *context = version.getContext(); - return Base::get(context, spirv::AttrKind::VerCapExt, version, capabilities, - extensions); + return Base::get(context, version, capabilities, extensions); } StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; } @@ -272,7 +270,7 @@ DictionaryAttr limits) { assert(triple && limits && "expected valid triple and limits"); MLIRContext *context = triple.getContext(); - return Base::get(context, spirv::AttrKind::TargetEnv, triple, limits); + return Base::get(context, triple, limits); } StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -124,15 +124,14 @@ ArrayType ArrayType::get(Type elementType, unsigned elementCount) { assert(elementCount && "ArrayType needs at least one element"); - return Base::get(elementType.getContext(), TypeKind::Array, elementType, - elementCount, /*stride=*/0); + return Base::get(elementType.getContext(), elementType, elementCount, + /*stride=*/0); } ArrayType ArrayType::get(Type elementType, unsigned elementCount, unsigned stride) { assert(elementCount && "ArrayType needs at least one element"); - return Base::get(elementType.getContext(), TypeKind::Array, elementType, - elementCount, stride); + return Base::get(elementType.getContext(), elementType, elementCount, stride); } unsigned ArrayType::getNumElements() const { return getImpl()->elementCount; } @@ -285,8 +284,7 @@ CooperativeMatrixNVType CooperativeMatrixNVType::get(Type elementType, Scope scope, unsigned rows, unsigned columns) { - return Base::get(elementType.getContext(), TypeKind::CooperativeMatrix, - elementType, scope, rows, columns); + return Base::get(elementType.getContext(), elementType, scope, rows, columns); } Type CooperativeMatrixNVType::getElementType() const { @@ -389,7 +387,7 @@ ImageType::get(std::tuple value) { - return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value); + return Base::get(std::get<0>(value).getContext(), value); } Type ImageType::getElementType() const { return getImpl()->elementType; } @@ -453,8 +451,7 @@ }; PointerType PointerType::get(Type pointeeType, StorageClass storageClass) { - return Base::get(pointeeType.getContext(), TypeKind::Pointer, pointeeType, - storageClass); + return Base::get(pointeeType.getContext(), pointeeType, storageClass); } Type PointerType::getPointeeType() const { return getImpl()->pointeeType; } @@ -511,13 +508,11 @@ }; RuntimeArrayType RuntimeArrayType::get(Type elementType) { - return Base::get(elementType.getContext(), TypeKind::RuntimeArray, - elementType, /*stride=*/0); + return Base::get(elementType.getContext(), elementType, /*stride=*/0); } RuntimeArrayType RuntimeArrayType::get(Type elementType, unsigned stride) { - return Base::get(elementType.getContext(), TypeKind::RuntimeArray, - elementType, stride); + return Base::get(elementType.getContext(), elementType, stride); } Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } @@ -846,12 +841,12 @@ SmallVector sortedDecorations( memberDecorations.begin(), memberDecorations.end()); llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); - return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct, - memberTypes, offsetInfo, sortedDecorations); + return Base::get(memberTypes.vec().front().getContext(), memberTypes, + offsetInfo, sortedDecorations); } StructType StructType::getEmpty(MLIRContext *context) { - return Base::get(context, TypeKind::Struct, ArrayRef(), + return Base::get(context, ArrayRef(), ArrayRef(), ArrayRef()); } @@ -946,13 +941,12 @@ }; MatrixType MatrixType::get(Type columnType, uint32_t columnCount) { - return Base::get(columnType.getContext(), TypeKind::Matrix, columnType, - columnCount); + return Base::get(columnType.getContext(), columnType, columnCount); } MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount, Location location) { - return Base::getChecked(location, TypeKind::Matrix, columnType, columnCount); + return Base::getChecked(location, columnType, columnCount); } LogicalResult MatrixType::verifyConstructionInvariants(Location loc, diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -20,9 +20,7 @@ MLIRContext *AffineExpr::getContext() const { return expr->context; } -AffineExprKind AffineExpr::getKind() const { - return static_cast(expr->getKind()); -} +AffineExprKind AffineExpr::getKind() const { return expr->kind; } /// Walk all of the AffineExprs in this subgraph in postorder. void AffineExpr::walk(std::function callback) const { @@ -449,8 +447,7 @@ StorageUniquer &uniquer = context->getAffineUniquer(); return uniquer.get( - TypeID::get(), assignCtx, - static_cast(kind), position); + assignCtx, static_cast(kind), position); } AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) { @@ -484,9 +481,7 @@ }; StorageUniquer &uniquer = context->getAffineUniquer(); - return uniquer.get( - TypeID::get(), assignCtx, - static_cast(AffineExprKind::Constant), constant); + return uniquer.get(assignCtx, constant); } /// Simplify add expression. Return nullptr if it can't be simplified. @@ -594,7 +589,6 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( - TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Add), *this, other); } @@ -655,7 +649,6 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( - TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Mul), *this, other); } @@ -722,7 +715,6 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( - TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::FloorDiv), *this, other); } @@ -766,7 +758,6 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( - TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::CeilDiv), *this, other); } @@ -814,7 +805,6 @@ StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( - TypeID::get(), /*initFn=*/{}, static_cast(AffineExprKind::Mod), *this, other); } diff --git a/mlir/lib/IR/AffineExprDetail.h b/mlir/lib/IR/AffineExprDetail.h --- a/mlir/lib/IR/AffineExprDetail.h +++ b/mlir/lib/IR/AffineExprDetail.h @@ -27,21 +27,24 @@ /// Base storage class appearing in an affine expression. struct AffineExprStorage : public StorageUniquer::BaseStorage { MLIRContext *context; + AffineExprKind kind; }; /// A binary operation appearing in an affine expression. struct AffineBinaryOpExprStorage : public AffineExprStorage { - using KeyTy = std::pair; + using KeyTy = std::tuple; bool operator==(const KeyTy &key) const { - return key.first == lhs && key.second == rhs; + return static_cast(std::get<0>(key)) == kind && + std::get<1>(key) == lhs && std::get<2>(key) == rhs; } static AffineBinaryOpExprStorage * construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { auto *result = allocator.allocate(); - result->lhs = key.first; - result->rhs = key.second; + result->kind = static_cast(std::get<0>(key)); + result->lhs = std::get<1>(key); + result->rhs = std::get<2>(key); result->context = result->lhs.getContext(); return result; } @@ -52,14 +55,18 @@ /// A dimensional or symbolic identifier appearing in an affine expression. struct AffineDimExprStorage : public AffineExprStorage { - using KeyTy = unsigned; + using KeyTy = std::pair; - bool operator==(const KeyTy &key) const { return position == key; } + bool operator==(const KeyTy &key) const { + return kind == static_cast(key.first) && + position == key.second; + } static AffineDimExprStorage * construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { auto *result = allocator.allocate(); - result->position = key; + result->kind = static_cast(key.first); + result->position = key.second; return result; } @@ -76,6 +83,7 @@ static AffineConstantExprStorage * construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) { auto *result = allocator.allocate(); + result->kind = AffineExprKind::Constant; result->constant = key; return result; } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -271,7 +271,7 @@ /// Mapping between attribute kind and a pair comprised of a base alias name /// and a unique list of attributes belonging to this kind sorted by location /// seen in the module. - llvm::MapVector>> + llvm::MapVector>> attrKindToAlias; /// Set of types known to be used within the module. @@ -301,13 +301,13 @@ llvm::StringSet<> usedAliases; // Collect the set of aliases from each dialect. - SmallVector, 8> attributeKindAliases; + SmallVector, 8> attributeKindAliases; SmallVector, 8> attributeAliases; SmallVector, 16> typeAliases; // AffineMap/Integer set have specific kind aliases. - attributeKindAliases.emplace_back(StandardAttributes::AffineMap, "map"); - attributeKindAliases.emplace_back(StandardAttributes::IntegerSet, "set"); + attributeKindAliases.emplace_back(AffineMapAttr::getTypeID(), "map"); + attributeKindAliases.emplace_back(IntegerSetAttr::getTypeID(), "set"); for (auto &interface : interfaces) { interface.getAttributeKindAliases(attributeKindAliases); @@ -317,7 +317,7 @@ // Setup the attribute kind aliases. StringRef alias; - unsigned attrKind; + TypeID attrKind; for (auto &attrAliasPair : attributeKindAliases) { std::tie(attrKind, alias) = attrAliasPair; assert(!alias.empty() && "expected non-empty alias string"); @@ -420,7 +420,7 @@ return; // If this attribute kind has an alias, then record one for this attribute. - auto alias = attrKindToAlias.find(static_cast(attr.getKind())); + auto alias = attrKindToAlias.find(attr.getTypeID()); if (alias == attrKindToAlias.end()) return; std::pair attrAlias(alias->second.first, diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -57,7 +57,7 @@ //===----------------------------------------------------------------------===// AffineMapAttr AffineMapAttr::get(AffineMap value) { - return Base::get(value.getContext(), StandardAttributes::AffineMap, value); + return Base::get(value.getContext(), value); } AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } @@ -67,7 +67,7 @@ //===----------------------------------------------------------------------===// ArrayAttr ArrayAttr::get(ArrayRef value, MLIRContext *context) { - return Base::get(context, StandardAttributes::Array, value); + return Base::get(context, value); } ArrayRef ArrayAttr::getValue() const { return getImpl()->value; } @@ -156,7 +156,7 @@ if (dictionaryAttrSort(value, storage)) value = storage; - return Base::get(context, StandardAttributes::Dictionary, value); + return Base::get(context, value); } /// Construct a dictionary with an array of values that is known to already be /// sorted by name and uniqued. @@ -175,7 +175,7 @@ return l.first == r.first; }) == value.end() && "DictionaryAttr element names must be unique"); - return Base::get(context, StandardAttributes::Dictionary, value); + return Base::get(context, value); } ArrayRef DictionaryAttr::getValue() const { @@ -219,19 +219,19 @@ //===----------------------------------------------------------------------===// FloatAttr FloatAttr::get(Type type, double value) { - return Base::get(type.getContext(), StandardAttributes::Float, type, value); + return Base::get(type.getContext(), type, value); } FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { - return Base::getChecked(loc, StandardAttributes::Float, type, value); + return Base::getChecked(loc, type, value); } FloatAttr FloatAttr::get(Type type, const APFloat &value) { - return Base::get(type.getContext(), StandardAttributes::Float, type, value); + return Base::get(type.getContext(), type, value); } FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { - return Base::getChecked(loc, StandardAttributes::Float, type, value); + return Base::getChecked(loc, type, value); } APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } @@ -279,14 +279,13 @@ //===----------------------------------------------------------------------===// FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { - return Base::get(ctx, StandardAttributes::SymbolRef, value, llvm::None) - .cast(); + return Base::get(ctx, value, llvm::None).cast(); } SymbolRefAttr SymbolRefAttr::get(StringRef value, ArrayRef nestedReferences, MLIRContext *ctx) { - return Base::get(ctx, StandardAttributes::SymbolRef, value, nestedReferences); + return Base::get(ctx, value, nestedReferences); } StringRef SymbolRefAttr::getRootReference() const { return getImpl()->value; } @@ -307,7 +306,7 @@ IntegerAttr IntegerAttr::get(Type type, const APInt &value) { if (type.isSignlessInteger(1)) return BoolAttr::get(value.getBoolValue(), type.getContext()); - return Base::get(type.getContext(), StandardAttributes::Integer, type, value); + return Base::get(type.getContext(), type, value); } IntegerAttr IntegerAttr::get(Type type, int64_t value) { @@ -380,8 +379,7 @@ //===----------------------------------------------------------------------===// IntegerSetAttr IntegerSetAttr::get(IntegerSet value) { - return Base::get(value.getConstraint(0).getContext(), - StandardAttributes::IntegerSet, value); + return Base::get(value.getConstraint(0).getContext(), value); } IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } @@ -392,14 +390,12 @@ OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, MLIRContext *context) { - return Base::get(context, StandardAttributes::Opaque, dialect, attrData, - type); + return Base::get(context, dialect, attrData, type); } OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, Type type, Location location) { - return Base::getChecked(location, StandardAttributes::Opaque, dialect, - attrData, type); + return Base::getChecked(location, dialect, attrData, type); } /// Returns the dialect namespace of the opaque attribute. @@ -430,7 +426,7 @@ /// Get an instance of a StringAttr with the given string and Type. StringAttr StringAttr::get(StringRef bytes, Type type) { - return Base::get(type.getContext(), StandardAttributes::String, bytes, type); + return Base::get(type.getContext(), bytes, type); } StringRef StringAttr::getValue() const { return getImpl()->value; } @@ -440,7 +436,7 @@ //===----------------------------------------------------------------------===// TypeAttr TypeAttr::get(Type value) { - return Base::get(value.getContext(), StandardAttributes::Type, value); + return Base::get(value.getContext(), value); } Type TypeAttr::getValue() const { return getImpl()->value; } @@ -1036,8 +1032,7 @@ DenseStringElementsAttr DenseStringElementsAttr::get(ShapedType type, ArrayRef values) { - return Base::get(type.getContext(), StandardAttributes::DenseStringElements, - type, values, (values.size() == 1)); + return Base::get(type.getContext(), type, values, (values.size() == 1)); } //===----------------------------------------------------------------------===// @@ -1088,8 +1083,7 @@ assert((type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), StandardAttributes::DenseIntOrFPElements, - type, data, isSplat); + return Base::get(type.getContext(), type, data, isSplat); } /// Overload of the raw 'get' method that asserts that the given type is of @@ -1210,8 +1204,7 @@ StringRef bytes) { assert(TensorType::isValidElementType(type.getElementType()) && "Input element type should be a valid tensor element type"); - return Base::get(type.getContext(), StandardAttributes::OpaqueElements, type, - dialect, bytes); + return Base::get(type.getContext(), type, dialect, bytes); } StringRef OpaqueElementsAttr::getValue() const { return getImpl()->bytes; } @@ -1248,7 +1241,7 @@ assert((type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); - return Base::get(type.getContext(), StandardAttributes::SparseElements, type, + return Base::get(type.getContext(), type, indices.cast(), values); } diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -28,8 +28,7 @@ //===----------------------------------------------------------------------===// Location CallSiteLoc::get(Location callee, Location caller) { - return Base::get(callee->getContext(), StandardAttributes::CallSiteLocation, - callee, caller); + return Base::get(callee->getContext(), callee, caller); } Location CallSiteLoc::get(Location name, ArrayRef frames) { @@ -50,8 +49,7 @@ Location FileLineColLoc::get(Identifier filename, unsigned line, unsigned column, MLIRContext *context) { - return Base::get(context, StandardAttributes::FileLineColLocation, filename, - line, column); + return Base::get(context, filename, line, column); } Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column, @@ -95,7 +93,7 @@ return UnknownLoc::get(context); if (locs.size() == 1) return locs.front(); - return Base::get(context, StandardAttributes::FusedLocation, locs, metadata); + return Base::get(context, locs, metadata); } ArrayRef FusedLoc::getLocations() const { @@ -111,8 +109,7 @@ Location NameLoc::get(Identifier name, Location child) { assert(!child.isa() && "a NameLoc cannot be used as a child of another NameLoc"); - return Base::get(child->getContext(), StandardAttributes::NameLocation, name, - child); + return Base::get(child->getContext(), name, child); } Location NameLoc::get(Identifier name, MLIRContext *context) { @@ -131,9 +128,8 @@ Location OpaqueLoc::get(uintptr_t underlyingLocation, TypeID typeID, Location fallbackLocation) { - return Base::get(fallbackLocation->getContext(), - StandardAttributes::OpaqueLocation, underlyingLocation, - typeID, fallbackLocation); + return Base::get(fallbackLocation->getContext(), underlyingLocation, typeID, + fallbackLocation); } uintptr_t OpaqueLoc::getUnderlyingLocation() const { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -87,6 +87,10 @@ struct BuiltinDialect : public Dialect { BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context, TypeID::get()) { + addTypes(); addAttributes(); - addTypes(); - // TODO: These operations should be moved to a different dialect when they // have been fully decoupled from the core. addOperations(); @@ -363,56 +362,50 @@ //// Types. /// Floating-point Types. - impl->bf16Ty = TypeUniquer::get(this, StandardTypes::BF16); - impl->f16Ty = TypeUniquer::get(this, StandardTypes::F16); - impl->f32Ty = TypeUniquer::get(this, StandardTypes::F32); - impl->f64Ty = TypeUniquer::get(this, StandardTypes::F64); + impl->bf16Ty = TypeUniquer::get(this); + impl->f16Ty = TypeUniquer::get(this); + impl->f32Ty = TypeUniquer::get(this); + impl->f64Ty = TypeUniquer::get(this); /// Index Type. - impl->indexTy = TypeUniquer::get(this, StandardTypes::Index); + impl->indexTy = TypeUniquer::get(this); /// Integer Types. - impl->int1Ty = TypeUniquer::get(this, StandardTypes::Integer, 1, - IntegerType::Signless); - impl->int8Ty = TypeUniquer::get(this, StandardTypes::Integer, 8, - IntegerType::Signless); - impl->int16Ty = TypeUniquer::get(this, StandardTypes::Integer, - 16, IntegerType::Signless); - impl->int32Ty = TypeUniquer::get(this, StandardTypes::Integer, - 32, IntegerType::Signless); - impl->int64Ty = TypeUniquer::get(this, StandardTypes::Integer, - 64, IntegerType::Signless); - impl->int128Ty = TypeUniquer::get(this, StandardTypes::Integer, - 128, IntegerType::Signless); + impl->int1Ty = TypeUniquer::get(this, 1, IntegerType::Signless); + impl->int8Ty = TypeUniquer::get(this, 8, IntegerType::Signless); + impl->int16Ty = + TypeUniquer::get(this, 16, IntegerType::Signless); + impl->int32Ty = + TypeUniquer::get(this, 32, IntegerType::Signless); + impl->int64Ty = + TypeUniquer::get(this, 64, IntegerType::Signless); + impl->int128Ty = + TypeUniquer::get(this, 128, IntegerType::Signless); /// None Type. - impl->noneType = TypeUniquer::get(this, StandardTypes::None); + impl->noneType = TypeUniquer::get(this); //// Attributes. //// Note: These must be registered after the types as they may generate one //// of the above types internally. /// Bool Attributes. impl->falseAttr = AttributeUniquer::get( - this, StandardAttributes::Integer, impl->int1Ty, - APInt(/*numBits=*/1, false)) + this, impl->int1Ty, APInt(/*numBits=*/1, false)) .cast(); impl->trueAttr = AttributeUniquer::get( - this, StandardAttributes::Integer, impl->int1Ty, - APInt(/*numBits=*/1, true)) + this, impl->int1Ty, APInt(/*numBits=*/1, true)) .cast(); /// Unit Attribute. - impl->unitAttr = - AttributeUniquer::get(this, StandardAttributes::Unit); + impl->unitAttr = AttributeUniquer::get(this); /// Unknown Location Attribute. - impl->unknownLocAttr = AttributeUniquer::get( - this, StandardAttributes::UnknownLocation); + impl->unknownLocAttr = AttributeUniquer::get(this); /// The empty dictionary attribute. - impl->emptyDictionaryAttr = AttributeUniquer::get( - this, StandardAttributes::Dictionary, ArrayRef()); + impl->emptyDictionaryAttr = + AttributeUniquer::get(this, ArrayRef()); // Register the affine storage objects with the uniquer. - impl->affineUniquer.registerStorageType( - TypeID::get()); - impl->affineUniquer.registerStorageType( - TypeID::get()); - impl->affineUniquer.registerStorageType(TypeID::get()); + impl->affineUniquer + .registerParametricStorageType(); + impl->affineUniquer + .registerParametricStorageType(); + impl->affineUniquer.registerParametricStorageType(); } MLIRContext::~MLIRContext() {} @@ -582,7 +575,6 @@ AbstractType(std::move(typeInfo)); if (!impl.registeredTypes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Type already registered."); - impl.typeUniquer.registerStorageType(typeID); } void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) { @@ -592,7 +584,6 @@ AbstractAttribute(std::move(attrInfo)); if (!impl.registeredAttributes.insert({typeID, newInfo}).second) llvm::report_fatal_error("Dialect Attribute already registered."); - impl.attributeUniquer.registerStorageType(typeID); } /// Get the dialect that registered the attribute with the provided typeid. @@ -718,7 +709,7 @@ MLIRContext *context) { if (auto cached = getCachedIntegerType(width, signedness, context)) return cached; - return Base::get(context, StandardTypes::Integer, width, signedness); + return Base::get(context, width, signedness); } IntegerType IntegerType::getChecked(unsigned width, Location location) { @@ -731,12 +722,16 @@ if (auto cached = getCachedIntegerType(width, signedness, location->getContext())) return cached; - return Base::getChecked(location, StandardTypes::Integer, width, signedness); + return Base::getChecked(location, width, signedness); } /// Get an instance of the NoneType. NoneType NoneType::get(MLIRContext *context) { - return context->getImpl().noneType; + if (NoneType cachedInst = context->getImpl().noneType) + return cachedInst; + // Note: May happen when initializing the singleton attributes of the builtin + // dialect. + return Base::get(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -102,12 +102,11 @@ //===----------------------------------------------------------------------===// ComplexType ComplexType::get(Type elementType) { - return Base::get(elementType.getContext(), StandardTypes::Complex, - elementType); + return Base::get(elementType.getContext(), elementType); } ComplexType ComplexType::getChecked(Type elementType, Location location) { - return Base::getChecked(location, StandardTypes::Complex, elementType); + return Base::getChecked(location, elementType); } /// Verify the construction of an integer type. @@ -265,13 +264,12 @@ //===----------------------------------------------------------------------===// VectorType VectorType::get(ArrayRef shape, Type elementType) { - return Base::get(elementType.getContext(), StandardTypes::Vector, shape, - elementType); + return Base::get(elementType.getContext(), shape, elementType); } VectorType VectorType::getChecked(ArrayRef shape, Type elementType, Location location) { - return Base::getChecked(location, StandardTypes::Vector, shape, elementType); + return Base::getChecked(location, shape, elementType); } LogicalResult VectorType::verifyConstructionInvariants(Location loc, @@ -320,15 +318,13 @@ RankedTensorType RankedTensorType::get(ArrayRef shape, Type elementType) { - return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape, - elementType); + return Base::get(elementType.getContext(), shape, elementType); } RankedTensorType RankedTensorType::getChecked(ArrayRef shape, Type elementType, Location location) { - return Base::getChecked(location, StandardTypes::RankedTensor, shape, - elementType); + return Base::getChecked(location, shape, elementType); } LogicalResult RankedTensorType::verifyConstructionInvariants( @@ -349,13 +345,12 @@ //===----------------------------------------------------------------------===// UnrankedTensorType UnrankedTensorType::get(Type elementType) { - return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor, - elementType); + return Base::get(elementType.getContext(), elementType); } UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, Location location) { - return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType); + return Base::getChecked(location, elementType); } LogicalResult @@ -444,8 +439,8 @@ cleanedAffineMapComposition.push_back(map); } - return Base::get(context, StandardTypes::MemRef, shape, elementType, - cleanedAffineMapComposition, memorySpace); + return Base::get(context, shape, elementType, cleanedAffineMapComposition, + memorySpace); } ArrayRef MemRefType::getShape() const { return getImpl()->getShape(); } @@ -462,15 +457,13 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType, unsigned memorySpace) { - return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef, - elementType, memorySpace); + return Base::get(elementType.getContext(), elementType, memorySpace); } UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, unsigned memorySpace, Location location) { - return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType, - memorySpace); + return Base::getChecked(location, elementType, memorySpace); } unsigned UnrankedMemRefType::getMemorySpace() const { @@ -642,7 +635,7 @@ /// Get or create a new TupleType with the provided element types. Assumes the /// arguments define a well-formed type. TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) { - return Base::get(context, StandardTypes::Tuple, elementTypes); + return Base::get(context, elementTypes); } /// Get or create an empty tuple type. diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -19,8 +19,6 @@ // Type //===----------------------------------------------------------------------===// -unsigned Type::getKind() const { return impl->getKind(); } - Dialect &Type::getDialect() const { return impl->getAbstractType().getDialect(); } @@ -33,7 +31,7 @@ FunctionType FunctionType::get(TypeRange inputs, TypeRange results, MLIRContext *context) { - return Base::get(context, Type::Kind::Function, inputs, results); + return Base::get(context, inputs, results); } unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } @@ -54,12 +52,12 @@ OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData, MLIRContext *context) { - return Base::get(context, Type::Kind::Opaque, dialect, typeData); + return Base::get(context, dialect, typeData); } OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, MLIRContext *context, Location location) { - return Base::getChecked(location, Kind::Opaque, dialect, typeData); + return Base::getChecked(location, dialect, typeData); } /// Returns the dialect namespace of the opaque type. diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -16,19 +16,17 @@ using namespace mlir::detail; namespace { -/// This class represents a uniquer for storage instances of a specific type. It -/// contains all of the necessary data to unique storage instances in a thread -/// safe way. This allows for the main uniquer to bucket each of the individual -/// sub-types removing the need to lock the main uniquer itself. -struct InstSpecificUniquer { +/// This class represents a uniquer for storage instances of a specific type +/// that has parametric storage. It contains all of the necessary data to unique +/// storage instances in a thread safe way. This allows for the main uniquer to +/// bucket each of the individual sub-types removing the need to lock the main +/// uniquer itself. +struct ParametricStorageUniquer { using BaseStorage = StorageUniquer::BaseStorage; using StorageAllocator = StorageUniquer::StorageAllocator; /// A lookup key for derived instances of storage objects. struct LookupKey { - /// The known derived kind for the storage. - unsigned kind; - /// The known hash value of the key. unsigned hashValue; @@ -63,18 +61,14 @@ static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) { if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey())) return false; - // If the lookup kind matches the kind of the storage, then invoke the - // equality function on the lookup key. - return lhs.kind == rhs.storage->getKind() && lhs.isEqual(rhs.storage); + // Invoke the equality function on the lookup key. + return lhs.isEqual(rhs.storage); } }; - /// Unique types with specific hashing or storage constraints. + /// The set containing the allocated storage instances. using StorageTypeSet = DenseSet; - StorageTypeSet complexInstances; - - /// Instances of this storage object. - llvm::SmallDenseMap simpleInstances; + StorageTypeSet instances; /// Allocator to use when constructing derived instances. StorageAllocator allocator; @@ -91,107 +85,79 @@ using BaseStorage = StorageUniquer::BaseStorage; using StorageAllocator = StorageUniquer::StorageAllocator; - /// Get or create an instance of a complex derived type. + //===--------------------------------------------------------------------===// + // Parametric Storage + //===--------------------------------------------------------------------===// + + /// Get or create an instance of a parametric type. BaseStorage * - getOrCreate(TypeID id, unsigned kind, unsigned hashValue, + getOrCreate(TypeID id, unsigned hashValue, function_ref isEqual, function_ref ctorFn) { - assert(instUniquers.count(id) && "creating unregistered storage instance"); - InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual}; - InstSpecificUniquer &storageUniquer = *instUniquers[id]; + assert(parametricUniquers.count(id) && + "creating unregistered storage instance"); + ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual}; + ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; if (!threadingIsEnabled) - return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); + return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn); // Check for an existing instance in read-only mode. { llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); - auto it = storageUniquer.complexInstances.find_as(lookupKey); - if (it != storageUniquer.complexInstances.end()) + auto it = storageUniquer.instances.find_as(lookupKey); + if (it != storageUniquer.instances.end()) return it->storage; } // Acquire a writer-lock so that we can safely create the new type instance. llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); - return getOrCreateUnsafe(storageUniquer, kind, lookupKey, ctorFn); + return getOrCreateUnsafe(storageUniquer, lookupKey, ctorFn); } /// Get or create an instance of a complex derived type in an thread-unsafe /// fashion. BaseStorage * - getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind, - InstSpecificUniquer::LookupKey &lookupKey, + getOrCreateUnsafe(ParametricStorageUniquer &storageUniquer, + ParametricStorageUniquer::LookupKey &lookupKey, function_ref ctorFn) { - auto existing = storageUniquer.complexInstances.insert_as({}, lookupKey); + auto existing = storageUniquer.instances.insert_as({}, lookupKey); if (!existing.second) return existing.first->storage; // Otherwise, construct and initialize the derived storage for this type // instance. - BaseStorage *storage = - initializeStorage(kind, storageUniquer.allocator, ctorFn); + BaseStorage *storage = ctorFn(storageUniquer.allocator); *existing.first = - InstSpecificUniquer::HashedStorage{lookupKey.hashValue, storage}; + ParametricStorageUniquer::HashedStorage{lookupKey.hashValue, storage}; return storage; } - /// Get or create an instance of a simple derived type. - BaseStorage * - getOrCreate(TypeID id, unsigned kind, - function_ref ctorFn) { - assert(instUniquers.count(id) && "creating unregistered storage instance"); - InstSpecificUniquer &storageUniquer = *instUniquers[id]; - if (!threadingIsEnabled) - return getOrCreateUnsafe(storageUniquer, kind, ctorFn); - - // Check for an existing instance in read-only mode. - { - llvm::sys::SmartScopedReader typeLock(storageUniquer.mutex); - auto it = storageUniquer.simpleInstances.find(kind); - if (it != storageUniquer.simpleInstances.end()) - return it->second; - } - - // Acquire a writer-lock so that we can safely create the new type instance. - llvm::sys::SmartScopedWriter typeLock(storageUniquer.mutex); - return getOrCreateUnsafe(storageUniquer, kind, ctorFn); - } - /// Get or create an instance of a simple derived type in an thread-unsafe - /// fashion. - BaseStorage * - getOrCreateUnsafe(InstSpecificUniquer &storageUniquer, unsigned kind, - function_ref ctorFn) { - auto &result = storageUniquer.simpleInstances[kind]; - if (result) - return result; - - // Otherwise, create and return a new storage instance. - return result = initializeStorage(kind, storageUniquer.allocator, ctorFn); - } - - /// Erase an instance of a complex derived type. - void erase(TypeID id, unsigned kind, unsigned hashValue, + /// Erase an instance of a parametric derived type. + void erase(TypeID id, unsigned hashValue, function_ref isEqual, function_ref cleanupFn) { - assert(instUniquers.count(id) && "erasing unregistered storage instance"); - InstSpecificUniquer &storageUniquer = *instUniquers[id]; - InstSpecificUniquer::LookupKey lookupKey{kind, hashValue, isEqual}; + assert(parametricUniquers.count(id) && + "erasing unregistered storage instance"); + ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; + ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual}; // Acquire a writer-lock so that we can safely erase the type instance. llvm::sys::SmartScopedWriter lock(storageUniquer.mutex); - auto existing = storageUniquer.complexInstances.find_as(lookupKey); - if (existing == storageUniquer.complexInstances.end()) + auto existing = storageUniquer.instances.find_as(lookupKey); + if (existing == storageUniquer.instances.end()) return; // Cleanup the storage and remove it from the map. cleanupFn(existing->storage); - storageUniquer.complexInstances.erase(existing); + storageUniquer.instances.erase(existing); } /// Mutates an instance of a derived storage in a thread-safe way. LogicalResult mutate(TypeID id, function_ref mutationFn) { - assert(instUniquers.count(id) && "mutating unregistered storage instance"); - InstSpecificUniquer &storageUniquer = *instUniquers[id]; + assert(parametricUniquers.count(id) && + "mutating unregistered storage instance"); + ParametricStorageUniquer &storageUniquer = *parametricUniquers[id]; if (!threadingIsEnabled) return mutationFn(storageUniquer.allocator); @@ -200,20 +166,30 @@ } //===--------------------------------------------------------------------===// - // Instance Storage + // Singleton Storage //===--------------------------------------------------------------------===// - /// Utility to create and initialize a storage instance. - BaseStorage * - initializeStorage(unsigned kind, StorageAllocator &allocator, - function_ref ctorFn) { - BaseStorage *storage = ctorFn(allocator); - storage->kind = kind; - return storage; + /// Get or create an instance of a singleton storage class. + BaseStorage *getSingleton(TypeID id) { + BaseStorage *singletonInstance = singletonInstances[id]; + assert(singletonInstance && "expected singleton instance to exist"); + return singletonInstance; } + //===--------------------------------------------------------------------===// + // Instance Storage + //===--------------------------------------------------------------------===// + /// Map of type ids to the storage uniquer to use for registered objects. - DenseMap> instUniquers; + DenseMap> + parametricUniquers; + + /// Map of type ids to a singleton instance when the storage class is a + /// singleton. + DenseMap singletonInstances; + + /// Allocator used for uniquing singleton instances. + StorageAllocator singletonAllocator; /// Flag specifying if multi-threading is enabled within the uniquer. bool threadingIsEnabled = true; @@ -229,41 +205,47 @@ impl->threadingIsEnabled = !disable; } -/// Register a new storage object with this uniquer using the given unique type -/// id. -void StorageUniquer::registerStorageType(TypeID id) { - impl->instUniquers.try_emplace(id, std::make_unique()); -} - /// Implementation for getting/creating an instance of a derived type with -/// complex storage. -auto StorageUniquer::getImpl( - const TypeID &id, unsigned kind, unsigned hashValue, +/// parametric storage. +auto StorageUniquer::getParametricStorageTypeImpl( + TypeID id, unsigned hashValue, function_ref isEqual, function_ref ctorFn) -> BaseStorage * { - return impl->getOrCreate(id, kind, hashValue, isEqual, ctorFn); + return impl->getOrCreate(id, hashValue, isEqual, ctorFn); } -/// Implementation for getting/creating an instance of a derived type with -/// default storage. -auto StorageUniquer::getImpl( - const TypeID &id, unsigned kind, - function_ref ctorFn) -> BaseStorage * { - return impl->getOrCreate(id, kind, ctorFn); +/// Implementation for registering an instance of a derived type with +/// parametric storage. +void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) { + impl->parametricUniquers.try_emplace( + id, std::make_unique()); +} + +/// Implementation for getting an instance of a derived type with default +/// storage. +auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * { + return impl->getSingleton(id); +} + +/// Implementation for registering an instance of a derived type with default +/// storage. +void StorageUniquer::registerSingletonImpl( + TypeID id, function_ref ctorFn) { + assert(!impl->singletonInstances.count(id) && + "storage class already registered"); + impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator)); } -/// Implementation for erasing an instance of a derived type with complex +/// Implementation for erasing an instance of a derived type with parametric /// storage. -void StorageUniquer::eraseImpl(const TypeID &id, unsigned kind, - unsigned hashValue, +void StorageUniquer::eraseImpl(TypeID id, unsigned hashValue, function_ref isEqual, function_ref cleanupFn) { - impl->erase(id, kind, hashValue, isEqual, cleanupFn); + impl->erase(id, hashValue, isEqual, cleanupFn); } /// Implementation for mutating an instance of a derived storage. LogicalResult StorageUniquer::mutateImpl( - const TypeID &id, - function_ref mutationFn) { + TypeID id, function_ref mutationFn) { return impl->mutate(id, mutationFn); } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -156,7 +156,7 @@ StringRef name; if (parser.parseLess() || parser.parseKeyword(&name)) return Type(); - auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name); + auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); // If this type already has been parsed above in the stack, expect just the // name. diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -26,10 +26,6 @@ TestTypeInterface::Trait> { using Base::Base; - static TestType get(MLIRContext *context) { - return Base::get(context, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE); - } - /// Provide a definition for the necessary interface methods. void printTypeC(Location loc) const { emitRemark(loc) << *this << " - TestC"; @@ -72,9 +68,8 @@ public: using Base::Base; - static TestRecursiveType create(MLIRContext *ctx, StringRef name) { - return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1, - name); + static TestRecursiveType get(MLIRContext *ctx, StringRef name) { + return Base::get(ctx, name); } /// Body getter and setter. diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp --- a/mlir/test/lib/IR/TestTypes.cpp +++ b/mlir/test/lib/IR/TestTypes.cpp @@ -41,7 +41,7 @@ LogicalResult TestRecursiveTypesPass::createIRWithTypes() { MLIRContext *ctx = &getContext(); FuncOp func = getFunction(); - auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name"); + auto type = TestRecursiveType::get(ctx, "some_long_and_unique_name"); if (failed(type.setBody(type))) return func.emitError("expected to be able to set the type body"); @@ -56,7 +56,7 @@ "not expected to be able to change function body more than once"); // Expecting to get the same type for the same name. - auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name"); + auto other = TestRecursiveType::get(ctx, "some_long_and_unique_name"); if (type != other) return func.emitError("expected type name to be the uniquing key");