diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -194,6 +194,8 @@ DependentAddressSpaceTypes; mutable llvm::FoldingSet VectorTypes; mutable llvm::FoldingSet DependentVectorTypes; + mutable llvm::FoldingSet MatrixTypes; + mutable llvm::FoldingSet DependentSizedMatrixTypes; mutable llvm::FoldingSet FunctionNoProtoTypes; mutable llvm::ContextualFoldingSet FunctionProtoTypes; @@ -1326,6 +1328,20 @@ Expr *SizeExpr, SourceLocation AttrLoc) const; + /// Return the unique reference to the matrix type of the specified element + /// type and size + /// + /// \pre \p ElementType must be a valid matrix element type (see + /// MatrixType::isValidElementType). + QualType getConstantMatrixType(QualType ElementType, unsigned NumRows, + unsigned NumColumns) const; + + /// Return the unique reference to the matrix type of the specified element + /// type and size + QualType getDependentSizedMatrixType(QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttrLoc) const; + QualType getDependentAddressSpaceType(QualType PointeeType, Expr *AddrSpaceExpr, SourceLocation AttrLoc) const; diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1006,6 +1006,17 @@ DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); }) +DEF_TRAVERSE_TYPE(ConstantMatrixType, + { TRY_TO(TraverseType(T->getElementType())); }) + +DEF_TRAVERSE_TYPE(DependentSizedMatrixType, { + if (T->getRowExpr()) + TRY_TO(TraverseStmt(T->getRowExpr())); + if (T->getColumnExpr()) + TRY_TO(TraverseStmt(T->getColumnExpr())); + TRY_TO(TraverseType(T->getElementType())); +}) + DEF_TRAVERSE_TYPE(FunctionNoProtoType, { TRY_TO(TraverseType(T->getReturnType())); }) @@ -1258,6 +1269,18 @@ TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); }) +DEF_TRAVERSE_TYPELOC(ConstantMatrixType, { + TRY_TO(TraverseStmt(TL.getAttrRowOperand())); + TRY_TO(TraverseStmt(TL.getAttrColumnOperand())); + TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); +}) + +DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, { + TRY_TO(TraverseStmt(TL.getAttrRowOperand())); + TRY_TO(TraverseStmt(TL.getAttrColumnOperand())); + TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); +}) + DEF_TRAVERSE_TYPELOC(FunctionNoProtoType, { TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); }) diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -1654,6 +1654,19 @@ uint32_t NumElements; }; + class ConstantMatrixTypeBitfields { + friend class ConstantMatrixType; + + unsigned : NumTypeBits; + + /// Number of rows and columns. Using 20 bits allows supporting very large + /// matrixes, while keeping 24 bits to accommodate NumTypeBits. + unsigned NumRows : 20; + unsigned NumColumns : 20; + + static constexpr uint32_t MaxElementsPerDimension = (1 << 20) - 1; + }; + class AttributedTypeBitfields { friend class AttributedType; @@ -1763,6 +1776,7 @@ TypeWithKeywordBitfields TypeWithKeywordBits; ElaboratedTypeBitfields ElaboratedTypeBits; VectorTypeBitfields VectorTypeBits; + ConstantMatrixTypeBitfields ConstantMatrixTypeBits; SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits; TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits; DependentTemplateSpecializationTypeBitfields @@ -2021,6 +2035,7 @@ bool isComplexIntegerType() const; // GCC _Complex integer type. bool isVectorType() const; // GCC vector type. bool isExtVectorType() const; // Extended vector type. + bool isConstantMatrixType() const; // Matrix type. bool isDependentAddressSpaceType() const; // value-dependent address space qualifier bool isObjCObjectPointerType() const; // pointer to ObjC object bool isObjCRetainableType() const; // ObjC object or block pointer @@ -3390,6 +3405,131 @@ } }; +/// Represents a matrix type, as defined in the Matrix Types clang extensions. +/// __attribute__((matrix_type(rows, columns))), where "rows" specifies +/// number of rows and "columns" specifies the number of columns. +class MatrixType : public Type, public llvm::FoldingSetNode { +protected: + friend class ASTContext; + + /// The element type of the matrix. + QualType ElementType; + + MatrixType(QualType ElementTy, QualType CanonElementTy); + + MatrixType(TypeClass TypeClass, QualType ElementTy, QualType CanonElementTy, + const Expr *RowExpr = nullptr, const Expr *ColumnExpr = nullptr); + +public: + /// Returns type of the elements being stored in the matrix + QualType getElementType() const { return ElementType; } + + /// Valid elements types are the following: + /// * an integer type (as in C2x 6.2.5p19), but excluding enumerated types + /// and _Bool + /// * the standard floating types float or double + /// * a half-precision floating point type, if one is supported on the target + static bool isValidElementType(QualType T) { + return T->isDependentType() || + (T->isRealType() && !T->isBooleanType() && !T->isEnumeralType()); + } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + static bool classof(const Type *T) { + return T->getTypeClass() == ConstantMatrix || + T->getTypeClass() == DependentSizedMatrix; + } +}; + +/// Represents a concrete matrix type with constant number of rows and columns +class ConstantMatrixType final : public MatrixType { +protected: + friend class ASTContext; + + /// The element type of the matrix. + QualType ElementType; + + ConstantMatrixType(QualType MatrixElementType, unsigned NRows, + unsigned NColumns, QualType CanonElementType); + + ConstantMatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows, + unsigned NColumns, QualType CanonElementType); + +public: + /// Returns the number of rows in the matrix. + unsigned getNumRows() const { return ConstantMatrixTypeBits.NumRows; } + + /// Returns the number of columns in the matrix. + unsigned getNumColumns() const { return ConstantMatrixTypeBits.NumColumns; } + + /// Returns the number of elements required to embed the matrix into a vector. + unsigned getNumElementsFlattened() const { + return ConstantMatrixTypeBits.NumRows * ConstantMatrixTypeBits.NumColumns; + } + + /// Returns true if \p NumElements is a valid matrix dimension. + static bool isDimensionValid(uint64_t NumElements) { + return NumElements > 0 && + NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension; + } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, getElementType(), getNumRows(), getNumColumns(), + getTypeClass()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType, + unsigned NumRows, unsigned NumColumns, + TypeClass TypeClass) { + ID.AddPointer(ElementType.getAsOpaquePtr()); + ID.AddInteger(NumRows); + ID.AddInteger(NumColumns); + ID.AddInteger(TypeClass); + } + + static bool classof(const Type *T) { + return T->getTypeClass() == ConstantMatrix; + } +}; + +/// Represents a matrix type where the type and the number of rows and columns +/// is dependent on a template. +class DependentSizedMatrixType final : public MatrixType { + friend class ASTContext; + + const ASTContext &Context; + Expr *RowExpr; + Expr *ColumnExpr; + + SourceLocation loc; + + DependentSizedMatrixType(const ASTContext &Context, QualType ElementType, + QualType CanonicalType, Expr *RowExpr, + Expr *ColumnExpr, SourceLocation loc); + +public: + QualType getElementType() const { return ElementType; } + Expr *getRowExpr() const { return RowExpr; } + Expr *getColumnExpr() const { return ColumnExpr; } + SourceLocation getAttributeLoc() const { return loc; } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + static bool classof(const Type *T) { + return T->getTypeClass() == DependentSizedMatrix; + } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context, + QualType ElementType, Expr *RowExpr, Expr *ColumnExpr); +}; + /// FunctionType - C99 6.7.5.3 - Function Declarators. This is the common base /// class of FunctionNoProtoType and FunctionProtoType. class FunctionType : public Type { @@ -6605,6 +6745,10 @@ return isa(CanonicalType); } +inline bool Type::isConstantMatrixType() const { + return isa(CanonicalType); +} + inline bool Type::isDependentAddressSpaceType() const { return isa(CanonicalType); } diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h --- a/clang/include/clang/AST/TypeLoc.h +++ b/clang/include/clang/AST/TypeLoc.h @@ -1735,6 +1735,7 @@ void initializeLocal(ASTContext &Context, SourceLocation loc) { setAttrNameLoc(loc); + setAttrOperandParensRange(loc); setAttrOperandParensRange(SourceRange(loc)); setAttrExprOperand(getTypePtr()->getAddrSpaceExpr()); } @@ -1774,6 +1775,68 @@ DependentSizedExtVectorType> { }; +struct MatrixTypeLocInfo { + SourceLocation AttrLoc; + SourceRange OperandParens; + Expr *RowOperand; + Expr *ColumnOperand; +}; + +class MatrixTypeLoc : public ConcreteTypeLoc { +public: + /// The location of the attribute name, i.e. + /// float __attribute__((matrix_type(4, 2))) + /// ^~~~~~~~~~~~~~~~~ + SourceLocation getAttrNameLoc() const { return getLocalData()->AttrLoc; } + void setAttrNameLoc(SourceLocation loc) { getLocalData()->AttrLoc = loc; } + + /// The attribute's row operand, if it has one. + /// float __attribute__((matrix_type(4, 2))) + /// ^ + Expr *getAttrRowOperand() const { return getLocalData()->RowOperand; } + void setAttrRowOperand(Expr *e) { getLocalData()->RowOperand = e; } + + /// The attribute's column operand, if it has one. + /// float __attribute__((matrix_type(4, 2))) + /// ^ + Expr *getAttrColumnOperand() const { return getLocalData()->ColumnOperand; } + void setAttrColumnOperand(Expr *e) { getLocalData()->ColumnOperand = e; } + + /// The location of the parentheses around the operand, if there is + /// an operand. + /// float __attribute__((matrix_type(4, 2))) + /// ^ ^ + SourceRange getAttrOperandParensRange() const { + return getLocalData()->OperandParens; + } + void setAttrOperandParensRange(SourceRange range) { + getLocalData()->OperandParens = range; + } + + SourceRange getLocalSourceRange() const { + SourceRange range(getAttrNameLoc()); + range.setEnd(getAttrOperandParensRange().getEnd()); + return range; + } + + void initializeLocal(ASTContext &Context, SourceLocation loc) { + setAttrNameLoc(loc); + setAttrOperandParensRange(loc); + setAttrRowOperand(nullptr); + setAttrColumnOperand(nullptr); + } +}; + +class ConstantMatrixTypeLoc + : public InheritingConcreteTypeLoc {}; + +class DependentSizedMatrixTypeLoc + : public InheritingConcreteTypeLoc {}; + // FIXME: location of the '_Complex' keyword. class ComplexTypeLoc : public InheritingConcreteTypeLoc; } +let Class = MatrixType in { + def : Property<"elementType", QualType> { + let Read = [{ node->getElementType() }]; + } +} + +let Class = ConstantMatrixType in { + def : Property<"numRows", UInt32> { + let Read = [{ node->getNumRows() }]; + } + def : Property<"numColumns", UInt32> { + let Read = [{ node->getNumColumns() }]; + } + + def : Creator<[{ + return ctx.getConstantMatrixType(elementType, numRows, numColumns); + }]>; +} + +let Class = DependentSizedMatrixType in { + def : Property<"rows", ExprRef> { + let Read = [{ node->getRowExpr() }]; + } + def : Property<"columns", ExprRef> { + let Read = [{ node->getColumnExpr() }]; + } + def : Property<"attributeLoc", SourceLocation> { + let Read = [{ node->getAttributeLoc() }]; + } + + def : Creator<[{ + return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc); + }]>; +} + let Class = FunctionType in { def : Property<"returnType", QualType> { let Read = [{ node->getReturnType() }]; diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -2460,6 +2460,15 @@ let Documentation = [Undocumented]; } +def MatrixType : TypeAttr { + let Spellings = [Clang<"matrix_type">]; + let Subjects = SubjectList<[TypedefName], ErrorDiag>; + let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">]; + let Documentation = [Undocumented]; + let ASTNode = 0; + let PragmaAttributeSupport = 0; +} + def Visibility : InheritableAttr { let Clone = 0; let Spellings = [GCC<"visibility">]; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -2774,6 +2774,7 @@ def err_attribute_too_few_arguments : Error< "%0 attribute takes at least %1 argument%s1">; def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">; +def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">; def err_attribute_bad_neon_vector_size : Error< "Neon vector size must be 64 or 128 bits">; def err_attribute_requires_positive_integer : Error< @@ -2877,8 +2878,8 @@ "init methods must return an object pointer type, not %0">; def err_attribute_invalid_size : Error< "vector size not an integral multiple of component size">; -def err_attribute_zero_size : Error<"zero vector size">; -def err_attribute_size_too_large : Error<"vector size too large">; +def err_attribute_zero_size : Error<"zero %0 size">; +def err_attribute_size_too_large : Error<"%0 size too large">; def err_typecheck_vector_not_convertable_implict_truncation : Error< "cannot convert between %select{scalar|vector}0 type %1 and vector type" " %2 as implicit conversion would cause truncation">; @@ -10741,6 +10742,9 @@ "%select{non-pointer|function pointer|void pointer}0 argument to " "'__builtin_launder' is not allowed">; +def err_builtin_matrix_disabled: Error< + "matrix types extension is disabled. Pass -fenable-matrix to enable it">; + def err_preserve_field_info_not_field : Error< "__builtin_preserve_field_info argument %0 not a field access">; def err_preserve_field_info_not_const: Error< diff --git a/clang/include/clang/Basic/Features.def b/clang/include/clang/Basic/Features.def --- a/clang/include/clang/Basic/Features.def +++ b/clang/include/clang/Basic/Features.def @@ -253,6 +253,7 @@ EXTENSION(pragma_clang_attribute_external_declaration, true) EXTENSION(gnu_asm, LangOpts.GNUAsm) EXTENSION(gnu_asm_goto_with_outputs, LangOpts.GNUAsm) +EXTENSION(matrix_types, LangOpts.MatrixTypes) #undef EXTENSION #undef FEATURE diff --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def --- a/clang/include/clang/Basic/LangOptions.def +++ b/clang/include/clang/Basic/LangOptions.def @@ -358,6 +358,8 @@ LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors") +LANGOPT(MatrixTypes, 1, 0, "Enable or disable the builtin matrix type") + COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0") ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None, diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td --- a/clang/include/clang/Basic/TypeNodes.td +++ b/clang/include/clang/Basic/TypeNodes.td @@ -69,6 +69,9 @@ def VectorType : TypeNode; def DependentVectorType : TypeNode, AlwaysDependent; def ExtVectorType : TypeNode; +def MatrixType : TypeNode; +def ConstantMatrixType : TypeNode; +def DependentSizedMatrixType : TypeNode, AlwaysDependent; def FunctionType : TypeNode; def FunctionProtoType : TypeNode; def FunctionNoProtoType : TypeNode; diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -2014,6 +2014,10 @@ def fno_strict_return : Flag<["-"], "fno-strict-return">, Group, Flags<[CC1Option]>; +def fenable_matrix : Flag<["-"], "fenable-matrix">, Group, + Flags<[CC1Option]>, + HelpText<"Enable matrix data type and related builtin functions">; + def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">, Group, Flags<[CC1Option]>, HelpText<"Treat editor placeholders as valid source code">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -1627,6 +1627,9 @@ QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc); QualType BuildExtVectorType(QualType T, Expr *ArraySize, SourceLocation AttrLoc); + QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns, + SourceLocation AttrLoc); + QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace, SourceLocation AttrLoc); diff --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def --- a/clang/include/clang/Serialization/TypeBitCodes.def +++ b/clang/include/clang/Serialization/TypeBitCodes.def @@ -60,5 +60,7 @@ TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49) TYPE_BIT_CODE(ExtInt, EXT_INT, 50) TYPE_BIT_CODE(DependentExtInt, DEPENDENT_EXT_INT, 51) +TYPE_BIT_CODE(ConstantMatrix, CONSTANT_MATRIX, 52) +TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 53) #undef TYPE_BIT_CODE diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -1932,6 +1932,17 @@ break; } + case Type::ConstantMatrix: { + const auto *MT = cast(T); + TypeInfo ElementInfo = getTypeInfo(MT->getElementType()); + // The internal layout of a matrix value is implementation defined. + // Initially be ABI compatible with arrays with respect to alignment and + // size. + Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns(); + Align = ElementInfo.Align; + break; + } + case Type::Builtin: switch (cast(T)->getKind()) { default: llvm_unreachable("Unknown builtin type!"); @@ -3362,6 +3373,8 @@ case Type::DependentVector: case Type::ExtVector: case Type::DependentSizedExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::DependentAddressSpace: case Type::ObjCObject: case Type::ObjCInterface: @@ -3775,6 +3788,78 @@ return QualType(New, 0); } +QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows, + unsigned NumColumns) const { + llvm::FoldingSetNodeID ID; + ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns, + Type::ConstantMatrix); + + assert(MatrixType::isValidElementType(ElementTy) && + "need a valid element type"); + assert(ConstantMatrixType::isDimensionValid(NumRows) && + ConstantMatrixType::isDimensionValid(NumColumns) && + "need valid matrix dimensions"); + void *InsertPos = nullptr; + if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos)) + return QualType(MTP, 0); + + QualType Canonical; + if (!ElementTy.isCanonical()) { + Canonical = + getConstantMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns); + + ConstantMatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + assert(!NewIP && "Matrix type shouldn't already exist in the map"); + (void)NewIP; + } + + auto *New = new (*this, TypeAlignment) + ConstantMatrixType(ElementTy, NumRows, NumColumns, Canonical); + MatrixTypes.InsertNode(New, InsertPos); + Types.push_back(New); + return QualType(New, 0); +} + +QualType ASTContext::getDependentSizedMatrixType(QualType ElementTy, + Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttrLoc) const { + QualType CanonElementTy = getCanonicalType(ElementTy); + llvm::FoldingSetNodeID ID; + DependentSizedMatrixType::Profile(ID, *this, CanonElementTy, RowExpr, + ColumnExpr); + + void *InsertPos = nullptr; + DependentSizedMatrixType *Canon = + DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + + if (!Canon) { + Canon = new (*this, TypeAlignment) DependentSizedMatrixType( + *this, CanonElementTy, QualType(), RowExpr, ColumnExpr, AttrLoc); +#ifndef NDEBUG + DependentSizedMatrixType *CanonCheck = + DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + assert(!CanonCheck && "Dependent-sized matrix canonical type broken"); +#endif + DependentSizedMatrixTypes.InsertNode(Canon, InsertPos); + Types.push_back(Canon); + } + + // Already have a canonical version of the matrix type + // + // If it exactly matches the requested type, use it directly. + if (Canon->getElementType() == ElementTy && Canon->getRowExpr() == RowExpr && + Canon->getRowExpr() == ColumnExpr) + return QualType(Canon, 0); + + // Use Canon as the canonical type for newly-built type. + DependentSizedMatrixType *New = new (*this, TypeAlignment) + DependentSizedMatrixType(*this, ElementTy, QualType(Canon, 0), RowExpr, + ColumnExpr, AttrLoc); + Types.push_back(New); + return QualType(New, 0); +} + QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType, Expr *AddrSpaceExpr, SourceLocation AttrLoc) const { @@ -7338,6 +7423,11 @@ *NotEncodedT = T; return; + case Type::ConstantMatrix: + if (NotEncodedT) + *NotEncodedT = T; + return; + // We could see an undeduced auto type here during error recovery. // Just ignore it. case Type::Auto: @@ -8217,6 +8307,16 @@ LHS->getNumElements() == RHS->getNumElements(); } +/// areCompatMatrixTypes - Return true if the two specified matrix types are +/// compatible. +static bool areCompatMatrixTypes(const ConstantMatrixType *LHS, + const ConstantMatrixType *RHS) { + assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified()); + return LHS->getElementType() == RHS->getElementType() && + LHS->getNumRows() == RHS->getNumRows() && + LHS->getNumColumns() == RHS->getNumColumns(); +} + bool ASTContext::areCompatibleVectorTypes(QualType FirstVec, QualType SecondVec) { assert(FirstVec->isVectorType() && "FirstVec should be a vector type"); @@ -9414,6 +9514,11 @@ RHSCan->castAs())) return LHS; return {}; + case Type::ConstantMatrix: + if (areCompatMatrixTypes(LHSCan->castAs(), + RHSCan->castAs())) + return LHS; + return {}; case Type::ObjCObject: { // Check if the types are assignment compatible. // FIXME: This should be type compatibility, e.g. whether diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -617,6 +617,34 @@ break; } + case Type::DependentSizedMatrix: { + const DependentSizedMatrixType *Mat1 = cast(T1); + const DependentSizedMatrixType *Mat2 = cast(T2); + // The element types, row and column expressions must be structurally + // equivalent. + if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(), + Mat2->getRowExpr()) || + !IsStructurallyEquivalent(Context, Mat1->getColumnExpr(), + Mat2->getColumnExpr()) || + !IsStructurallyEquivalent(Context, Mat1->getElementType(), + Mat2->getElementType())) + return false; + break; + } + + case Type::ConstantMatrix: { + const ConstantMatrixType *Mat1 = cast(T1); + const ConstantMatrixType *Mat2 = cast(T2); + // The element types must be structurally equivalent and the number of rows + // and columns must match. + if (!IsStructurallyEquivalent(Context, Mat1->getElementType(), + Mat2->getElementType()) || + Mat1->getNumRows() != Mat2->getNumRows() || + Mat1->getNumColumns() != Mat2->getNumColumns()) + return false; + break; + } + case Type::FunctionProto: { const auto *Proto1 = cast(T1); const auto *Proto2 = cast(T2); diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -10350,6 +10350,7 @@ case Type::BlockPointer: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::ObjCObject: case Type::ObjCInterface: case Type::ObjCObjectPointer: diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -2079,6 +2079,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Paren: @@ -3343,6 +3345,31 @@ mangleType(T->getElementType()); } +void CXXNameMangler::mangleType(const ConstantMatrixType *T) { + // Mangle matrix types using a vendor extended type qualifier: + // Umatrix_type + StringRef VendorQualifier = "matrix_type"; + Out << "U" << VendorQualifier.size() << VendorQualifier; + auto &ASTCtx = getASTContext(); + unsigned BitWidth = ASTCtx.getTypeSize(ASTCtx.getSizeType()); + llvm::APSInt Rows(BitWidth); + Rows = T->getNumRows(); + mangleIntegerLiteral(ASTCtx.getSizeType(), Rows); + llvm::APSInt Columns(BitWidth); + Columns = T->getNumColumns(); + mangleIntegerLiteral(ASTCtx.getSizeType(), Columns); + mangleType(T->getElementType()); +} + +void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) { + // Umatrix_type + StringRef VendorQualifier = "matrix_type"; + Out << "U" << VendorQualifier.size() << VendorQualifier; + mangleTemplateArg(T->getRowExpr()); + mangleTemplateArg(T->getColumnExpr()); + mangleType(T->getElementType()); +} + void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) { SplitQualType split = T->getPointeeType().split(); mangleQualifiers(split.Quals, T); diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -2730,6 +2730,23 @@ << Range; } +void MicrosoftCXXNameMangler::mangleType(const ConstantMatrixType *T, + Qualifiers quals, SourceRange Range) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, + "Cannot mangle this matrix type yet"); + Diags.Report(Range.getBegin(), DiagID) << Range; +} + +void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T, + Qualifiers quals, SourceRange Range) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "Cannot mangle this dependent-sized matrix type yet"); + Diags.Report(Range.getBegin(), DiagID) << Range; +} + void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T, Qualifiers, SourceRange Range) { DiagnosticsEngine &Diags = Context.getDiags(); diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -282,6 +282,53 @@ AddrSpaceExpr->Profile(ID, Context, true); } +MatrixType::MatrixType(TypeClass tc, QualType matrixType, QualType canonType, + const Expr *RowExpr, const Expr *ColumnExpr) + : Type(tc, canonType, + (RowExpr ? (matrixType->getDependence() | TypeDependence::Dependent | + TypeDependence::Instantiation | + (matrixType->isVariablyModifiedType() + ? TypeDependence::VariablyModified + : TypeDependence::None) | + (matrixType->containsUnexpandedParameterPack() || + (RowExpr && + RowExpr->containsUnexpandedParameterPack()) || + (ColumnExpr && + ColumnExpr->containsUnexpandedParameterPack()) + ? TypeDependence::UnexpandedPack + : TypeDependence::None)) + : matrixType->getDependence())), + ElementType(matrixType) {} + +ConstantMatrixType::ConstantMatrixType(QualType matrixType, unsigned nRows, + unsigned nColumns, QualType canonType) + : ConstantMatrixType(ConstantMatrix, matrixType, nRows, nColumns, + canonType) {} + +ConstantMatrixType::ConstantMatrixType(TypeClass tc, QualType matrixType, + unsigned nRows, unsigned nColumns, + QualType canonType) + : MatrixType(tc, matrixType, canonType) { + ConstantMatrixTypeBits.NumRows = nRows; + ConstantMatrixTypeBits.NumColumns = nColumns; +} + +DependentSizedMatrixType::DependentSizedMatrixType( + const ASTContext &CTX, QualType ElementType, QualType CanonicalType, + Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc) + : MatrixType(DependentSizedMatrix, ElementType, CanonicalType, RowExpr, + ColumnExpr), + Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr), loc(loc) {} + +void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID, + const ASTContext &CTX, + QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr) { + ID.AddPointer(ElementType.getAsOpaquePtr()); + RowExpr->Profile(ID, CTX, true); + ColumnExpr->Profile(ID, CTX, true); +} + VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType, VectorKind vecKind) : VectorType(Vector, vecType, nElements, canonType, vecKind) {} @@ -971,6 +1018,17 @@ return Ctx.getExtVectorType(elementType, T->getNumElements()); } + QualType VisitConstantMatrixType(const ConstantMatrixType *T) { + QualType elementType = recurse(T->getElementType()); + if (elementType.isNull()) + return {}; + if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr()) + return QualType(T, 0); + + return Ctx.getConstantMatrixType(elementType, T->getNumRows(), + T->getNumColumns()); + } + QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) { QualType returnType = recurse(T->getReturnType()); if (returnType.isNull()) @@ -1790,6 +1848,14 @@ return Visit(T->getElementType()); } + Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) { + return Visit(T->getElementType()); + } + + Type *VisitConstantMatrixType(const ConstantMatrixType *T) { + return Visit(T->getElementType()); + } + Type *VisitFunctionProtoType(const FunctionProtoType *T) { if (Syntactic && T->hasTrailingReturn()) return const_cast(T); @@ -3744,6 +3810,8 @@ case Type::Vector: case Type::ExtVector: return Cache::get(cast(T)->getElementType()); + case Type::ConstantMatrix: + return Cache::get(cast(T)->getElementType()); case Type::FunctionNoProto: return Cache::get(cast(T)->getReturnType()); case Type::FunctionProto: { @@ -3830,6 +3898,9 @@ case Type::Vector: case Type::ExtVector: return computeTypeLinkageInfo(cast(T)->getElementType()); + case Type::ConstantMatrix: + return computeTypeLinkageInfo( + cast(T)->getElementType()); case Type::FunctionNoProto: return computeTypeLinkageInfo(cast(T)->getReturnType()); case Type::FunctionProto: { @@ -3993,6 +4064,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::DependentAddressSpace: case Type::FunctionProto: case Type::FunctionNoProto: diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -256,6 +256,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: + case Type::DependentSizedMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Paren: @@ -720,6 +722,38 @@ OS << ")))"; } +void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T, + raw_ostream &OS) { + printBefore(T->getElementType(), OS); + OS << " __attribute__((matrix_type("; + OS << T->getNumRows() << ", " << T->getNumColumns(); + OS << ")))"; +} + +void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T, + raw_ostream &OS) { + printAfter(T->getElementType(), OS); +} + +void TypePrinter::printDependentSizedMatrixBefore( + const DependentSizedMatrixType *T, raw_ostream &OS) { + printBefore(T->getElementType(), OS); + OS << " __attribute__((matrix_type("; + if (T->getRowExpr()) { + T->getRowExpr()->printPretty(OS, nullptr, Policy); + } + OS << ", "; + if (T->getColumnExpr()) { + T->getColumnExpr()->printPretty(OS, nullptr, Policy); + } + OS << ")))"; +} + +void TypePrinter::printDependentSizedMatrixAfter( + const DependentSizedMatrixType *T, raw_ostream &OS) { + printAfter(T->getElementType(), OS); +} + void FunctionProtoType::printExceptionSpecification(raw_ostream &OS, const PrintingPolicy &Policy) diff --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h --- a/clang/lib/CodeGen/CGDebugInfo.h +++ b/clang/lib/CodeGen/CGDebugInfo.h @@ -192,6 +192,7 @@ llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit); llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F); + llvm::DIType *CreateType(const ConstantMatrixType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit); diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -2736,6 +2736,23 @@ return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray); } +llvm::DIType *CGDebugInfo::CreateType(const ConstantMatrixType *Ty, + llvm::DIFile *Unit) { + // FIXME: Create another debug type for matrices + // For the time being, it treats it like a nested ArrayType. + + llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit); + uint64_t Size = CGM.getContext().getTypeSize(Ty); + uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext()); + + // Create ranges for both dimensions. + llvm::SmallVector Subscripts; + Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns())); + Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows())); + llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts); + return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray); +} + llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) { uint64_t Size; uint32_t Align; @@ -3129,6 +3146,8 @@ case Type::ExtVector: case Type::Vector: return CreateType(cast(Ty), Unit); + case Type::ConstantMatrix: + return CreateType(cast(Ty), Unit); case Type::ObjCObjectPointer: return CreateType(cast(Ty), Unit); case Type::ObjCObject: diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -145,8 +145,19 @@ Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align, const Twine &Name, Address *Alloca) { - return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name, - /*ArraySize=*/nullptr, Alloca); + Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name, + /*ArraySize=*/nullptr, Alloca); + + if (Ty->isConstantMatrixType()) { + auto *ArrayTy = cast(Result.getType()->getElementType()); + auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(), + ArrayTy->getNumElements()); + + Result = Address( + Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()), + Result.getAlignment()); + } + return Result; } Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align, @@ -1732,6 +1743,42 @@ return Value; } +// Convert the pointer of \p Addr to a pointer to a vector (the value type of +// MatrixType), if it points to a array (the memory type of MatrixType). +static Address MaybeConvertMatrixAddress(Address Addr, CodeGenFunction &CGF, + bool IsVector = true) { + auto *ArrayTy = dyn_cast( + cast(Addr.getPointer()->getType())->getElementType()); + if (ArrayTy && IsVector) { + auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(), + ArrayTy->getNumElements()); + + return Address(CGF.Builder.CreateElementBitCast(Addr, VectorTy)); + } + auto *VectorTy = dyn_cast( + cast(Addr.getPointer()->getType())->getElementType()); + if (VectorTy && !IsVector) { + auto *ArrayTy = llvm::ArrayType::get(VectorTy->getElementType(), + VectorTy->getNumElements()); + + return Address(CGF.Builder.CreateElementBitCast(Addr, ArrayTy)); + } + + return Addr; +} + +// Emit a store of a matrix LValue. This may require casting the original +// pointer to memory address (ArrayType) to a pointer to the value type +// (VectorType). +static void EmitStoreOfMatrixScalar(llvm::Value *value, LValue lvalue, + bool isInit, CodeGenFunction &CGF) { + Address Addr = MaybeConvertMatrixAddress(lvalue.getAddress(CGF), CGF, + value->getType()->isVectorTy()); + CGF.EmitStoreOfScalar(value, Addr, lvalue.isVolatile(), lvalue.getType(), + lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit, + lvalue.isNontemporal()); +} + void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr, bool Volatile, QualType Ty, LValueBaseInfo BaseInfo, @@ -1779,11 +1826,26 @@ void CodeGenFunction::EmitStoreOfScalar(llvm::Value *value, LValue lvalue, bool isInit) { + if (lvalue.getType()->isConstantMatrixType()) { + EmitStoreOfMatrixScalar(value, lvalue, isInit, *this); + return; + } + EmitStoreOfScalar(value, lvalue.getAddress(*this), lvalue.isVolatile(), lvalue.getType(), lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit, lvalue.isNontemporal()); } +// Emit a load of a LValue of matrix type. This may require casting the pointer +// to memory address (ArrayType) to a pointer to the value type (VectorType). +static RValue EmitLoadOfMatrixLValue(LValue LV, SourceLocation Loc, + CodeGenFunction &CGF) { + assert(LV.getType()->isConstantMatrixType()); + Address Addr = MaybeConvertMatrixAddress(LV.getAddress(CGF), CGF); + LV.setAddress(Addr); + return RValue::get(CGF.EmitLoadOfScalar(LV, Loc)); +} + /// EmitLoadOfLValue - Given an expression that represents a value lvalue, this /// method emits the address of the lvalue, then loads the result as an rvalue, /// returning the rvalue. @@ -1809,6 +1871,9 @@ if (LV.isSimple()) { assert(!LV.getType()->isFunctionType()); + if (LV.getType()->isConstantMatrixType()) + return EmitLoadOfMatrixLValue(LV, Loc, *this); + // Everything needs a load. return RValue::get(EmitLoadOfScalar(LV, Loc)); } diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -247,6 +247,7 @@ case Type::MemberPointer: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Enum: @@ -2000,6 +2001,7 @@ case Type::Complex: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Record: case Type::Enum: case Type::Elaborated: diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -82,6 +82,13 @@ /// a type. For example, the scalar representation for _Bool is i1, but the /// memory representation is usually i8 or i32, depending on the target. llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T, bool ForBitField) { + if (T->isConstantMatrixType()) { + const Type *Ty = Context.getCanonicalType(T).getTypePtr(); + const ConstantMatrixType *MT = cast(Ty); + return llvm::ArrayType::get(ConvertType(MT->getElementType()), + MT->getNumRows() * MT->getNumColumns()); + } + llvm::Type *R = ConvertType(T); // If this is a bool type, or an ExtIntType in a bitfield representation, @@ -646,6 +653,12 @@ VT->getNumElements()); break; } + case Type::ConstantMatrix: { + const ConstantMatrixType *MT = cast(Ty); + ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()), + MT->getNumRows() * MT->getNumColumns()); + break; + } case Type::FunctionNoProto: case Type::FunctionProto: ResultType = ConvertFunctionTypeInternal(T); diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -3223,6 +3223,7 @@ // GCC treats vector and complex types as fundamental types. case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Complex: case Type::Atomic: // FIXME: GCC treats block pointers as fundamental types?! @@ -3458,6 +3459,7 @@ case Type::Builtin: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Complex: case Type::BlockPointer: // Itanium C++ ABI 2.9.5p4: diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -4566,6 +4566,13 @@ if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false)) CmdArgs.push_back("-fdefault-calling-conv=stdcall"); + if (Args.hasArg(options::OPT_fenable_matrix)) { + // enable-matrix is needed by both the LangOpts and by LLVM. + CmdArgs.push_back("-fenable-matrix"); + CmdArgs.push_back("-mllvm"); + CmdArgs.push_back("-enable-matrix"); + } + CodeGenOptions::FramePointerKind FPKeepKind = getFramePointerKind(Args, RawTriple); const char *FPKeepKindStr = nullptr; diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -3337,6 +3337,8 @@ Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers); Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj); + Opts.MatrixTypes = Args.hasArg(OPT_fenable_matrix); + Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags); if (Arg *A = Args.getLastArg(OPT_msign_return_address_EQ)) { diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -4257,6 +4257,7 @@ case Type::Complex: case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Record: case Type::Enum: case Type::Elaborated: diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp --- a/clang/lib/Sema/SemaLookup.cpp +++ b/clang/lib/Sema/SemaLookup.cpp @@ -2966,6 +2966,7 @@ // These are fundamental types. case Type::Vector: case Type::ExtVector: + case Type::ConstantMatrix: case Type::Complex: case Type::ExtInt: break; diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp old mode 100755 new mode 100644 --- a/clang/lib/Sema/SemaTemplate.cpp +++ b/clang/lib/Sema/SemaTemplate.cpp @@ -5867,6 +5867,11 @@ return Visit(T->getElementType()); } +bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType( + const DependentSizedMatrixType *T) { + return Visit(T->getElementType()); +} + bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType( const DependentAddressSpaceType *T) { return Visit(T->getPointeeType()); @@ -5885,6 +5890,11 @@ return Visit(T->getElementType()); } +bool UnnamedLocalNoLinkageFinder::VisitConstantMatrixType( + const ConstantMatrixType *T) { + return Visit(T->getElementType()); +} + bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType( const FunctionProtoType* T) { for (const auto &A : T->param_types()) { diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp --- a/clang/lib/Sema/SemaTemplateDeduction.cpp +++ b/clang/lib/Sema/SemaTemplateDeduction.cpp @@ -2057,6 +2057,101 @@ // (clang extension) // + // T __attribute__((matrix_type(, + // ))) + case Type::ConstantMatrix: { + const ConstantMatrixType *MatrixArg = dyn_cast(Arg); + if (!MatrixArg) + return Sema::TDK_NonDeducedMismatch; + + const ConstantMatrixType *MatrixParam = cast(Param); + // Check that the dimensions are the same + if (MatrixParam->getNumRows() != MatrixArg->getNumRows() || + MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) { + return Sema::TDK_NonDeducedMismatch; + } + // Perform deduction on element types. + return DeduceTemplateArgumentsByTypeMatch( + S, TemplateParams, MatrixParam->getElementType(), + MatrixArg->getElementType(), Info, Deduced, TDF); + } + + case Type::DependentSizedMatrix: { + const MatrixType *MatrixArg = dyn_cast(Arg); + if (!MatrixArg) + return Sema::TDK_NonDeducedMismatch; + + // Check the element type of the matrixes. + const DependentSizedMatrixType *MatrixParam = + cast(Param); + if (Sema::TemplateDeductionResult Result = + DeduceTemplateArgumentsByTypeMatch( + S, TemplateParams, MatrixParam->getElementType(), + MatrixArg->getElementType(), Info, Deduced, TDF)) + return Result; + + // Try to deduce a matrix dimension. + auto DeduceMatrixArg = + [&S, &Info, &Deduced, &TemplateParams]( + Expr *ParamExpr, const MatrixType *Arg, + unsigned (ConstantMatrixType::*GetArgDimension)() const, + Expr *(DependentSizedMatrixType::*GetArgDimensionExpr)() const) { + const auto *ArgConstMatrix = dyn_cast(Arg); + const auto *ArgDepMatrix = dyn_cast(Arg); + if (!ParamExpr->isValueDependent()) { + llvm::APSInt ParamConst( + S.Context.getTypeSize(S.Context.getSizeType())); + if (!ParamExpr->isIntegerConstantExpr(ParamConst, S.Context)) + return Sema::TDK_NonDeducedMismatch; + + if (ArgConstMatrix) { + if ((ArgConstMatrix->*GetArgDimension)() == ParamConst) + return Sema::TDK_Success; + return Sema::TDK_NonDeducedMismatch; + } + + Expr *ArgExpr = (ArgDepMatrix->*GetArgDimensionExpr)(); + llvm::APSInt ArgConst( + S.Context.getTypeSize(S.Context.getSizeType())); + if (!ArgExpr->isValueDependent() && + ArgExpr->isIntegerConstantExpr(ArgConst, S.Context) && + ArgConst == ParamConst) + return Sema::TDK_Success; + return Sema::TDK_NonDeducedMismatch; + } + + NonTypeTemplateParmDecl *NTTP = + getDeducedParameterFromExpr(Info, ParamExpr); + if (!NTTP) + return Sema::TDK_Success; + + if (ArgConstMatrix) { + llvm::APSInt ArgConst( + S.Context.getTypeSize(S.Context.getSizeType())); + ArgConst = (ArgConstMatrix->*GetArgDimension)(); + return DeduceNonTypeTemplateArgument( + S, TemplateParams, NTTP, ArgConst, S.Context.getSizeType(), + /*ArrayBound=*/true, Info, Deduced); + } + + return DeduceNonTypeTemplateArgument( + S, TemplateParams, NTTP, (ArgDepMatrix->*GetArgDimensionExpr)(), + Info, Deduced); + }; + + auto Result = DeduceMatrixArg(MatrixParam->getRowExpr(), MatrixArg, + &ConstantMatrixType::getNumRows, + &DependentSizedMatrixType::getRowExpr); + if (Result) + return Result; + + return DeduceMatrixArg(MatrixParam->getColumnExpr(), MatrixArg, + &ConstantMatrixType::getNumColumns, + &DependentSizedMatrixType::getColumnExpr); + } + + // (clang extension) + // // T __attribute__(((address_space(N)))) case Type::DependentAddressSpace: { const DependentAddressSpaceType *AddressSpaceParam = @@ -5723,6 +5818,24 @@ break; } + case Type::ConstantMatrix: { + const ConstantMatrixType *MatType = cast(T); + MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced, + Depth, Used); + break; + } + + case Type::DependentSizedMatrix: { + const DependentSizedMatrixType *MatType = cast(T); + MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced, + Depth, Used); + MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth, + Used); + MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced, + Depth, Used); + break; + } + case Type::FunctionProto: { const FunctionProtoType *Proto = cast(T); MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth, diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -2492,14 +2492,15 @@ if (!VecSize.isIntN(61)) { // Bit size will overflow uint64. Diag(AttrLoc, diag::err_attribute_size_too_large) - << SizeExpr->getSourceRange(); + << SizeExpr->getSourceRange() << "vector"; return QualType(); } uint64_t VectorSizeBits = VecSize.getZExtValue() * 8; unsigned TypeSize = static_cast(Context.getTypeSize(CurType)); if (VectorSizeBits == 0) { - Diag(AttrLoc, diag::err_attribute_zero_size) << SizeExpr->getSourceRange(); + Diag(AttrLoc, diag::err_attribute_zero_size) + << SizeExpr->getSourceRange() << "vector"; return QualType(); } @@ -2511,7 +2512,7 @@ if (VectorSizeBits / TypeSize > std::numeric_limits::max()) { Diag(AttrLoc, diag::err_attribute_size_too_large) - << SizeExpr->getSourceRange(); + << SizeExpr->getSourceRange() << "vector"; return QualType(); } @@ -2549,7 +2550,7 @@ if (!vecSize.isIntN(32)) { Diag(AttrLoc, diag::err_attribute_size_too_large) - << ArraySize->getSourceRange(); + << ArraySize->getSourceRange() << "vector"; return QualType(); } // Unlike gcc's vector_size attribute, the size is specified as the @@ -2558,7 +2559,7 @@ if (vectorSize == 0) { Diag(AttrLoc, diag::err_attribute_zero_size) - << ArraySize->getSourceRange(); + << ArraySize->getSourceRange() << "vector"; return QualType(); } @@ -2568,6 +2569,84 @@ return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc); } +QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols, + SourceLocation AttrLoc) { + assert(Context.getLangOpts().MatrixTypes && + "Should never build a matrix type when it is disabled"); + + if (NumRows->isTypeDependent() || NumCols->isTypeDependent() || + NumRows->isValueDependent() || NumCols->isValueDependent()) + return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols, + AttrLoc); + + // Check element type, if it is not dependent. + if (!ElementTy->isDependentType() && + !MatrixType::isValidElementType(ElementTy)) { + Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy; + return QualType(); + } + + // Both row and column values can only be 20 bit wide currently. + llvm::APSInt ValueRows(32), ValueColumns(32); + + bool const RowsIsInteger = NumRows->isIntegerConstantExpr(ValueRows, Context); + bool const ColumnsIsInteger = + NumCols->isIntegerConstantExpr(ValueColumns, Context); + + auto const RowRange = NumRows->getSourceRange(); + auto const ColRange = NumCols->getSourceRange(); + + // Both are row and column expressions are invalid. + if (!RowsIsInteger && !ColumnsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange + << ColRange; + return QualType(); + } + + // Only the row expression is invalid. + if (!RowsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange; + return QualType(); + } + + // Only the column expression is invalid. + if (!ColumnsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange; + return QualType(); + } + + // Check the matrix dimensions. + unsigned MatrixRows = static_cast(ValueRows.getZExtValue()); + unsigned MatrixColumns = static_cast(ValueColumns.getZExtValue()); + if (MatrixRows == 0 && MatrixColumns == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) + << "matrix" << RowRange << ColRange; + return QualType(); + } + if (MatrixRows == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange; + return QualType(); + } + if (MatrixColumns == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange; + return QualType(); + } + if (!ConstantMatrixType::isDimensionValid(MatrixRows)) { + Diag(AttrLoc, diag::err_attribute_size_too_large) + << RowRange << "matrix row"; + return QualType(); + } + if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) { + Diag(AttrLoc, diag::err_attribute_size_too_large) + << ColRange << "matrix column"; + return QualType(); + } + return Context.getConstantMatrixType(ElementTy, MatrixRows, MatrixColumns); +} + bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) { if (T->isArrayType() || T->isFunctionType()) { Diag(Loc, diag::err_func_returning_array_function) @@ -6013,6 +6092,21 @@ "no address_space attribute found at the expected location!"); } +static void fillMatrixTypeLoc(MatrixTypeLoc MTL, + const ParsedAttributesView &Attrs) { + for (const ParsedAttr &AL : Attrs) { + if (AL.getKind() == ParsedAttr::AT_MatrixType) { + MTL.setAttrNameLoc(AL.getLoc()); + MTL.setAttrRowOperand(AL.getArgAsExpr(0)); + MTL.setAttrColumnOperand(AL.getArgAsExpr(1)); + MTL.setAttrOperandParensRange(SourceRange()); + return; + } + } + + llvm_unreachable("no matrix_type attribute found at the expected location!"); +} + /// Create and instantiate a TypeSourceInfo with type source information. /// /// \param T QualType referring to the type as written in source code. @@ -6061,6 +6155,9 @@ CurrTL = TL.getPointeeTypeLoc().getUnqualifiedLoc(); } + if (MatrixTypeLoc TL = CurrTL.getAs()) + fillMatrixTypeLoc(TL, D.getTypeObject(i).getAttrs()); + // FIXME: Ordering here? while (AdjustedTypeLoc TL = CurrTL.getAs()) CurrTL = TL.getNextTypeLoc().getUnqualifiedLoc(); @@ -7706,6 +7803,68 @@ } } +/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type +static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr, + Sema &S) { + if (!S.getLangOpts().MatrixTypes) { + S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled); + return; + } + + if (Attr.getNumArgs() != 2) { + S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments) + << Attr << 2; + return; + } + + Expr *RowsExpr = nullptr; + Expr *ColsExpr = nullptr; + + // TODO: Refactor parameter extraction into separate function + // Get the number of rows + if (Attr.isArgIdent(0)) { + CXXScopeSpec SS; + SourceLocation TemplateKeywordLoc; + UnqualifiedId id; + id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc()); + ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS, + TemplateKeywordLoc, id, false, false); + + if (Rows.isInvalid()) + // TODO: maybe a good error message would be nice here + return; + RowsExpr = Rows.get(); + } else { + assert(Attr.isArgExpr(0) && + "Argument to should either be an identity or expression"); + RowsExpr = Attr.getArgAsExpr(0); + } + + // Get the number of columns + if (Attr.isArgIdent(1)) { + CXXScopeSpec SS; + SourceLocation TemplateKeywordLoc; + UnqualifiedId id; + id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc()); + ExprResult Columns = S.ActOnIdExpression( + S.getCurScope(), SS, TemplateKeywordLoc, id, false, false); + + if (Columns.isInvalid()) + // TODO: a good error message would be nice here + return; + RowsExpr = Columns.get(); + } else { + assert(Attr.isArgExpr(1) && + "Argument to should either be an identity or expression"); + ColsExpr = Attr.getArgAsExpr(1); + } + + // Create the matrix type. + QualType T = S.BuildMatrixType(CurType, RowsExpr, ColsExpr, Attr.getLoc()); + if (!T.isNull()) + CurType = T; +} + static void HandleLifetimeBoundAttr(TypeProcessingState &State, QualType &CurType, ParsedAttr &Attr) { @@ -7857,6 +8016,11 @@ break; } + case ParsedAttr::AT_MatrixType: + HandleMatrixTypeAttr(type, attr, state.getSema()); + attr.setUsedAsTypeAttr(); + break; + MS_TYPE_ATTRS_CASELIST: if (!handleMSPointerTypeQualifierAttr(state, attr, type)) attr.setUsedAsTypeAttr(); diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -894,6 +894,16 @@ Expr *SizeExpr, SourceLocation AttributeLoc); + /// Build a new matrix type given the element type and dimensions. + QualType RebuildConstantMatrixType(QualType ElementType, unsigned NumRows, + unsigned NumColumns); + + /// Build a new matrix type given the type and dependently-defined + /// dimensions. + QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttributeLoc); + /// Build a new DependentAddressSpaceType or return the pointee /// type variable with the correct address space (retrieved from /// AddrSpaceExpr) applied to it. The former will be returned in cases @@ -5180,6 +5190,86 @@ } template +QualType +TreeTransform::TransformConstantMatrixType(TypeLocBuilder &TLB, + ConstantMatrixTypeLoc TL) { + const ConstantMatrixType *T = TL.getTypePtr(); + QualType ElementType = getDerived().TransformType(T->getElementType()); + if (ElementType.isNull()) + return QualType(); + + QualType Result = TL.getType(); + if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) { + Result = getDerived().RebuildConstantMatrixType( + ElementType, T->getNumRows(), T->getNumColumns()); + if (Result.isNull()) + return QualType(); + } + + ConstantMatrixTypeLoc NewTL = TLB.push(Result); + NewTL.setAttrNameLoc(TL.getAttrNameLoc()); + NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange()); + NewTL.setAttrRowOperand(TL.getAttrRowOperand()); + NewTL.setAttrColumnOperand(TL.getAttrColumnOperand()); + + return Result; +} + +template +QualType TreeTransform::TransformDependentSizedMatrixType( + TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) { + const DependentSizedMatrixType *T = TL.getTypePtr(); + + QualType ElementType = getDerived().TransformType(T->getElementType()); + if (ElementType.isNull()) { + return QualType(); + } + + // Matrix dimensions are constant expressions. + EnterExpressionEvaluationContext Unevaluated( + SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated); + + Expr *origRows = TL.getAttrRowOperand(); + if (!origRows) + origRows = T->getRowExpr(); + Expr *origColumns = TL.getAttrColumnOperand(); + if (!origColumns) + origColumns = T->getColumnExpr(); + + ExprResult rowResult = getDerived().TransformExpr(origRows); + rowResult = SemaRef.ActOnConstantExpression(rowResult); + if (rowResult.isInvalid()) + return QualType(); + + ExprResult columnResult = getDerived().TransformExpr(origColumns); + columnResult = SemaRef.ActOnConstantExpression(columnResult); + if (columnResult.isInvalid()) + return QualType(); + + Expr *rows = rowResult.get(); + Expr *columns = columnResult.get(); + + QualType Result = TL.getType(); + if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() || + rows != origRows || columns != origColumns) { + Result = getDerived().RebuildDependentSizedMatrixType( + ElementType, rows, columns, T->getAttributeLoc()); + + if (Result.isNull()) + return QualType(); + } + + // We might have any sort of matrix type now, but fortunately they + // all have the same location layout. + MatrixTypeLoc NewTL = TLB.push(Result); + NewTL.setAttrNameLoc(TL.getAttrNameLoc()); + NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange()); + NewTL.setAttrRowOperand(rows); + NewTL.setAttrColumnOperand(columns); + return Result; +} + +template QualType TreeTransform::TransformDependentAddressSpaceType( TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) { const DependentAddressSpaceType *T = TL.getTypePtr(); @@ -13750,6 +13840,21 @@ return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc); } +template +QualType TreeTransform::RebuildConstantMatrixType( + QualType ElementType, unsigned NumRows, unsigned NumColumns) { + return SemaRef.Context.getConstantMatrixType(ElementType, NumRows, + NumColumns); +} + +template +QualType TreeTransform::RebuildDependentSizedMatrixType( + QualType ElementType, Expr *RowExpr, Expr *ColumnExpr, + SourceLocation AttributeLoc) { + return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr, + AttributeLoc); +} + template QualType TreeTransform::RebuildFunctionProtoType( QualType T, diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -6554,6 +6554,21 @@ TL.setNameLoc(readSourceLocation()); } +void TypeLocReader::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) { + TL.setAttrNameLoc(readSourceLocation()); + TL.setAttrOperandParensRange(Reader.readSourceRange()); + TL.setAttrRowOperand(Reader.readExpr()); + TL.setAttrColumnOperand(Reader.readExpr()); +} + +void TypeLocReader::VisitDependentSizedMatrixTypeLoc( + DependentSizedMatrixTypeLoc TL) { + TL.setAttrNameLoc(readSourceLocation()); + TL.setAttrOperandParensRange(Reader.readSourceRange()); + TL.setAttrRowOperand(Reader.readExpr()); + TL.setAttrColumnOperand(Reader.readExpr()); +} + void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) { TL.setLocalRangeBegin(readSourceLocation()); TL.setLParenLoc(readSourceLocation()); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -288,6 +288,25 @@ Record.AddSourceLocation(TL.getNameLoc()); } +void TypeLocWriter::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) { + Record.AddSourceLocation(TL.getAttrNameLoc()); + SourceRange range = TL.getAttrOperandParensRange(); + Record.AddSourceLocation(range.getBegin()); + Record.AddSourceLocation(range.getEnd()); + Record.AddStmt(TL.getAttrRowOperand()); + Record.AddStmt(TL.getAttrColumnOperand()); +} + +void TypeLocWriter::VisitDependentSizedMatrixTypeLoc( + DependentSizedMatrixTypeLoc TL) { + Record.AddSourceLocation(TL.getAttrNameLoc()); + SourceRange range = TL.getAttrOperandParensRange(); + Record.AddSourceLocation(range.getBegin()); + Record.AddSourceLocation(range.getEnd()); + Record.AddStmt(TL.getAttrRowOperand()); + Record.AddStmt(TL.getAttrColumnOperand()); +} + void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) { Record.AddSourceLocation(TL.getLocalRangeBegin()); Record.AddSourceLocation(TL.getLParenLoc()); diff --git a/clang/test/CodeGen/debug-info-matrix-types.c b/clang/test/CodeGen/debug-info-matrix-types.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/debug-info-matrix-types.c @@ -0,0 +1,19 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -debug-info-kind=limited -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +typedef double dx2x3_t __attribute__((matrix_type(2, 3))); + +void load_store_double(dx2x3_t *a, dx2x3_t *b) { + // CHECK-DAG: @llvm.dbg.declare(metadata [6 x double]** %a.addr, metadata [[EXPR_A:![0-9]+]] + // CHECK-DAG: @llvm.dbg.declare(metadata [6 x double]** %b.addr, metadata [[EXPR_B:![0-9]+]] + // CHECK: [[PTR_TY:![0-9]+]] = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: [[TYPEDEF:![0-9]+]], size: 64) + // CHECK: [[TYPEDEF]] = !DIDerivedType(tag: DW_TAG_typedef, name: "dx2x3_t", {{.+}} baseType: [[MATRIX_TY:![0-9]+]]) + // CHECK: [[MATRIX_TY]] = !DICompositeType(tag: DW_TAG_array_type, baseType: [[ELT_TY:![0-9]+]], size: 384, elements: [[ELEMENTS:![0-9]+]]) + // CHECK: [[ELT_TY]] = !DIBasicType(name: "double", size: 64, encoding: DW_ATE_float) + // CHECK: [[ELEMENTS]] = !{[[COLS:![0-9]+]], [[ROWS:![0-9]+]]} + // CHECK: [[COLS]] = !DISubrange(count: 3) + // CHECK: [[ROWS]] = !DISubrange(count: 2) + // CHECK: [[EXPR_A]] = !DILocalVariable(name: "a", arg: 1, {{.+}} type: [[PTR_TY]]) + // CHECK: [[EXPR_B]] = !DILocalVariable(name: "b", arg: 2, {{.+}} type: [[PTR_TY]]) + + *a = *b; +} diff --git a/clang/test/CodeGen/matrix-type.c b/clang/test/CodeGen/matrix-type.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/matrix-type.c @@ -0,0 +1,158 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +#if !__has_extension(matrix_types) +#error Expected extension 'matrix_types' to be enabled +#endif + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); + +// CHECK: %struct.Matrix = type { i8, [12 x float], float } + +void load_store_double(dx5x5_t *a, dx5x5_t *b) { + // CHECK-LABEL: define void @load_store_double( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x4_t __attribute__((matrix_type(3, 4))); +void load_store_float(fx3x4_t *a, fx3x4_t *b) { + // CHECK-LABEL: define void @load_store_float( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x float]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x float]*, align 8 + // CHECK-NEXT: store [12 x float]* %a, [12 x float]** %a.addr, align 8 + // CHECK-NEXT: store [12 x float]* %b, [12 x float]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x float]*, [12 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x float]* %0 to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load [12 x float]*, [12 x float]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x float]* %3 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef int ix3x4_t __attribute__((matrix_type(4, 3))); +void load_store_int(ix3x4_t *a, ix3x4_t *b) { + // CHECK-LABEL: define void @load_store_int( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x i32]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x i32]*, align 8 + // CHECK-NEXT: store [12 x i32]* %a, [12 x i32]** %a.addr, align 8 + // CHECK-NEXT: store [12 x i32]* %b, [12 x i32]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x i32]*, [12 x i32]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x i32]* %0 to <12 x i32>* + // CHECK-NEXT: %2 = load <12 x i32>, <12 x i32>* %1, align 4 + // CHECK-NEXT: %3 = load [12 x i32]*, [12 x i32]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x i32]* %3 to <12 x i32>* + // CHECK-NEXT: store <12 x i32> %2, <12 x i32>* %4, align 4 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef unsigned long long ullx3x4_t __attribute__((matrix_type(4, 3))); +void load_store_ull(ullx3x4_t *a, ullx3x4_t *b) { + // CHECK-LABEL: define void @load_store_ull( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x i64]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x i64]*, align 8 + // CHECK-NEXT: store [12 x i64]* %a, [12 x i64]** %a.addr, align 8 + // CHECK-NEXT: store [12 x i64]* %b, [12 x i64]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x i64]*, [12 x i64]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x i64]* %0 to <12 x i64>* + // CHECK-NEXT: %2 = load <12 x i64>, <12 x i64>* %1, align 8 + // CHECK-NEXT: %3 = load [12 x i64]*, [12 x i64]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x i64]* %3 to <12 x i64>* + // CHECK-NEXT: store <12 x i64> %2, <12 x i64>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef __fp16 fp16x3x4_t __attribute__((matrix_type(4, 3))); +void load_store_fp16(fp16x3x4_t *a, fp16x3x4_t *b) { + // CHECK-LABEL: define void @load_store_fp16( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [12 x half]*, align 8 + // CHECK-NEXT: %b.addr = alloca [12 x half]*, align 8 + // CHECK-NEXT: store [12 x half]* %a, [12 x half]** %a.addr, align 8 + // CHECK-NEXT: store [12 x half]* %b, [12 x half]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [12 x half]*, [12 x half]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [12 x half]* %0 to <12 x half>* + // CHECK-NEXT: %2 = load <12 x half>, <12 x half>* %1, align 2 + // CHECK-NEXT: %3 = load [12 x half]*, [12 x half]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [12 x half]* %3 to <12 x half>* + // CHECK-NEXT: store <12 x half> %2, <12 x half>* %4, align 2 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); + +void parameter_passing(fx3x3_t a, fx3x3_t *b) { + // CHECK-LABEL: define void @parameter_passing( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4 + // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4 + // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>* + // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4 + // CHECK-NEXT: ret void + *b = a; +} + +fx3x3_t return_matrix(fx3x3_t *a) { + // CHECK-LABEL: define <9 x float> @return_matrix + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>* + // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: ret <9 x float> %2 + return *a; +} + +typedef struct { + char Tmp1; + fx3x4_t Data; + float Tmp2; +} Matrix; + +void matrix_struct(Matrix *a, Matrix *b) { + // CHECK-LABEL: define void @matrix_struct( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b->Data = a->Data; +} diff --git a/clang/test/CodeGenCXX/matrix-type.cpp b/clang/test/CodeGenCXX/matrix-type.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCXX/matrix-type.cpp @@ -0,0 +1,388 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++17 | FileCheck %s + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +typedef float fx3x4_t __attribute__((matrix_type(3, 4))); + +// CHECK: %struct.Matrix = type { i8, [12 x float], float } + +void load_store(dx5x5_t *a, dx5x5_t *b) { + // CHECK-LABEL: define void @_Z10load_storePU11matrix_typeLm5ELm5EdS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); + +void parameter_passing(fx3x3_t a, fx3x3_t *b) { + // CHECK-LABEL: define void @_Z17parameter_passingU11matrix_typeLm3ELm3EfPS_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4 + // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4 + // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>* + // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4 + // CHECK-NEXT: ret void + *b = a; +} + +fx3x3_t return_matrix(fx3x3_t *a) { + // CHECK-LABEL: define <9 x float> @_Z13return_matrixPU11matrix_typeLm3ELm3Ef( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>* + // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: ret <9 x float> %2 + return *a; +} + +struct Matrix { + char Tmp1; + fx3x4_t Data; + float Tmp2; +}; + +void matrix_struct_pointers(Matrix *a, Matrix *b) { + // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b->Data = a->Data; +} + +void matrix_struct_reference(Matrix &a, Matrix &b) { + // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b.Data = a.Data; +} + +class MatrixClass { +public: + int Tmp1; + fx3x4_t Data; + long Tmp2; +}; + +void matrix_class_reference(MatrixClass &a, MatrixClass &b) { + // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %class.MatrixClass*, align 8 + // CHECK-NEXT: %b.addr = alloca %class.MatrixClass*, align 8 + // CHECK-NEXT: store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8 + // CHECK-NEXT: store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8 + // CHECK-NEXT: %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b.Data = a.Data; +} + +template +class MatrixClassTemplate { +public: + using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols))); + int Tmp1; + MatrixTy Data; + long Tmp2; +}; + +template +void matrix_template_reference(MatrixClassTemplate &a, MatrixClassTemplate &b) { + b.Data = a.Data; +} + +MatrixClassTemplate matrix_template_reference_caller(float *Data) { + // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data + // CHECK-NEXT: entry: + // CHECK-NEXT: %Data.addr = alloca float*, align 8 + // CHECK-NEXT: %Arg = alloca %class.MatrixClassTemplate, align 8 + // CHECK-NEXT: store float* %Data, float** %Data.addr, align 8 + // CHECK-NEXT: %0 = load float*, float** %Data.addr, align 8 + // CHECK-NEXT: %1 = bitcast float* %0 to [150 x float]* + // CHECK-NEXT: %2 = bitcast [150 x float]* %1 to <150 x float>* + // CHECK-NEXT: %3 = load <150 x float>, <150 x float>* %2, align 4 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>* + // CHECK-NEXT: store <150 x float> %3, <150 x float>* %4, align 4 + // CHECK-NEXT: call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result) + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %class.MatrixClassTemplate*, align 8 + // CHECK-NEXT: %b.addr = alloca %class.MatrixClassTemplate*, align 8 + // CHECK-NEXT: store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8 + // CHECK-NEXT: store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8 + // CHECK-NEXT: %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [150 x float]* %Data to <150 x float>* + // CHECK-NEXT: %2 = load <150 x float>, <150 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>* + // CHECK-NEXT: store <150 x float> %2, <150 x float>* %4, align 4 + // CHECK-NEXT: ret void + + MatrixClassTemplate Result, Arg; + Arg.Data = *((MatrixClassTemplate::MatrixTy *)Data); + matrix_template_reference(Arg, Result); + return Result; +} + +template +using matrix = T __attribute__((matrix_type(R, C))); + +template +struct selector {}; + +template +selector<0> use_matrix(matrix &m) {} + +template +selector<1> use_matrix(matrix &m) {} + +template +selector<2> use_matrix(matrix &m) {} + +template +selector<3> use_matrix(matrix &m) {} + +template +selector<4> use_matrix(matrix &m) {} + +void test_template_deduction() { + + // CHECK-LABEL: define void @_Z23test_template_deductionv() + // CHECK-NEXT: entry: + // CHECK-NEXT: %m0 = alloca [120 x i32], align 4 + // CHECK-NEXT: %w = alloca %struct.selector, align 1 + // CHECK-NEXT: %undef.agg.tmp = alloca %struct.selector, align 1 + // CHECK-NEXT: %m1 = alloca [100 x i32], align 4 + // CHECK-NEXT: %x = alloca %struct.selector.0, align 1 + // CHECK-NEXT: %undef.agg.tmp1 = alloca %struct.selector.0, align 1 + // CHECK-NEXT: %m2 = alloca [120 x i32], align 4 + // CHECK-NEXT: %y = alloca %struct.selector.1, align 1 + // CHECK-NEXT: %undef.agg.tmp2 = alloca %struct.selector.1, align 1 + // CHECK-NEXT: %m3 = alloca [144 x i32], align 4 + // CHECK-NEXT: %z = alloca %struct.selector.2, align 1 + // CHECK-NEXT: %undef.agg.tmp3 = alloca %struct.selector.2, align 1 + // CHECK-NEXT: %m4 = alloca [144 x float], align 4 + // CHECK-NEXT: %v = alloca %struct.selector.3, align 1 + // CHECK-NEXT: %undef.agg.tmp4 = alloca %struct.selector.3, align 1 + // CHECK-NEXT: call void @_Z10use_matrixIiLm12EE8selectorILi3EERU11matrix_typeXLm10EEXT0_ET_([120 x i32]* dereferenceable(480) %m0) + // CHECK-NEXT: call void @_Z10use_matrixIiE8selectorILi2EERU11matrix_typeLm10ELm10ET_([100 x i32]* dereferenceable(400) %m1) + // CHECK-NEXT: call void @_Z10use_matrixIiLm12EE8selectorILi1EERU11matrix_typeXT0_EXLm10EET_([120 x i32]* dereferenceable(480) %m2) + // CHECK-NEXT: call void @_Z10use_matrixIiLm12ELm12EE8selectorILi0EERU11matrix_typeXT0_EXT1_ET_([144 x i32]* dereferenceable(576) %m3) + // CHECK-NEXT: call void @_Z10use_matrixILm12ELm12EE8selectorILi4EERU11matrix_typeXT_EXT0_Ef([144 x float]* dereferenceable(576) %m4) + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12EE8selectorILi3EERU11matrix_typeXLm10EEXT0_ET_([120 x i32]* dereferenceable(480) %m) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [120 x i32]*, align 8 + // CHECK-NEXT: store [120 x i32]* %m, [120 x i32]** %m.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiE8selectorILi2EERU11matrix_typeLm10ELm10ET_([100 x i32]* dereferenceable(400) %m) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [100 x i32]*, align 8 + // CHECK-NEXT: store [100 x i32]* %m, [100 x i32]** %m.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12EE8selectorILi1EERU11matrix_typeXT0_EXLm10EET_([120 x i32]* dereferenceable(480) %m) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [120 x i32]*, align 8 + // CHECK-NEXT: store [120 x i32]* %m, [120 x i32]** %m.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixIiLm12ELm12EE8selectorILi0EERU11matrix_typeXT0_EXT1_ET_([144 x i32]* dereferenceable(576) %m) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [144 x i32]*, align 8 + // CHECK-NEXT: store [144 x i32]* %m, [144 x i32]** %m.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr void @_Z10use_matrixILm12ELm12EE8selectorILi4EERU11matrix_typeXT_EXT0_Ef([144 x float]* dereferenceable(576) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [144 x float]*, align 8 + // CHECK-NEXT: store [144 x float]* %m, [144 x float]** %m.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + matrix m0; + selector<3> w = use_matrix(m0); + matrix m1; + selector<2> x = use_matrix(m1); + matrix m2; + selector<1> y = use_matrix(m2); + matrix m3; + selector<0> z = use_matrix(m3); + matrix m4; + selector<4> v = use_matrix(m4); +} + +template +void foo(matrix &m) { +} + +void test_auto_t() { + // CHECK-LABEL: define void @_Z11test_auto_tv() + // CHECK-NEXT: entry: + // CHECK-NEXT: %m = alloca [130 x i32], align 4 + // CHECK-NEXT: call void @_Z3fooILm13EEvRU11matrix_typeXT_EXLm10EEi([130 x i32]* dereferenceable(520) %m) + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr void @_Z3fooILm13EEvRU11matrix_typeXT_EXLm10EEi([130 x i32]* dereferenceable(520) %m) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [130 x i32]*, align 8 + // CHECK-NEXT: store [130 x i32]* %m, [130 x i32]** %m.addr, align 8 + // CHECK-NEXT: ret void + + matrix m; + foo(m); +} + +template +matrix use_matrix_2(matrix &m) {} + +template +selector<0> use_matrix_2(matrix &m1, matrix &m2) {} + +template +selector<1> use_matrix_2(matrix &m1, matrix &m2) {} + +template +matrix use_matrix_2(matrix &m1) {} + +template +selector<2> use_matrix_3(matrix &m) {} + +void test_use_matrix_2() { + // CHECK-LABEL: define void @_Z17test_use_matrix_2v() + // CHECK-NEXT: entry: + // CHECK-NEXT: %m1 = alloca [24 x i32], align 4 + // CHECK-NEXT: %r1 = alloca [40 x float], align 4 + // CHECK-NEXT: %m2 = alloca [24 x float], align 4 + // CHECK-NEXT: %r2 = alloca %struct.selector.2, align 1 + // CHECK-NEXT: %undef.agg.tmp = alloca %struct.selector.2, align 1 + // CHECK-NEXT: %m3 = alloca [104 x i32], align 4 + // CHECK-NEXT: %m4 = alloca [15 x float], align 4 + // CHECK-NEXT: %r3 = alloca %struct.selector.1, align 1 + // CHECK-NEXT: %undef.agg.tmp1 = alloca %struct.selector.1, align 1 + // CHECK-NEXT: %m5 = alloca [50 x i32], align 4 + // CHECK-NEXT: %r4 = alloca [20 x float], align 4 + // CHECK-NEXT: %r5 = alloca %struct.selector.0, align 1 + // CHECK-NEXT: %undef.agg.tmp3 = alloca %struct.selector.0, align 1 + // CHECK-NEXT: %call = call <40 x float> @_Z12use_matrix_2ILm4ELm6EEU11matrix_typeXplT_Li1EEXplT0_Li2EEfRU11matrix_typeXT_EXT0_Ei([24 x i32]* dereferenceable(96) %m1) + // CHECK-NEXT: %0 = bitcast [40 x float]* %r1 to <40 x float>* + // CHECK-NEXT: store <40 x float> %call, <40 x float>* %0, align 4 + // CHECK-NEXT: call void @_Z12use_matrix_2ILm2ELm12EE8selectorILi0EERU11matrix_typeXplT_Li2EEXdvT0_Li2EEiRU11matrix_typeXT_EXT0_Ef([24 x i32]* dereferenceable(96) %m1, [24 x float]* dereferenceable(96) %m2) + // CHECK-NEXT: call void @_Z12use_matrix_2ILm5ELm8EE8selectorILi1EERU11matrix_typeXplT_T0_EXT0_EiRU11matrix_typeXT_EXmiT0_T_Ef([104 x i32]* dereferenceable(416) %m3, [15 x float]* dereferenceable(60) %m4) + // CHECK-NEXT: %call2 = call <20 x float> @_Z12use_matrix_2ILm5EEU11matrix_typeXplT_T_EXmiT_Li3EEfRU11matrix_typeXT_EXLm10EEi([50 x i32]* dereferenceable(200) %m5) + // CHECK-NEXT: %1 = bitcast [20 x float]* %r4 to <20 x float>* + // CHECK-NEXT: store <20 x float> %call2, <20 x float>* %1, align 4 + // CHECK-NEXT: call void @_Z12use_matrix_3ILm6EE8selectorILi2EERU11matrix_typeXmiT_Li2EEXT_Ei([24 x i32]* dereferenceable(96) %m1) + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr <40 x float> @_Z12use_matrix_2ILm4ELm6EEU11matrix_typeXplT_Li1EEXplT0_Li2EEfRU11matrix_typeXT_EXT0_Ei([24 x i32]* dereferenceable(96) %m) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [24 x i32]*, align 8 + // CHECK-NEXT: store [24 x i32]* %m, [24 x i32]** %m.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr void @_Z12use_matrix_2ILm2ELm12EE8selectorILi0EERU11matrix_typeXplT_Li2EEXdvT0_Li2EEiRU11matrix_typeXT_EXT0_Ef([24 x i32]* dereferenceable(96) %m1, [24 x float]* dereferenceable(96) %m2) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m1.addr = alloca [24 x i32]*, align 8 + // CHECK-NEXT: %m2.addr = alloca [24 x float]*, align 8 + // CHECK-NEXT: store [24 x i32]* %m1, [24 x i32]** %m1.addr, align 8 + // CHECK-NEXT: store [24 x float]* %m2, [24 x float]** %m2.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr void @_Z12use_matrix_2ILm5ELm8EE8selectorILi1EERU11matrix_typeXplT_T0_EXT0_EiRU11matrix_typeXT_EXmiT0_T_Ef([104 x i32]* dereferenceable(416) %m1, [15 x float]* dereferenceable(60) %m2) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m1.addr = alloca [104 x i32]*, align 8 + // CHECK-NEXT: %m2.addr = alloca [15 x float]*, align 8 + // CHECK-NEXT: store [104 x i32]* %m1, [104 x i32]** %m1.addr, align 8 + // CHECK-NEXT: store [15 x float]* %m2, [15 x float]** %m2.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr <20 x float> @_Z12use_matrix_2ILm5EEU11matrix_typeXplT_T_EXmiT_Li3EEfRU11matrix_typeXT_EXLm10EEi([50 x i32]* dereferenceable(200) %m1) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m1.addr = alloca [50 x i32]*, align 8 + // CHECK-NEXT: store [50 x i32]* %m1, [50 x i32]** %m1.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + // CHECK-LABEL: define linkonce_odr void @_Z12use_matrix_3ILm6EE8selectorILi2EERU11matrix_typeXmiT_Li2EEXT_Ei([24 x i32]* dereferenceable(96) %m) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca [24 x i32]*, align 8 + // CHECK-NEXT: store [24 x i32]* %m, [24 x i32]** %m.addr, align 8 + // CHECK-NEXT: call void @llvm.trap() + // CHECK-NEXT: unreachable + + matrix m1; + matrix r1 = use_matrix_2(m1); + + matrix m2; + selector<0> r2 = use_matrix_2(m1, m2); + + matrix m3; + matrix m4; + selector<1> r3 = use_matrix_2(m3, m4); + + matrix m5; + matrix r4 = use_matrix_2(m5); + + selector<2> r5 = use_matrix_3(m1); +} diff --git a/clang/test/Parser/matrix-type-disabled.c b/clang/test/Parser/matrix-type-disabled.c new file mode 100644 --- /dev/null +++ b/clang/test/Parser/matrix-type-disabled.c @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 %s -triple i686-apple-darwin -verify -fsyntax-only + +// Matrix types are disabled by default. + +#if __has_extension(matrix_types) +#error Expected extension 'matrix_types' to be disabled +#endif + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +// expected-error@-1 {{matrix types extension is disabled. Pass -fenable-matrix to enable it}} + +void load_store_double(dx5x5_t *a, dx5x5_t *b) { + *a = *b; +} diff --git a/clang/test/SemaCXX/matrix-type.cpp b/clang/test/SemaCXX/matrix-type.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/matrix-type.cpp @@ -0,0 +1,129 @@ +// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s + +using matrix_double_t = double __attribute__((matrix_type(6, 6))); +using matrix_float_t = float __attribute__((matrix_type(6, 6))); +using matrix_int_t = int __attribute__((matrix_type(6, 6))); + +void matrix_var_dimensions(int Rows, unsigned Columns, char C) { + using matrix1_t = int __attribute__((matrix_type(Rows, 1))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix3_t = int __attribute__((matrix_type(C, C))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix4_t = int __attribute__((matrix_type(-1, 1))); // expected-error{{matrix row size too large}} + using matrix5_t = int __attribute__((matrix_type(1, -1))); // expected-error{{matrix column size too large}} + using matrix6_t = int __attribute__((matrix_type(0, 1))); // expected-error{{zero matrix size}} + using matrix7_t = int __attribute__((matrix_type(1, 0))); // expected-error{{zero matrix size}} + using matrix7_t = int __attribute__((matrix_type(char, 0))); // expected-error{{expected '(' for function-style cast or type construction}} + using matrix8_t = int __attribute__((matrix_type(1048576, 1))); // expected-error{{matrix row size too large}} +} + +struct S1 {}; + +enum TestEnum { + A, + B +}; + +void matrix_unsupported_element_type() { + using matrix1_t = char *__attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'char *'}} + using matrix2_t = S1 __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'S1'}} + using matrix3_t = bool __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'bool'}} + using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}} +} + +template // expected-note{{declared here}} +void matrix_template_1() { + using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}} +} + +template // expected-note{{declared here}} +void matrix_template_2() { + using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}} +} + +template +void matrix_template_3() { + using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero matrix size}} +} + +void instantiate_template_3() { + matrix_template_3<1, 10>(); + matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}} +} + +template +void matrix_template_4() { + using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{matrix row size too large}} +} + +void instantiate_template_4() { + matrix_template_4<2, 10>(); + matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}} +} + +template +using matrix = T __attribute__((matrix_type(R, C))); + +template +void use_matrix(matrix &m) {} +// expected-note@-1 {{candidate function [with T = float, R = 10]}} + +template +void use_matrix(matrix &m) {} +// expected-note@-1 {{candidate function [with T = float, C = 10]}} + +void test_ambigous_deduction1() { + matrix m; + use_matrix(m); + // expected-error@-1 {{call to 'use_matrix' is ambiguous}} +} + +template +void type_conflict(matrix &m, T x) {} +// expected-note@-1 {{candidate template ignored: deduced conflicting types for parameter 'T' ('float' vs. 'char *')}} + +void test_type_conflict(char *p) { + matrix m; + type_conflict(m, p); + // expected-error@-1 {{no matching function for call to 'type_conflict'}} +} + +template +matrix use_matrix_2(matrix &m) {} +// expected-note@-1 {{candidate function template not viable: requires single argument 'm', but 2 arguments were provided}} +// expected-note@-2 {{candidate function template not viable: requires single argument 'm', but 2 arguments were provided}} + +template +void use_matrix_2(matrix &m1, matrix &m2) {} +// expected-note@-1 {{candidate function [with R = 3, C = 11] not viable: no known conversion from 'matrix' (aka 'int __attribute__((matrix_type(5, 6)))') to 'matrix &' (aka 'int __attribute__((matrix_type(5, 5)))&') for 1st argument}} +// expected-note@-2 {{candidate template ignored: deduced type 'matrix' of 2nd parameter does not match adjusted type 'matrix' of argument [with R = 3, C = 4]}} + +template +void use_matrix_2(matrix &m1, matrix &m2) {} +// expected-note@-1 {{candidate template ignored: deduced conflicting types for parameter 'T' ('int' vs. 'float')}} +// expected-note@-2 {{candidate template ignored: deduced type 'matrix<[...], 3UL + 4UL, 4UL>' of 1st parameter does not match adjusted type 'matrix<[...], 3, 4>' of argument [with T = int, R = 3, C = 4]}} + +template +void use_matrix_3(matrix &m) {} +// expected-note@-1 {{candidate template ignored: deduced type 'matrix<[...], 5UL - 2, 5UL>' of 1st parameter does not match adjusted type 'matrix<[...], 5, 5>' of argument [with T = unsigned int, R = 5]}} + +void test_use_matrix_2() { + matrix m1; + matrix r1 = use_matrix_2(m1); + // expected-error@-1 {{cannot initialize a variable of type 'matrix<[...], 5, 8>' with an rvalue of type 'matrix<[...], 5UL + 1, 6UL + 2>'}} + + matrix m2; + matrix r2 = use_matrix_2(m2); + // expected-error@-1 {{cannot initialize a variable of type 'matrix<[...], 5, 8>' with an rvalue of type 'matrix<[...], 4UL + 1, 5UL + 2>'}} + + matrix m3; + use_matrix_2(m1, m3); + // expected-error@-1 {{no matching function for call to 'use_matrix_2'}} + + matrix m4; + use_matrix_2(m4, m4); + // expected-error@-1 {{no matching function for call to 'use_matrix_2'}} + + matrix m5; + use_matrix_3(m5); + // expected-error@-1 {{no matching function for call to 'use_matrix_3'}} +} diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -1795,6 +1795,8 @@ DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type) DEFAULT_TYPELOC_IMPL(Vector, Type) DEFAULT_TYPELOC_IMPL(ExtVector, VectorType) +DEFAULT_TYPELOC_IMPL(ConstantMatrix, MatrixType) +DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, MatrixType) DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType) DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType) DEFAULT_TYPELOC_IMPL(Record, TagType)