diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -32,6 +32,7 @@ class Value; class APInt; class LLVMContext; +class TargetExtTypeClass; /// Class to represent integer types. Note that this class is also used to /// represent the built-in integer types: Int1Ty, Int8Ty, Int16Ty, Int32Ty and @@ -737,9 +738,10 @@ /// integer parameters. The exact meaning of any parameters is dependent on the /// target. class TargetExtType : public Type { - TargetExtType(LLVMContext &C, StringRef Name, ArrayRef Types, - ArrayRef Ints); + TargetExtType(LLVMContext &C, const TargetExtTypeClass *Class, StringRef Name, + ArrayRef Types, ArrayRef Ints); + const TargetExtTypeClass *Class; // These strings are ultimately owned by the context. StringRef Name; unsigned *IntParams; @@ -754,6 +756,17 @@ ArrayRef Types = std::nullopt, ArrayRef Ints = std::nullopt); + /// Return a target extension type of a known type class, having the specified + /// name and optional type and integer parameters. + static TargetExtType *get(LLVMContext &Context, + const TargetExtTypeClass *Class, StringRef Name, + ArrayRef Types = std::nullopt, + ArrayRef Ints = std::nullopt); + + /// Return the type class of this target extension type, or null if none has + /// been registered. + const TargetExtTypeClass *getClass() const { return Class; } + /// Return the name for this target extension type. Two distinct target /// extension types may have the same name if their type or integer parameters /// differ. diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -34,6 +34,7 @@ template class SmallVectorImpl; template class StringMapEntry; class StringRef; +class TargetExtTypeClass; class Twine; class LLVMRemarkStreamer; @@ -320,6 +321,9 @@ /// Whether typed pointers are supported. If false, all pointers are opaque. bool supportsTypedPointers() const; + /// Register a custom extension type class. + void registerTargetExtTypeClass(const TargetExtTypeClass *TypeClass); + private: // Module needs access to the add/removeModule methods. friend class Module; diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMRemarkStreamer.h" +#include "llvm/IR/TargetExtType.h" #include "llvm/Remarks/RemarkStreamer.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" @@ -103,6 +104,21 @@ assert(SystemSSID == SyncScope::System && "system synchronization scope ID drifted!"); (void)SystemSSID; + + static const TargetExtTypeClass SpirvTypes( + "spirv.", true, [](const TargetExtType *T) { + return TargetTypeInfo(Type::getInt8PtrTy(T->getContext(), 0), + TargetExtType::HasZeroInit, + TargetExtType::CanBeGlobal); + }); + registerTargetExtTypeClass(&SpirvTypes); + + static const TargetExtTypeClass Aarch64SVCount( + "aarch64.svcount", false, [](const TargetExtType *T) { + return TargetTypeInfo( + ScalableVectorType::get(Type::getInt1Ty(T->getContext()), 16)); + }); + registerTargetExtTypeClass(&Aarch64SVCount); } LLVMContext::~LLVMContext() { delete pImpl; } @@ -375,3 +391,9 @@ bool LLVMContext::supportsTypedPointers() const { return !pImpl->getOpaquePointers(); } + +void LLVMContext::registerTargetExtTypeClass( + const TargetExtTypeClass *TypeClass) { + assert(!pImpl->TargetExtTypeClassesFrozen); + pImpl->TargetExtTypeClasses.push_back(TypeClass); +} diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -1611,6 +1611,9 @@ /// clients which do use GC. DenseMap GCNames; + SmallVector TargetExtTypeClasses; + bool TargetExtTypeClassesFrozen = false; + /// Flag to indicate if Value (other than GlobalValue) retains their name or /// not. bool DiscardValueNames = false; diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/TargetExtType.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/TypeSize.h" @@ -806,9 +807,31 @@ // TargetExtType Implementation //===----------------------------------------------------------------------===// -TargetExtType::TargetExtType(LLVMContext &C, StringRef Name, - ArrayRef Types, ArrayRef Ints) - : Type(C, TargetExtTyID), Name(C.pImpl->Saver.save(Name)) { +TargetExtTypeClass::TargetExtTypeClass(StringRef Name, bool NameIsPrefix, + GetTypeInfoFn *GetTypeInfo) + : Name(Name.str()), NameIsPrefix(NameIsPrefix), GetTypeInfo(GetTypeInfo) { + assert(!Name.empty()); + assert(NameIsPrefix == Name.ends_with(".")); +} + +const TargetExtTypeClass *TargetExtTypeClass::find(LLVMContext &Ctx, + StringRef Name) { + Ctx.pImpl->TargetExtTypeClassesFrozen = true; + + for (const TargetExtTypeClass *Class : Ctx.pImpl->TargetExtTypeClasses) { + if ((Class->NameIsPrefix && Name.starts_with(Class->Name)) || + (!Class->NameIsPrefix && Name == Class->Name)) { + return Class; + } + } + + return nullptr; +} + +TargetExtType::TargetExtType(LLVMContext &C, const TargetExtTypeClass *Class, + StringRef Name, ArrayRef Types, + ArrayRef Ints) + : Type(C, TargetExtTyID), Class(Class), Name(C.pImpl->Saver.save(Name)) { NumContainedTys = Types.size(); // Parameter storage immediately follows the class in allocation. @@ -827,6 +850,14 @@ TargetExtType *TargetExtType::get(LLVMContext &C, StringRef Name, ArrayRef Types, ArrayRef Ints) { + const auto *Class = TargetExtTypeClass::find(C, Name); + return get(C, Class, Name, Types, Ints); +} + +TargetExtType *TargetExtType::get(LLVMContext &C, + const TargetExtTypeClass *Class, + StringRef Name, ArrayRef Types, + ArrayRef Ints) { const TargetExtTypeKeyInfo::KeyTy Key(Name, Types, Ints); TargetExtType *TT; // Since we only want to allocate a fresh target type in case none is found @@ -842,7 +873,7 @@ sizeof(TargetExtType) + sizeof(Type *) * Types.size() + sizeof(unsigned) * Ints.size(), alignof(TargetExtType)); - new (TT) TargetExtType(C, Name, Types, Ints); + new (TT) TargetExtType(C, Class, Name, Types, Ints); *Insertion.first = TT; } else { // The target type was found. Just return it. @@ -851,29 +882,10 @@ return TT; } -namespace { -struct TargetTypeInfo { - Type *LayoutType; - uint64_t Properties; - - template - TargetTypeInfo(Type *LayoutType, ArgTys... Properties) - : LayoutType(LayoutType), Properties((0 | ... | Properties)) {} -}; -} // anonymous namespace - static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) { - LLVMContext &C = Ty->getContext(); - StringRef Name = Ty->getName(); - if (Name.startswith("spirv.")) - return TargetTypeInfo(Type::getInt8PtrTy(C, 0), TargetExtType::HasZeroInit, - TargetExtType::CanBeGlobal); - - // Opaque types in the AArch64 name space. - if (Name == "aarch64.svcount") - return TargetTypeInfo(ScalableVectorType::get(Type::getInt1Ty(C), 16)); - - return TargetTypeInfo(Type::getVoidTy(C)); + if (const auto *Class = Ty->getClass()) + return Class->getTypeInfo(Ty); + return TargetTypeInfo(Type::getVoidTy(Ty->getContext())); } Type *TargetExtType::getLayoutType() const {