diff --git a/llvm/docs/BitCodeFormat.rst b/llvm/docs/BitCodeFormat.rst --- a/llvm/docs/BitCodeFormat.rst +++ b/llvm/docs/BitCodeFormat.rst @@ -1338,6 +1338,21 @@ The ``X86_AMX`` record (code 24) adds an ``x86_amx`` type to the type table. +TYPE_CODE_TARGET_TYPE Record +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``[TARGET_TYPE, num_tys, ...ty_params..., ...int_params... ]`` + +The ``TARGET_TYPE`` record (code 26) adds a target extension type to the type +table, with a name defined by a previously encountered ``STRUCT_NAME`` record. +The operand fields are + +* *num_tys*: The number of parameters that are types (as opposed to integers) + +* *ty_params*: Type indices that represent type parameters + +* *int_params*: Numbers that correspond to the integer parameters. + .. _CONSTANTS_BLOCK: CONSTANTS_BLOCK Contents diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -3631,6 +3631,53 @@ pointers" are still supported under non-default options. See the `opaque pointers document `__ for more information. +.. _t_target_type: + +Target Extension Type +""""""""""""""""""""" + +:Overview: + +Target extension types represent types that must be preserved through +optimization, but are otherwise generally opaque to the compiler. They may be +used as function parameters or arguments, and in :ref:`phi ` or +:ref:`select ` instructions. Some types may be also used in +:ref:`alloca ` instructions or as global values, and correspondingly +it is legal to use :ref:`load ` and :ref:`store ` instructions +on them. Full semantics for these types are defined by the target. + +The only constants that target extension types may have are ``zeroinitializer``, +``undef``, and ``poison``. Other possible values for target extension types may +arise from target-specific intrinsics and functions. + +These types cannot be converted to other types. As such, it is not legal to use +them in :ref:`bitcast ` instructions (as a source or target type), +nor is it legal to use them in :ref:`ptrtoint ` or +:ref:`inttoptr ` instructions. Similarly, they are not legal to use +in an :ref:`icmp ` instruction. + +Target extension types have a name and optional type or integer parameters. The +meanings of name and parameters are defined by the target. When being defined in +LLVM IR, all of the type parameters must precede all of the integer parameters. + +Specific target extension types are registered with LLVM as having specific +properties. These properties can be used to restrict the type from appearing in +certain contexts, such as being the type of a global variable or having a +``zeroinitializer`` constant be valid. A complete list of type properties may be +found in the documentation for ``llvm::TargetExtType::Property`` (`doxygen +`_). + +:Syntax: + +.. code-block:: llvm + + target("label") + target("label", void) + target("label", void, i32) + target("label", 0, 1, 2) + target("label", void, i32, 0, 1, 2) + + .. _t_vector: Vector Type diff --git a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst --- a/llvm/docs/ReleaseNotes.rst +++ b/llvm/docs/ReleaseNotes.rst @@ -104,6 +104,10 @@ * ``fneg`` +* Target extension types have been added, which allow targets to have + types that need to be preserved through the optimizer, but otherwise are not + introspectable by target-independent types. + Changes to building LLVM ------------------------ diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h --- a/llvm/include/llvm-c/Core.h +++ b/llvm/include/llvm-c/Core.h @@ -165,7 +165,8 @@ LLVMTokenTypeKind, /**< Tokens */ LLVMScalableVectorTypeKind, /**< Scalable SIMD vector type */ LLVMBFloatTypeKind, /**< 16 bit brain floating point type */ - LLVMX86_AMXTypeKind /**< X86 AMX */ + LLVMX86_AMXTypeKind, /**< X86 AMX */ + LLVMTargetExtTypeKind, /**< Target extension type */ } LLVMTypeKind; typedef enum { @@ -284,7 +285,8 @@ LLVMInlineAsmValueKind, LLVMInstructionValueKind, - LLVMPoisonValueValueKind + LLVMPoisonValueValueKind, + LLVMConstantTargetNoneValueKind, } LLVMValueKind; typedef enum { @@ -1571,6 +1573,15 @@ LLVMTypeRef LLVMX86MMXType(void); LLVMTypeRef LLVMX86AMXType(void); +/** + * Create a target extension type in LLVM context. + */ +LLVMTypeRef LLVMTargetExtTypeInContext(LLVMContextRef C, const char *Name, + LLVMTypeRef *TypeParams, + unsigned TypeParamCount, + unsigned *IntParams, + unsigned IntParamCount); + /** * @} */ diff --git a/llvm/include/llvm/AsmParser/LLParser.h b/llvm/include/llvm/AsmParser/LLParser.h --- a/llvm/include/llvm/AsmParser/LLParser.h +++ b/llvm/include/llvm/AsmParser/LLParser.h @@ -434,6 +434,7 @@ bool parseArrayVectorType(Type *&Result, bool IsVector); bool parseFunctionType(Type *&Result); + bool parseTargetExtType(Type *&Result); // Function Semantic Analysis. class PerFunctionState { diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -175,6 +175,8 @@ TYPE_CODE_X86_AMX = 24, // X86 AMX TYPE_CODE_OPAQUE_POINTER = 25, // OPAQUE_POINTER: [addrspace] + + TYPE_CODE_TARGET_TYPE = 26, // TARGET_TYPE }; enum OperandBundleTagCode { diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -843,6 +843,33 @@ } }; +/// A constant target extension type default initializer +class ConstantTargetNone final : public ConstantData { + friend class Constant; + + explicit ConstantTargetNone(TargetExtType *T) + : ConstantData(T, Value::ConstantTargetNoneVal) {} + + void destroyConstantImpl(); + +public: + ConstantTargetNone(const ConstantTargetNone &) = delete; + + /// Static factory methods - Return objects of the specified value. + static ConstantTargetNone *get(TargetExtType *T); + + /// Specialize the getType() method to always return an TargetExtType, + /// which reduces the amount of casting needed in parts of the compiler. + inline TargetExtType *getType() const { + return cast(Value::getType()); + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const Value *V) { + return V->getValueID() == ConstantTargetNoneVal; + } +}; + /// The address of a basic block. /// class BlockAddress final : public Constant { diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h --- a/llvm/include/llvm/IR/DataLayout.h +++ b/llvm/include/llvm/IR/DataLayout.h @@ -713,6 +713,10 @@ getTypeSizeInBits(VTy->getElementType()).getFixedSize(); return TypeSize(MinBits, EltCnt.isScalable()); } + case Type::TargetExtTyID: { + Type *LayoutTy = cast(Ty)->getLayoutType(); + return getTypeSizeInBits(LayoutTy); + } default: llvm_unreachable("DataLayout::getTypeSizeInBits(): Unsupported type"); } diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -730,6 +730,82 @@ return cast(getScalarType())->getAddressSpace(); } +/// Class to represent target extensions types, which are generally +/// unintrospectable from target-independent optimizations. +/// +/// Target extension types have a string name, and optionally have type and/or +/// integer parameters. The exact meaning of any parameters is dependent on the +/// target. +class TargetExtType : public Type { + TargetExtType(LLVMContext &C, StringRef Name, ArrayRef Types, + ArrayRef Ints); + + std::string Name; + unsigned *IntParams; + +public: + TargetExtType(const TargetExtType &) = delete; + TargetExtType &operator=(const TargetExtType &) = delete; + + /// Return a target extension type having the specified name and optional + /// type and integer parameters. + static TargetExtType *get(LLVMContext &Context, StringRef Name, + ArrayRef Types = std::nullopt, + ArrayRef Ints = std::nullopt); + + /// Return the name for this target extension type. Two distinct target + /// extension types may have the same name if their type or integer parameters + /// differ. + StringRef getName() const { return Name; } + + /// Return the type parameters for this particular target extension type. If + /// there are no parameters, an empty array is returned. + ArrayRef type_params() const { + return makeArrayRef(type_param_begin(), type_param_end()); + } + + using type_param_iterator = Type::subtype_iterator; + type_param_iterator type_param_begin() const { return ContainedTys; } + type_param_iterator type_param_end() const { + return &ContainedTys[NumContainedTys]; + } + + Type *getTypeParameter(unsigned i) const { return getContainedType(i); } + unsigned getNumTypeParameters() const { return getNumContainedTypes(); } + + /// Return the integer parameters for this particular target extension type. + /// If there are no parameters, an empty array is returned. + ArrayRef int_params() const { + return makeArrayRef(IntParams, getNumIntParameters()); + } + + unsigned getIntParameter(unsigned i) const { return IntParams[i]; } + unsigned getNumIntParameters() const { return getSubclassData(); } + + enum Property { + /// zeroinitializer is valid for this target extension type. + HasZeroInit = 1U << 0, + /// This type may be used as the value type of a global variable. + CanBeGlobal = 1U << 1, + }; + + /// Returns true if the target extension type contains the given property. + bool hasProperty(Property Prop) const; + + /// Returns an underlying layout type for the target extension type. This + /// type can be used to query size and alignment information, if it is + /// appropriate (although note that the layout type may also be void). It is + /// not legal to bitcast between this type and the layout type, however. + Type *getLayoutType() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const Type *T) { return T->getTypeID() == TargetExtTyID; } +}; + +StringRef Type::getTargetExtName() const { + return cast(this)->getName(); +} + } // end namespace llvm #endif // LLVM_IR_DERIVEDTYPES_H diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -76,6 +76,7 @@ FixedVectorTyID, ///< Fixed width SIMD vector type ScalableVectorTyID, ///< Scalable SIMD vector type TypedPointerTyID, ///< Typed pointer used by some GPU targets + TargetExtTyID, ///< Target extension type }; private: @@ -194,6 +195,9 @@ /// Return true if this is X86 AMX. bool isX86_AMXTy() const { return getTypeID() == X86_AMXTyID; } + /// Return true if this is a target extension type. + bool isTargetExtTy() const { return getTypeID() == TargetExtTyID; } + /// Return true if this is a FP type or a vector of FP. bool isFPOrFPVectorTy() const { return getScalarType()->isFloatingPointTy(); } @@ -267,7 +271,7 @@ /// includes all first-class types except struct and array types. bool isSingleValueType() const { return isFloatingPointTy() || isX86_MMXTy() || isIntegerTy() || - isPointerTy() || isVectorTy() || isX86_AMXTy(); + isPointerTy() || isVectorTy() || isX86_AMXTy() || isTargetExtTy(); } /// Return true if the type is an aggregate type. This means it is valid as @@ -288,7 +292,8 @@ return true; // If it is not something that can have a size (e.g. a function or label), // it doesn't have a size. - if (getTypeID() != StructTyID && getTypeID() != ArrayTyID && !isVectorTy()) + if (getTypeID() != StructTyID && getTypeID() != ArrayTyID && + !isVectorTy() && getTypeID() != TargetExtTyID) return false; // Otherwise we have to try harder to decide. return isSizedDerivedType(Visited); @@ -386,6 +391,8 @@ return ContainedTys[0]; } + inline StringRef getTargetExtName() const; + /// This method is deprecated without replacement. Pointer element types are /// not available with opaque pointers. [[deprecated("Deprecated without replacement, see " diff --git a/llvm/include/llvm/IR/Value.def b/llvm/include/llvm/IR/Value.def --- a/llvm/include/llvm/IR/Value.def +++ b/llvm/include/llvm/IR/Value.def @@ -95,6 +95,7 @@ HANDLE_CONSTANT(ConstantDataVector) HANDLE_CONSTANT(ConstantInt) HANDLE_CONSTANT(ConstantFP) +HANDLE_CONSTANT(ConstantTargetNone) HANDLE_CONSTANT(ConstantPointerNull) HANDLE_CONSTANT(ConstantTokenNone) diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -2598,6 +2598,12 @@ return false; } break; + case lltok::kw_target: { + // Type ::= TargetExtType + if (parseTargetExtType(Result)) + return true; + break; + } case lltok::lbrace: // Type ::= StructType if (parseAnonStructType(Result, false)) @@ -3105,6 +3111,60 @@ return false; } +/// parseTargetExtType - handle target extension type syntax +/// TargetExtType +/// ::= 'target' '(' STRINGCONSTANT TargetExtTypeParams TargetExtIntParams ')' +/// +/// TargetExtTypeParams +/// ::= /*empty*/ +/// ::= ',' Type TargetExtTypeParams +/// +/// TargetExtIntParams +/// ::= /*empty*/ +/// ::= ',' uint32 TargetExtIntParams +bool LLParser::parseTargetExtType(Type *&Result) { + Lex.Lex(); // Eat the 'target' keyword. + + // Get the mandatory type name. + std::string TypeName; + if (parseToken(lltok::lparen, "expected '(' in target extension type") || + parseStringConstant(TypeName)) + return true; + + // Parse all of the integer and type parameters at the same time; the use of + // SeenInt will allow us to catch cases where type parameters follow integer + // parameters. + SmallVector TypeParams; + SmallVector IntParams; + bool SeenInt = false; + while (Lex.getKind() == lltok::comma) { + Lex.Lex(); // Eat the comma. + + if (Lex.getKind() == lltok::APSInt) { + SeenInt = true; + unsigned IntVal; + if (parseUInt32(IntVal)) + return true; + IntParams.push_back(IntVal); + } else if (SeenInt) { + // The only other kind of parameter we support is type parameters, which + // must precede the integer parameters. This is therefore an error. + return tokError("expected uint32 param"); + } else { + Type *TypeParam; + if (parseType(TypeParam, /*AllowVoid=*/true)) + return true; + TypeParams.push_back(TypeParam); + } + } + + if (parseToken(lltok::rparen, "expected ')' in target extension type")) + return true; + + Result = TargetExtType::get(Context, TypeName, TypeParams, IntParams); + return false; +} + //===----------------------------------------------------------------------===// // Function Semantic Analysis. //===----------------------------------------------------------------------===// @@ -5694,6 +5754,9 @@ // FIXME: LabelTy should not be a first-class type. if (!Ty->isFirstClassType() || Ty->isLabelTy()) return error(ID.Loc, "invalid type for null constant"); + if (auto *TETy = dyn_cast(Ty)) + if (!TETy->hasProperty(TargetExtType::HasZeroInit)) + return error(ID.Loc, "invalid type for null constant"); V = Constant::getNullValue(Ty); return false; case ValID::t_None: diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -2486,6 +2486,35 @@ ResultTy = Res; break; } + case bitc::TYPE_CODE_TARGET_TYPE: { // TARGET_TYPE: [NumTy, Tys..., Ints...] + if (Record.size() < 1) + return error("Invalid target extension type record"); + + if (NumRecords >= TypeList.size()) + return error("Invalid TYPE table"); + + if (Record[0] >= Record.size()) + return error("Too many type parameters"); + + unsigned NumTys = Record[0]; + SmallVector TypeParams; + SmallVector IntParams; + for (unsigned i = 0; i < NumTys; i++) { + if (Type *T = getTypeByID(Record[i + 1])) + TypeParams.push_back(T); + else + return error("Invalid type"); + } + + for (unsigned i = NumTys + 1, e = Record.size(); i < e; i++) { + if (Record[i] > UINT_MAX) + return error("Integer parameter too large"); + IntParams.push_back(Record[i]); + } + ResultTy = TargetExtType::get(Context, TypeName, TypeParams, IntParams); + TypeName.clear(); + break; + } case bitc::TYPE_CODE_ARRAY: // ARRAY: [numelts, eltty] if (Record.size() < 2) return error("Invalid array type record"); @@ -2989,6 +3018,9 @@ case bitc::CST_CODE_NULL: // NULL if (CurTy->isVoidTy() || CurTy->isFunctionTy() || CurTy->isLabelTy()) return error("Invalid type for a constant null value"); + if (auto *TETy = dyn_cast(CurTy)) + if (!TETy->hasProperty(TargetExtType::HasZeroInit)) + return error("Invalid type for a constant null value"); V = Constant::getNullValue(CurTy); break; case bitc::CST_CODE_INTEGER: // INTEGER: [intval] diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -1052,6 +1052,18 @@ TypeVals.push_back(true); break; } + case Type::TargetExtTyID: { + TargetExtType *TET = cast(T); + Code = bitc::TYPE_CODE_TARGET_TYPE; + writeStringRecord(Stream, bitc::TYPE_CODE_STRUCT_NAME, TET->getName(), + StructNameAbbrev); + TypeVals.push_back(TET->getNumTypeParameters()); + for (Type *InnerTy : TET->type_params()) + TypeVals.push_back(VE.getTypeID(InnerTy)); + for (unsigned IntParam : TET->int_params()) + TypeVals.push_back(IntParam); + break; + } case Type::TypedPointerTyID: llvm_unreachable("Typed pointers cannot be added to IR modules"); } diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -622,6 +622,17 @@ << TPTy->getAddressSpace() << ")"; return; } + case Type::TargetExtTyID: + TargetExtType *TETy = cast(Ty); + OS << "target(\""; + printEscapedString(Ty->getTargetExtName(), OS); + OS << "\""; + for (Type *Inner : TETy->type_params()) + OS << ", " << *Inner; + for (unsigned IntParam : TETy->int_params()) + OS << ", " << IntParam; + OS << ")"; + return; } llvm_unreachable("Invalid TypeID"); } @@ -1438,7 +1449,7 @@ return; } - if (isa(CV)) { + if (isa(CV) || isa(CV)) { Out << "zeroinitializer"; return; } diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -87,7 +87,7 @@ // constant zero is zero for aggregates, cpnull is null for pointers, none for // tokens. return isa(this) || isa(this) || - isa(this); + isa(this) || isa(this); } bool Constant::isAllOnesValue() const { @@ -369,6 +369,8 @@ return ConstantAggregateZero::get(Ty); case Type::TokenTyID: return ConstantTokenNone::get(Ty->getContext()); + case Type::TargetExtTyID: + return ConstantTargetNone::get(cast(Ty)); default: // Function, Label, or Opaque type? llvm_unreachable("Cannot create a null constant of that type!"); @@ -1710,6 +1712,25 @@ getContext().pImpl->CPNConstants.erase(getType()); } +//---- ConstantTargetNone::get() implementation. +// + +ConstantTargetNone *ConstantTargetNone::get(TargetExtType *Ty) { + assert(Ty->hasProperty(TargetExtType::HasZeroInit) && + "Target extension type not allowed to have a zeroinitializer"); + std::unique_ptr &Entry = + Ty->getContext().pImpl->CTNConstants[Ty]; + if (!Entry) + Entry.reset(new ConstantTargetNone(Ty)); + + return Entry.get(); +} + +/// Remove the constant from the constant table. +void ConstantTargetNone::destroyConstantImpl() { + getContext().pImpl->CTNConstants.erase(getType()); +} + UndefValue *UndefValue::get(Type *Ty) { std::unique_ptr &Entry = Ty->getContext().pImpl->UVConstants[Ty]; if (!Entry) diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -540,6 +540,8 @@ return LLVMTokenTypeKind; case Type::ScalableVectorTyID: return LLVMScalableVectorTypeKind; + case Type::TargetExtTyID: + return LLVMTargetExtTypeKind; case Type::TypedPointerTyID: llvm_unreachable("Typed pointers are unsupported via the C API"); } @@ -858,6 +860,17 @@ return LLVMLabelTypeInContext(LLVMGetGlobalContext()); } +LLVMTypeRef LLVMTargetExtTypeInContext(LLVMContextRef C, const char *Name, + LLVMTypeRef *TypeParams, + unsigned TypeParamCount, + unsigned *IntParams, + unsigned IntParamCount) { + ArrayRef TypeParamArray(unwrap(TypeParams), TypeParamCount); + ArrayRef IntParamArray(IntParams, IntParamCount); + return wrap( + TargetExtType::get(*unwrap(C), Name, TypeParamArray, IntParamArray)); +} + /*===-- Operations on values ----------------------------------------------===*/ /*--.. Operations on all values ............................................--*/ diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp --- a/llvm/lib/IR/DataLayout.cpp +++ b/llvm/lib/IR/DataLayout.cpp @@ -816,6 +816,10 @@ } case Type::X86_AMXTyID: return Align(64); + case Type::TargetExtTyID: { + Type *LayoutTy = cast(Ty)->getLayoutType(); + return getAlignment(LayoutTy, abi_or_pref); + } default: llvm_unreachable("Bad type for getAlignment!!!"); } diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -941,6 +941,15 @@ Result += "nx"; Result += "v" + utostr(EC.getKnownMinValue()) + getMangledTypeStr(VTy->getElementType(), HasUnnamedType); + } else if (TargetExtType *TETy = dyn_cast(Ty)) { + Result += "t"; + Result += TETy->getName(); + for (Type *ParamTy : TETy->type_params()) + Result += "_" + getMangledTypeStr(ParamTy, HasUnnamedType); + for (unsigned IntParam : TETy->int_params()) + Result += "_" + utostr(IntParam); + // Ensure nested target extension types are distinguishable. + Result += "t"; } else if (Ty) { switch (Ty->getTypeID()) { default: llvm_unreachable("Unhandled type"); diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -191,6 +191,55 @@ } }; +struct TargetExtTypeKeyInfo { + struct KeyTy { + StringRef Name; + ArrayRef TypeParams; + ArrayRef IntParams; + + KeyTy(StringRef N, const ArrayRef &TP, const ArrayRef &IP) + : Name(N), TypeParams(TP), IntParams(IP) {} + KeyTy(const TargetExtType *TT) + : Name(TT->getName()), TypeParams(TT->type_params()), + IntParams(TT->int_params()) {} + + bool operator==(const KeyTy &that) const { + return Name == that.Name && TypeParams == that.TypeParams && + IntParams == that.IntParams; + } + bool operator!=(const KeyTy &that) const { return !this->operator==(that); } + }; + + static inline TargetExtType *getEmptyKey() { + return DenseMapInfo::getEmptyKey(); + } + + static inline TargetExtType *getTombstoneKey() { + return DenseMapInfo::getTombstoneKey(); + } + + static unsigned getHashValue(const KeyTy &Key) { + return hash_combine( + Key.Name, + hash_combine_range(Key.TypeParams.begin(), Key.TypeParams.end()), + hash_combine_range(Key.IntParams.begin(), Key.IntParams.end())); + } + + static unsigned getHashValue(const TargetExtType *FT) { + return getHashValue(KeyTy(FT)); + } + + static bool isEqual(const KeyTy &LHS, const TargetExtType *RHS) { + if (RHS == getEmptyKey() || RHS == getTombstoneKey()) + return false; + return LHS == KeyTy(RHS); + } + + static bool isEqual(const TargetExtType *LHS, const TargetExtType *RHS) { + return LHS == RHS; + } +}; + /// Structure for hashing arbitrary MDNode operands. class MDNodeOpsKey { ArrayRef RawOps; @@ -1440,6 +1489,8 @@ DenseMap> CPNConstants; + DenseMap> CTNConstants; + DenseMap> UVConstants; DenseMap> PVConstants; @@ -1480,6 +1531,9 @@ StringMap NamedStructTypes; unsigned NamedStructTypesUniqueID = 0; + using TargetExtTypeSet = DenseSet; + TargetExtTypeSet TargetExtTypes; + DenseMap, ArrayType *> ArrayTypes; DenseMap, VectorType *> VectorTypes; DenseMap PointerTypes; // Pointers in AddrSpace = 0 diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp --- a/llvm/lib/IR/LLVMContextImpl.cpp +++ b/llvm/lib/IR/LLVMContextImpl.cpp @@ -113,6 +113,7 @@ CAZConstants.clear(); CPNConstants.clear(); + CTNConstants.clear(); UVConstants.clear(); PVConstants.clear(); IntConstants.clear(); diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -211,6 +211,9 @@ if (auto *VTy = dyn_cast(this)) return VTy->getElementType()->isSized(Visited); + if (auto *TTy = dyn_cast(this)) + return TTy->getLayoutType()->isSized(Visited); + return cast(this)->isSized(Visited); } @@ -783,3 +786,82 @@ bool PointerType::isLoadableOrStorableType(Type *ElemTy) { return isValidElementType(ElemTy) && !ElemTy->isFunctionTy(); } + +//===----------------------------------------------------------------------===// +// TargetExtType Implementation +//===----------------------------------------------------------------------===// + +TargetExtType::TargetExtType(LLVMContext &C, StringRef Name, + ArrayRef Types, ArrayRef Ints) + : Type(C, TargetExtTyID), Name(Name) { + NumContainedTys = Types.size(); + + // Parameter storage immediately follows the class in allocation. + Type **Params = reinterpret_cast(this + 1); + ContainedTys = Params; + for (Type *T : Types) + *Params++ = T; + + setSubclassData(Ints.size()); + unsigned *IntParamSpace = reinterpret_cast(Params); + IntParams = IntParamSpace; + for (unsigned IntParam : Ints) + *IntParamSpace++ = IntParam; +} + +TargetExtType *TargetExtType::get(LLVMContext &C, StringRef Name, + ArrayRef Types, + ArrayRef Ints) { + const TargetExtTypeKeyInfo::KeyTy Key(Name, Types, Ints); + TargetExtType *TT; + // Since we only want to allocate a fresh target type in case none is found + // and we don't want to perform two lookups (one for checking if existent and + // one for inserting the newly allocated one), here we instead lookup based on + // Key and update the reference to the target type in-place to a newly + // allocated one if not found. + auto Insertion = C.pImpl->TargetExtTypes.insert_as(nullptr, Key); + if (Insertion.second) { + // The target type was not found. Allocate one and update TargetExtTypes + // in-place. + TT = (TargetExtType *)C.pImpl->Alloc.Allocate( + sizeof(TargetExtType) + sizeof(Type *) * Types.size() + + sizeof(unsigned) * Ints.size(), + alignof(TargetExtType)); + new (TT) TargetExtType(C, Name, Types, Ints); + *Insertion.first = TT; + } else { + // The target type was found. Just return it. + TT = *Insertion.first; + } + return TT; +} + +namespace { +struct TargetTypeInfo { + Type *LayoutType; + uint64_t Properties; + + template + TargetTypeInfo(Type *LayoutType, ArgTys... Properties) + : LayoutType(LayoutType), Properties((0 | ... | Properties)) {} +}; +} // anonymous namespace + +static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) { + LLVMContext &C = Ty->getContext(); + StringRef Name = Ty->getName(); + if (Name.startswith("spirv.")) { + return TargetTypeInfo(Type::getInt8PtrTy(C, 0), TargetExtType::HasZeroInit, + TargetExtType::CanBeGlobal); + } + return TargetTypeInfo(Type::getVoidTy(C)); +} + +Type *TargetExtType::getLayoutType() const { + return getTargetTypeInfo(this).LayoutType; +} + +bool TargetExtType::hasProperty(Property Prop) const { + uint64_t Properties = getTargetTypeInfo(this).Properties; + return (Properties & Prop) == Prop; +} diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -798,6 +798,13 @@ Check(!STy->containsScalableVectorType(), "Globals cannot contain scalable vectors", &GV); + // Check if it's a target extension type that disallows being used as a + // global. + if (auto *TTy = dyn_cast(GV.getValueType())) + Check(TTy->hasProperty(TargetExtType::CanBeGlobal), + "Global @" + GV.getName() + " has illegal target extension type", + TTy); + if (!GV.hasInitializer()) { visitGlobalValue(GV); return; diff --git a/llvm/lib/Transforms/Scalar/SROA.cpp b/llvm/lib/Transforms/Scalar/SROA.cpp --- a/llvm/lib/Transforms/Scalar/SROA.cpp +++ b/llvm/lib/Transforms/Scalar/SROA.cpp @@ -1942,6 +1942,9 @@ return false; } + if (OldTy->isTargetExtTy() || NewTy->isTargetExtTy()) + return false; + return true; } diff --git a/llvm/lib/Transforms/Utils/VNCoercion.cpp b/llvm/lib/Transforms/Utils/VNCoercion.cpp --- a/llvm/lib/Transforms/Utils/VNCoercion.cpp +++ b/llvm/lib/Transforms/Utils/VNCoercion.cpp @@ -57,10 +57,13 @@ // The implementation below uses inttoptr for vectors of unequal size; we // can't allow this for non integral pointers. We could teach it to extract - // exact subvectors if desired. + // exact subvectors if desired. if (StoredNI && StoreSize != DL.getTypeSizeInBits(LoadTy).getFixedSize()) return false; + if (StoredTy->isTargetExtTy() || LoadTy->isTargetExtTy()) + return false; + return true; } diff --git a/llvm/test/Assembler/invalid-target-type-mixed.ll b/llvm/test/Assembler/invalid-target-type-mixed.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/invalid-target-type-mixed.ll @@ -0,0 +1,6 @@ +; RUN: not llvm-as < %s -disable-output 2>&1 | FileCheck %s + +; CHECK: expected uint32 param +define void @f(target("type", i32, 0, void) %a) { + ret void +} diff --git a/llvm/test/Assembler/target-type-mangled.ll b/llvm/test/Assembler/target-type-mangled.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/target-type-mangled.ll @@ -0,0 +1,11 @@ +; RUN: llvm-as < %s | llvm-dis | FileCheck %s +; Check support for mangling of target extension types in intrinsics + +declare target("a", target("b")) @llvm.ssa.copy.ta_tbtt(target("a", target("b")) returned) +declare target("a", void, i8, 5, 3) @llvm.ssa.copy.ta_isVoid_i8_5_3t(target("a", void, i8, 5, 3) returned) +declare target("b") @llvm.ssa.copy.tbt(target("b") returned) + +; CHECK: declare target("a", target("b")) @llvm.ssa.copy.ta_tbtt(target("a", target("b")) returned) +; CHECK: declare target("a", void, i8, 5, 3) @llvm.ssa.copy.ta_isVoid_i8_5_3t(target("a", void, i8, 5, 3) returned) +; CHECK: declare target("b") @llvm.ssa.copy.tbt(target("b") returned) + diff --git a/llvm/test/Assembler/target-type-params.ll b/llvm/test/Assembler/target-type-params.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/target-type-params.ll @@ -0,0 +1,16 @@ +; RUN: llvm-as < %s | llvm-dis | FileCheck %s +; Check support for basic target extension type properties + +declare void @g1(target("atype", void)) +declare void @g2(target("atype", i32)) +declare void @g3(target("atype", i32, 0)) +declare void @g4(target("atype", 0)) +declare void @g5(target("atype", 0, 1, 2)) +declare void @g6(target("atype", void, i32, float, {i32, bfloat}, 0, 1, 2)) + +;CHECK: declare void @g1(target("atype", void)) +;CHECK: declare void @g2(target("atype", i32)) +;CHECK: declare void @g3(target("atype", i32, 0)) +;CHECK: declare void @g4(target("atype", 0)) +;CHECK: declare void @g5(target("atype", 0, 1, 2)) +;CHECK: declare void @g6(target("atype", void, i32, float, { i32, bfloat }, 0, 1, 2)) diff --git a/llvm/test/Assembler/target-type-properties.ll b/llvm/test/Assembler/target-type-properties.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/target-type-properties.ll @@ -0,0 +1,16 @@ +; RUN: split-file %s %t +; RUN: not llvm-as < %t/zeroinit-error.ll -o /dev/null 2>&1 | FileCheck --check-prefix=CHECK-ZEROINIT %s +; RUN: not llvm-as < %t/global-var.ll -o /dev/null 2>&1 | FileCheck --check-prefix=CHECK-GLOBALVAR %s +; Check target extension type properties are verified in the assembler. + +;--- zeroinit-error.ll +define void @foo() { + %val = freeze target("spirv.DeviceEvent") zeroinitializer + %val2 = freeze target("unknown_target_type") zeroinitializer +; CHECK-ZEROINIT: error: invalid type for null constant + ret void +} + +;--- global-var.ll +@global = external global target("unknown_target_type") +; CHECK-GLOBALVAR: Global @global has illegal target extension type diff --git a/llvm/test/Assembler/target-types.ll b/llvm/test/Assembler/target-types.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Assembler/target-types.ll @@ -0,0 +1,24 @@ +; RUN: llvm-as < %s | llvm-dis | FileCheck %s +; Check support for basic target extension type usage + +@global = global target("spirv.DeviceEvent") zeroinitializer + +define target("spirv.Sampler") @foo(target("spirv.Sampler") %a) { + ret target("spirv.Sampler") %a +} + +define target("spirv.Event") @func2() { + %mem = alloca target("spirv.Event") + %val = load target("spirv.Event"), ptr %mem + ret target("spirv.Event") poison +} + +; CHECK: @global = global target("spirv.DeviceEvent") zeroinitializer +; CHECK: define target("spirv.Sampler") @foo(target("spirv.Sampler") %a) { +; CHECK: ret target("spirv.Sampler") %a +; CHECK: } +; CHECK: define target("spirv.Event") @func2() { +; CHECK: %mem = alloca target("spirv.Event") +; CHECK: %val = load target("spirv.Event"), ptr %mem +; CHECK: ret target("spirv.Event") poison +; CHECK: } diff --git a/llvm/test/Transforms/GVN/target-type.ll b/llvm/test/Transforms/GVN/target-type.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/GVN/target-type.ll @@ -0,0 +1,52 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes=gvn < %s | FileCheck %s + +; Check that GVN can work with target extension types correctly. + +target datalayout = "e-p:64:64:64-p1:16:16:16-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-n8:16:32:64" + +define target("spirv.DeviceEvent") @basic_alloc(target("spirv.DeviceEvent") %arg) { +; CHECK-LABEL: @basic_alloc( +; CHECK-NEXT: [[VAL:%.*]] = alloca target("spirv.DeviceEvent"), align 8 +; CHECK-NEXT: store target("spirv.DeviceEvent") [[ARG:%.*]], ptr [[VAL]], align 8 +; CHECK-NEXT: ret target("spirv.DeviceEvent") [[ARG]] +; + %val = alloca target("spirv.DeviceEvent") + store target("spirv.DeviceEvent") %arg, ptr %val + %ret = load target("spirv.DeviceEvent"), ptr %val + ret target("spirv.DeviceEvent") %ret +} + +define target("spirv.DeviceEvent") @nobitcast(ptr %arg) { +; CHECK-LABEL: @nobitcast( +; CHECK-NEXT: [[VAL:%.*]] = alloca target("spirv.DeviceEvent"), align 8 +; CHECK-NEXT: store ptr [[ARG:%.*]], ptr [[VAL]], align 8 +; CHECK-NEXT: [[RET:%.*]] = load target("spirv.DeviceEvent"), ptr [[VAL]], align 8 +; CHECK-NEXT: ret target("spirv.DeviceEvent") [[RET]] +; + %val = alloca target("spirv.DeviceEvent") + store ptr %arg, ptr %val + %ret = load target("spirv.DeviceEvent"), ptr %val + ret target("spirv.DeviceEvent") %ret +} + +define target("spirv.DeviceEvent") @viai64(target("spirv.DeviceEvent") %arg) { +; CHECK-LABEL: @viai64( +; CHECK-NEXT: [[VAL:%.*]] = alloca target("spirv.DeviceEvent"), align 8 +; CHECK-NEXT: [[BAR:%.*]] = alloca target("spirv.DeviceEvent"), align 8 +; CHECK-NEXT: store target("spirv.DeviceEvent") [[ARG:%.*]], ptr [[VAL]], align 8 +; CHECK-NEXT: [[IMEMCPY:%.*]] = load i64, ptr [[VAL]], align 4 +; CHECK-NEXT: store i64 [[IMEMCPY]], ptr [[BAR]], align 4 +; CHECK-NEXT: [[RET:%.*]] = load target("spirv.DeviceEvent"), ptr [[BAR]], align 8 +; CHECK-NEXT: ret target("spirv.DeviceEvent") [[RET]] +; + %val = alloca target("spirv.DeviceEvent") + %bar = alloca target("spirv.DeviceEvent") + store target("spirv.DeviceEvent") %arg, ptr %val + %imemcpy = load i64, ptr %val + store i64 %imemcpy, ptr %bar + %ret = load target("spirv.DeviceEvent"), ptr %bar + ret target("spirv.DeviceEvent") %ret +} + +declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1) diff --git a/llvm/test/Transforms/SROA/sroa-target.ll b/llvm/test/Transforms/SROA/sroa-target.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/SROA/sroa-target.ll @@ -0,0 +1,62 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes=sroa < %s | FileCheck %s + +; Check that SROA can work with target extension types correctly. + +target datalayout = "e-p:64:64:64-p1:16:16:16-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-n8:16:32:64" + +define target("spirv.DeviceEvent") @basic_alloc(target("spirv.DeviceEvent") %arg) { +; CHECK-LABEL: @basic_alloc( +; CHECK-NEXT: ret target("spirv.DeviceEvent") [[ARG:%.*]] +; + %val = alloca target("spirv.DeviceEvent") + store target("spirv.DeviceEvent") %arg, ptr %val + %ret = load target("spirv.DeviceEvent"), ptr %val + ret target("spirv.DeviceEvent") %ret +} + +define target("spirv.DeviceEvent") @via_memcpy(target("spirv.DeviceEvent") %arg) { +; CHECK-LABEL: @via_memcpy( +; CHECK-NEXT: ret target("spirv.DeviceEvent") [[ARG:%.*]] +; + %val = alloca target("spirv.DeviceEvent") + %bar = alloca target("spirv.DeviceEvent") + store target("spirv.DeviceEvent") %arg, ptr %val + call void @llvm.memcpy.p0.p0.i64(ptr %bar, ptr %val, i64 8, i1 false) + %ret = load target("spirv.DeviceEvent"), ptr %bar + ret target("spirv.DeviceEvent") %ret +} + +define target("spirv.DeviceEvent") @nobitcast(ptr %arg) { +; CHECK-LABEL: @nobitcast( +; CHECK-NEXT: [[VAL:%.*]] = alloca target("spirv.DeviceEvent"), align 8 +; CHECK-NEXT: store ptr [[ARG:%.*]], ptr [[VAL]], align 8 +; CHECK-NEXT: [[VAL_0_RET:%.*]] = load target("spirv.DeviceEvent"), ptr [[VAL]], align 8 +; CHECK-NEXT: ret target("spirv.DeviceEvent") [[VAL_0_RET]] +; + %val = alloca target("spirv.DeviceEvent") + store ptr %arg, ptr %val + %ret = load target("spirv.DeviceEvent"), ptr %val + ret target("spirv.DeviceEvent") %ret +} + +define target("spirv.DeviceEvent") @viai64(target("spirv.DeviceEvent") %arg) { +; CHECK-LABEL: @viai64( +; CHECK-NEXT: [[VAL:%.*]] = alloca target("spirv.DeviceEvent"), align 8 +; CHECK-NEXT: [[BAR:%.*]] = alloca target("spirv.DeviceEvent"), align 8 +; CHECK-NEXT: store target("spirv.DeviceEvent") [[ARG:%.*]], ptr [[VAL]], align 8 +; CHECK-NEXT: [[VAL_0_IMEMCPY:%.*]] = load i64, ptr [[VAL]], align 8 +; CHECK-NEXT: store i64 [[VAL_0_IMEMCPY]], ptr [[BAR]], align 8 +; CHECK-NEXT: [[BAR_0_RET:%.*]] = load target("spirv.DeviceEvent"), ptr [[BAR]], align 8 +; CHECK-NEXT: ret target("spirv.DeviceEvent") [[BAR_0_RET]] +; + %val = alloca target("spirv.DeviceEvent") + %bar = alloca target("spirv.DeviceEvent") + store target("spirv.DeviceEvent") %arg, ptr %val + %imemcpy = load i64, ptr %val + store i64 %imemcpy, ptr %bar + %ret = load target("spirv.DeviceEvent"), ptr %bar + ret target("spirv.DeviceEvent") %ret +} + +declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1) diff --git a/llvm/tools/llvm-c-test/echo.cpp b/llvm/tools/llvm-c-test/echo.cpp --- a/llvm/tools/llvm-c-test/echo.cpp +++ b/llvm/tools/llvm-c-test/echo.cpp @@ -159,6 +159,8 @@ return LLVMX86MMXTypeInContext(Ctx); case LLVMTokenTypeKind: return LLVMTokenTypeInContext(Ctx); + case LLVMTargetExtTypeKind: + assert(false && "Implement me"); } fprintf(stderr, "%d is not a supported typekind\n", Kind); diff --git a/llvm/unittests/IR/TypesTest.cpp b/llvm/unittests/IR/TypesTest.cpp --- a/llvm/unittests/IR/TypesTest.cpp +++ b/llvm/unittests/IR/TypesTest.cpp @@ -61,6 +61,17 @@ EXPECT_FALSE(P2C0->isOpaque()); } +TEST(TypesTest, TargetExtType) { + LLVMContext Context; + Type *A = TargetExtType::get(Context, "typea"); + Type *Aparam = TargetExtType::get(Context, "typea", {}, {0, 1}); + Type *Aparam2 = TargetExtType::get(Context, "typea", {}, {0, 1}); + // Opaque types with same parameters are identical... + EXPECT_EQ(Aparam, Aparam2); + // ... but just having the same name is not enough. + EXPECT_NE(A, Aparam); +} + TEST(TypedPointerType, PrintTest) { std::string Buffer; LLVMContext Context;