diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -193,6 +193,8 @@ DependentAddressSpaceTypes; mutable llvm::FoldingSet VectorTypes; mutable llvm::FoldingSet DependentVectorTypes; + mutable llvm::FoldingSet MatrixTypes; + mutable llvm::FoldingSet DependentSizedMatrixTypes; mutable llvm::FoldingSet FunctionNoProtoTypes; mutable llvm::ContextualFoldingSet FunctionProtoTypes; @@ -1309,6 +1311,21 @@ Expr *SizeExpr, SourceLocation AttrLoc) const; + /// Return the unique reference to the matrix type of the specified element + /// type and size + /// + /// \pre \p MatrixType must be a built-in type. + QualType getMatrixType(QualType MatrixType, unsigned NumRows, + unsigned NumColumns) const; + + /// Return the unique reference to the matrix type of the specified element + /// type and size + /// + /// \pre \p MatrixElementType must be a built-in type. + QualType getDependentSizedMatrixType(QualType MatrixElementType, + Expr *RowExpr, Expr *ColumnExpr, + SourceLocation AttrLoc) const; + QualType getDependentAddressSpaceType(QualType PointeeType, Expr *AddrSpaceExpr, SourceLocation AttrLoc) const; diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1006,6 +1006,16 @@ DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); }) +DEF_TRAVERSE_TYPE(MatrixType, { TRY_TO(TraverseType(T->getElementType())); }) + +DEF_TRAVERSE_TYPE(DependentSizedMatrixType, { + if (T->getRowExpr()) + TRY_TO(TraverseStmt(T->getRowExpr())); + if (T->getColumnExpr()) + TRY_TO(TraverseStmt(T->getColumnExpr())); + TRY_TO(TraverseType(T->getElementType())); +}) + DEF_TRAVERSE_TYPE(FunctionNoProtoType, { TRY_TO(TraverseType(T->getReturnType())); }) @@ -1254,6 +1264,21 @@ TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); }) +// Same as VectorType: FIXME: MatrixTypeLoc is unfinished +DEF_TRAVERSE_TYPELOC(MatrixType, { + TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); +}) + +DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, { + if (TL.getTypePtr()->getRowExpr()) { + TRY_TO(TraverseStmt(TL.getTypePtr()->getRowExpr())); + } + if (TL.getTypePtr()->getColumnExpr()) { + TRY_TO(TraverseStmt(TL.getTypePtr()->getColumnExpr())); + } + TRY_TO(TraverseType(TL.getTypePtr()->getElementType())); +}) + DEF_TRAVERSE_TYPELOC(FunctionNoProtoType, { TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); }) diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -1657,6 +1657,19 @@ enum { MaxNumElements = (1 << (29 - NumTypeBits)) - 1 }; }; + class MatrixTypeBitfields { + friend class MatrixType; + + unsigned : NumTypeBits; + + // Number of rows and columns + unsigned NumRows : 29 - NumTypeBits; + unsigned NumColumns : 29 - NumTypeBits; + + enum { MaxNumRows = (1 << (29 - NumTypeBits)) - 1 }; + enum { MaxNumColumns = (1 << (29 - NumTypeBits)) - 1 }; + }; + class AttributedTypeBitfields { friend class AttributedType; @@ -1766,6 +1779,7 @@ TypeWithKeywordBitfields TypeWithKeywordBits; ElaboratedTypeBitfields ElaboratedTypeBits; VectorTypeBitfields VectorTypeBits; + MatrixTypeBitfields MatrixTypeBits; SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits; TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits; DependentTemplateSpecializationTypeBitfields @@ -2024,6 +2038,7 @@ bool isComplexIntegerType() const; // GCC _Complex integer type. bool isVectorType() const; // GCC vector type. bool isExtVectorType() const; // Extended vector type. + bool isMatrixType() const; bool isDependentAddressSpaceType() const; // value-dependent address space qualifier bool isObjCObjectPointerType() const; // pointer to ObjC object bool isObjCRetainableType() const; // ObjC object or block pointer @@ -3386,6 +3401,114 @@ } }; +/// MatrixType - This type is created using +/// __attribute__((matrix_type(rows, columns))), where "rows" is the +/// number of rows and "columns" is the number of columns. +class MatrixType : public Type, public llvm::FoldingSetNode { +protected: + friend class ASTContext; + + QualType ElementType; + + // MatrixElementType: The type of the elements in the matrix + // NRows: Number of rows + // NColumns: Number of columns + // CanonElementType: Canonical element type (if the matrix type is not + // canonical) + MatrixType(QualType MatrixElementType, unsigned NRows, unsigned NColumns, + QualType CanonElementType); + + // typeClass: The typeclass (defined in TypeNodes.def) + // MatrixElementType: The type of elements in the matrix + // NRows: The number of rows + // NColumns: The number of columns + // CanonElementType: Canonical type (if the matrixType is not canonical) + MatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows, + unsigned NColumns, QualType CanonElementType); + +public: + // The type of the elements being stored in the matrix + QualType getElementType() const { return ElementType; } + + // The number of rows in the matrix + unsigned getNumRows() const { return MatrixTypeBits.NumRows; } + + // The number of columns in the matrix + unsigned getNumColumns() const { return MatrixTypeBits.NumColumns; } + + unsigned getNumElementsFlattened() const { + return MatrixTypeBits.NumRows * MatrixTypeBits.NumColumns; + } + + // Check if the dimensions of the matrix fit in data storage type + static bool tooBig(unsigned NumRows, unsigned NumColumns) { + return NumRows > MatrixTypeBitfields::MaxNumRows || + NumColumns > MatrixTypeBitfields::MaxNumColumns; + } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, getElementType(), getNumRows(), getNumColumns(), + getTypeClass()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType, + unsigned NumRows, unsigned NumColumns, + TypeClass TypeClass) { + ID.AddPointer(ElementType.getAsOpaquePtr()); + ID.AddInteger(NumRows); + ID.AddInteger(NumColumns); + ID.AddInteger(TypeClass); + } + + static bool classof(const Type *T) { + return T->getTypeClass() == Matrix || + T->getTypeClass() == DependentSizedMatrix; + } +}; + +/// DependentSizedMatrixType - Represents a matrix type where the type +/// and size is dependnt on a template. +/// +class DependentSizedMatrixType : public Type, public llvm::FoldingSetNode { + friend class ASTContext; + + const ASTContext &Context; + Expr *RowExpr; + Expr *ColumnExpr; + + /// The element type of the matrix + QualType ElementType; + + SourceLocation loc; + + DependentSizedMatrixType(const ASTContext &Context, QualType ElementType, + QualType CanonicalType, Expr *RowExpr, + Expr *ColumnExpr, SourceLocation loc); + +public: + QualType getElementType() const { return ElementType; } + Expr *getRowExpr() const { return RowExpr; } + Expr *getColumnExpr() const { return ColumnExpr; } + SourceLocation getAttributeLoc() const { return loc; } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + static bool classof(const Type *T) { + return T->getTypeClass() == DependentSizedMatrix; + } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context, + QualType ElementType, Expr *RowExpr, Expr *ColumnExpr); +}; + /// FunctionType - C99 6.7.5.3 - Function Declarators. This is the common base /// class of FunctionNoProtoType and FunctionProtoType. class FunctionType : public Type { @@ -6543,6 +6666,10 @@ return isa(CanonicalType); } +inline bool Type::isMatrixType() const { + return isa(CanonicalType); +} + inline bool Type::isDependentAddressSpaceType() const { return isa(CanonicalType); } diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h --- a/clang/include/clang/AST/TypeLoc.h +++ b/clang/include/clang/AST/TypeLoc.h @@ -1774,6 +1774,18 @@ DependentSizedExtVectorType> { }; +// Same as VectorType: FIXME: attribute locations. +class MatrixTypeLoc + : public InheritingConcreteTypeLoc {}; + +// Same as VectorType: FIXME: attribute locations. Also look into making this +// a subtype of the MatrixTypeLoc +class DependentSizedMatrixTypeLoc + : public InheritingConcreteTypeLoc {}; + // FIXME: location of the '_Complex' keyword. class ComplexTypeLoc : public InheritingConcreteTypeLoc; } +let Class = MatrixType in { + def : Property<"elementType", QualType> { + let Read = [{ node->getElementType() }]; + } + def : Property<"numRows", UInt32> { + let Read = [{ node->getNumRows() }]; + } + def : Property<"numColumns", UInt32> { + let Read = [{ node->getNumColumns() }]; + } + + def : Creator<[{ + return ctx.getMatrixType(elementType, numRows, numColumns); + }]>; +} + +let Class = DependentSizedMatrixType in { + def : Property<"elementType", QualType> { + let Read = [{ node->getElementType() }]; + } + def : Property<"rows", ExprRef> { + let Read = [{ node->getRowExpr() }]; + } + def : Property<"columns", ExprRef> { + let Read = [{ node->getColumnExpr() }]; + } + def : Property<"attributeLoc", SourceLocation> { + let Read = [{ node->getAttributeLoc() }]; + } + + def : Creator<[{ + return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc); + }]>; +} + let Class = FunctionType in { def : Property<"returnType", QualType> { let Read = [{ node->getReturnType() }]; diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -2464,6 +2464,15 @@ let Documentation = [Undocumented]; } +def MatrixType : TypeAttr { + let Spellings = [Clang<"matrix_type">]; + let Subjects = SubjectList<[TypedefName], ErrorDiag>; + let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">]; + let Documentation = [Undocumented]; + let ASTNode = 0; + let PragmaAttributeSupport = 0; +} + def Visibility : InheritableAttr { let Clone = 0; let Spellings = [GCC<"visibility">]; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -2764,6 +2764,7 @@ def err_attribute_too_few_arguments : Error< "%0 attribute takes at least %1 argument%s1">; def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">; +def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">; def err_attribute_bad_neon_vector_size : Error< "Neon vector size must be 64 or 128 bits">; def err_attribute_requires_positive_integer : Error< @@ -10629,6 +10630,9 @@ "%select{non-pointer|function pointer|void pointer}0 argument to " "'__builtin_launder' is not allowed">; +def err_builtin_matrix_disabled: Error< + "Builtin matrix support is disabled. Pass -fenable-matrix to enable it.">; + def err_preserve_field_info_not_field : Error< "__builtin_preserve_field_info argument %0 not a field access">; def err_preserve_field_info_not_const: Error< diff --git a/clang/include/clang/Basic/LangOptions.def b/clang/include/clang/Basic/LangOptions.def --- a/clang/include/clang/Basic/LangOptions.def +++ b/clang/include/clang/Basic/LangOptions.def @@ -351,6 +351,8 @@ LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors") +LANGOPT(EnableMatrix, 1, 0, "Enable or disable the builtin matrix type") + COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0") #undef LANGOPT diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td --- a/clang/include/clang/Basic/TypeNodes.td +++ b/clang/include/clang/Basic/TypeNodes.td @@ -65,10 +65,12 @@ def VariableArrayType : TypeNode; def DependentSizedArrayType : TypeNode, AlwaysDependent; def DependentSizedExtVectorType : TypeNode, AlwaysDependent; +def DependentSizedMatrixType : TypeNode, AlwaysDependent; def DependentAddressSpaceType : TypeNode, AlwaysDependent; def VectorType : TypeNode; def DependentVectorType : TypeNode, AlwaysDependent; def ExtVectorType : TypeNode; +def MatrixType : TypeNode; def FunctionType : TypeNode; def FunctionProtoType : TypeNode; def FunctionNoProtoType : TypeNode; diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -1982,6 +1982,10 @@ def fno_strict_return : Flag<["-"], "fno-strict-return">, Group, Flags<[CC1Option]>; +def fenable_matrix : Flag<["-"], "fenable-matrix">, Group, + Flags<[CC1Option]>, + HelpText<"Enable matrix data type and related builtin functions">; + def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">, Group, Flags<[CC1Option]>, HelpText<"Treat editor placeholders as valid source code">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -1625,6 +1625,9 @@ QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc); QualType BuildExtVectorType(QualType T, Expr *ArraySize, SourceLocation AttrLoc); + QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns, + SourceLocation AttrLoc); + QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace, SourceLocation AttrLoc); diff --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def --- a/clang/include/clang/Serialization/TypeBitCodes.def +++ b/clang/include/clang/Serialization/TypeBitCodes.def @@ -58,5 +58,7 @@ TYPE_BIT_CODE(DependentAddressSpace, DEPENDENT_ADDRESS_SPACE, 47) TYPE_BIT_CODE(DependentVector, DEPENDENT_SIZED_VECTOR, 48) TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49) +TYPE_BIT_CODE(Matrix, MATRIX, 50) +TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 51) #undef TYPE_BIT_CODE diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -1926,6 +1926,18 @@ break; } + case Type::Matrix: { + const auto *MT = cast(T); + TypeInfo ElementInfo = getTypeInfo(MT->getElementType()); + // The matrix type is intended to be ABI compatible with arrays with respect + // to alignment and size. We use LLVM's array type for storage. + Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns(); + // If the alignment is not a power of 2, round up to the next power of 2. + // This happens for non-power-of-2 length vectors. + Align = ElementInfo.Width; + break; + } + case Type::Builtin: switch (cast(T)->getKind()) { default: llvm_unreachable("Unknown builtin type!"); @@ -3342,6 +3354,8 @@ case Type::DependentVector: case Type::ExtVector: case Type::DependentSizedExtVector: + case Type::Matrix: + case Type::DependentSizedMatrix: case Type::DependentAddressSpace: case Type::ObjCObject: case Type::ObjCInterface: @@ -3753,6 +3767,76 @@ return QualType(New, 0); } +/// getMatrixType - Return the unique reference to a matrix type of the +/// specified element type and size. ElementTy must be a built-in integer or +/// floating point type. +QualType ASTContext::getMatrixType(QualType ElementTy, unsigned NumRows, + unsigned NumColumns) const { + llvm::FoldingSetNodeID ID; + MatrixType::Profile(ID, ElementTy, NumRows, NumColumns, Type::Matrix); + + void *InsertPos = nullptr; + if (MatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos)) { + return QualType(MTP, 0); + } + + QualType Canonical; + if (!ElementTy.isCanonical()) { + Canonical = getMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns); + + MatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + assert(!NewIP && "Matrix type shouldn't already exist in the map"); + (void)NewIP; + } + + auto *New = new (*this, TypeAlignment) + MatrixType(ElementTy, NumRows, NumColumns, Canonical); + MatrixTypes.InsertNode(New, InsertPos); + Types.push_back(New); + return QualType(New, 0); +} + +// getDependentSizedMatrixType - Return a unique reference to the +// dependent matrix MatrixElementType must be a builtin type +QualType ASTContext::getDependentSizedMatrixType(QualType MatrixElementType, + Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttrLoc) const { + llvm::FoldingSetNodeID ID; + DependentSizedMatrixType::Profile( + ID, *this, getCanonicalType(MatrixElementType), RowExpr, ColumnExpr); + + void *InsertPos = nullptr; + DependentSizedMatrixType *Canon = + DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + DependentSizedMatrixType *New; + if (Canon) { + // Already have a canonical version of the matrix type + // Use it as the canonical type for newly-built types + New = new (*this, TypeAlignment) + DependentSizedMatrixType(*this, MatrixElementType, QualType(Canon, 0), + RowExpr, ColumnExpr, AttrLoc); + } else { + QualType CanonicalMatrixElementType = getCanonicalType(MatrixElementType); + if (CanonicalMatrixElementType == MatrixElementType) { + New = new (*this, TypeAlignment) DependentSizedMatrixType( + *this, MatrixElementType, QualType(), RowExpr, ColumnExpr, AttrLoc); + DependentSizedMatrixType *CanonCheck = + DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos); + assert(!CanonCheck && "Dependent-sized matrix canonical type broken"); + (void)CanonCheck; + DependentSizedMatrixTypes.InsertNode(New, InsertPos); + } else { + QualType Canon = getDependentSizedMatrixType( + CanonicalMatrixElementType, RowExpr, ColumnExpr, SourceLocation()); + New = new (*this, TypeAlignment) DependentSizedMatrixType( + *this, MatrixElementType, Canon, RowExpr, ColumnExpr, AttrLoc); + } + } + Types.push_back(New); + return QualType(New, 0); +} + QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType, Expr *AddrSpaceExpr, SourceLocation AttrLoc) const { @@ -7267,6 +7351,11 @@ *NotEncodedT = T; return; + case Type::Matrix: + if (NotEncodedT) + *NotEncodedT = T; + return; + // We could see an undeduced auto type here during error recovery. // Just ignore it. case Type::Auto: @@ -8092,6 +8181,15 @@ LHS->getNumElements() == RHS->getNumElements(); } +/// areCompatMatrixTypes - Return true if the two specified vector types are +/// compatible. +static bool areCompatMatrixTypes(const MatrixType *LHS, const MatrixType *RHS) { + assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified()); + return LHS->getElementType() == RHS->getElementType() && + LHS->getNumRows() == RHS->getNumRows() && + LHS->getNumColumns() == RHS->getNumColumns(); +} + bool ASTContext::areCompatibleVectorTypes(QualType FirstVec, QualType SecondVec) { assert(FirstVec->isVectorType() && "FirstVec should be a vector type"); @@ -9288,6 +9386,11 @@ RHSCan->castAs())) return LHS; return {}; + case Type::Matrix: + if (areCompatMatrixTypes(LHSCan->castAs(), + RHSCan->castAs())) + return LHS; + return {}; case Type::ObjCObject: { // Check if the types are assignment compatible. // FIXME: This should be type compatibility, e.g. whether diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -617,6 +617,39 @@ break; } + case Type::DependentSizedMatrix: { + const DependentSizedMatrixType *Mat1 = cast(T1); + const DependentSizedMatrixType *Mat2 = cast(T2); + // Rows + if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(), + Mat2->getRowExpr())) { + return false; + } + // Columns + if (!IsStructurallyEquivalent(Context, Mat1->getColumnExpr(), + Mat2->getColumnExpr())) { + return false; + } + // Element Type + if (!IsStructurallyEquivalent(Context, Mat1->getElementType(), + Mat2->getElementType())) { + return false; + } + return true; + } + + case Type::Matrix: { + const MatrixType *Mat1 = cast(T1); + const MatrixType *Mat2 = cast(T2); + if (!IsStructurallyEquivalent(Context, Mat1->getElementType(), + Mat2->getElementType())) + return false; + if (Mat1->getNumRows() != Mat2->getNumRows() || + Mat1->getNumColumns() != Mat2->getNumColumns()) + return false; + break; + } + case Type::FunctionProto: { const auto *Proto1 = cast(T1); const auto *Proto2 = cast(T2); diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -10286,6 +10286,7 @@ case Type::BlockPointer: case Type::Vector: case Type::ExtVector: + case Type::Matrix: case Type::ObjCObject: case Type::ObjCInterface: case Type::ObjCObjectPointer: diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -2065,6 +2065,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::Matrix: + case Type::DependentSizedMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Paren: @@ -3327,6 +3329,20 @@ mangleType(T->getElementType()); } +void CXXNameMangler::mangleType(const MatrixType *T) { + Out << "Dm" << T->getNumRows() << "_" << T->getNumColumns() << '_'; + mangleType(T->getElementType()); +} + +void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) { + Out << "Dm"; + mangleExpression(T->getRowExpr()); + Out << '_'; + mangleExpression(T->getColumnExpr()); + Out << '_'; + mangleType(T->getElementType()); +} + void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) { SplitQualType split = T->getPointeeType().split(); mangleQualifiers(split.Quals, T); diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -2755,6 +2755,23 @@ << Range; } +void MicrosoftCXXNameMangler::mangleType(const MatrixType *T, Qualifiers quals, + SourceRange Range) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, + "Cannot mangle this matrix type yet"); + Diags.Report(Range.getBegin(), DiagID) << Range; +} + +void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T, + Qualifiers quals, SourceRange Range) { + DiagnosticsEngine &Diags = Context.getDiags(); + unsigned DiagID = Diags.getCustomDiagID( + DiagnosticsEngine::Error, + "Cannot mangle this dependent-sized matrix type yet"); + Diags.Report(Range.getBegin(), DiagID) << Range; +} + void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T, Qualifiers, SourceRange Range) { DiagnosticsEngine &Diags = Context.getDiags(); diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -282,6 +282,45 @@ AddrSpaceExpr->Profile(ID, Context, true); } +MatrixType::MatrixType(QualType matrixType, unsigned nRows, unsigned nColumns, + QualType canonType) + : MatrixType(Matrix, matrixType, nRows, nColumns, canonType) {} + +MatrixType::MatrixType(TypeClass tc, QualType matrixType, unsigned nRows, + unsigned nColumns, QualType canonType) + : Type(tc, canonType, matrixType->getDependence()), + ElementType(matrixType) { + MatrixTypeBits.NumRows = nRows; + MatrixTypeBits.NumColumns = nColumns; +} + +DependentSizedMatrixType::DependentSizedMatrixType( + const ASTContext &CTX, QualType ElementType, QualType CanonicalType, + Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc) + : Type(DependentSizedMatrix, CanonicalType, + TypeDependence::Dependent | TypeDependence::Instantiation | + (ElementType->isVariablyModifiedType() + ? TypeDependence::VariablyModified + : TypeDependence::None) | + (ElementType->containsUnexpandedParameterPack() || + (RowExpr && + RowExpr->containsUnexpandedParameterPack()) || + (ColumnExpr && + ColumnExpr->containsUnexpandedParameterPack()) + ? TypeDependence::UnexpandedPack + : TypeDependence::None)), + Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr), + ElementType(ElementType), loc(loc) {} + +void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID, + const ASTContext &CTX, + QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr) { + ID.AddPointer(ElementType.getAsOpaquePtr()); + RowExpr->Profile(ID, CTX, true); + ColumnExpr->Profile(ID, CTX, true); +} + VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType, VectorKind vecKind) : VectorType(Vector, vecType, nElements, canonType, vecKind) {} @@ -938,6 +977,16 @@ return Ctx.getExtVectorType(elementType, T->getNumElements()); } + QualType VisitMatrixType(const MatrixType *T) { + QualType elementType = recurse(T->getElementType()); + if (elementType.isNull()) + return {}; + if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr()) + return QualType(T, 0); + + return Ctx.getMatrixType(elementType, T->getNumRows(), T->getNumColumns()); + } + QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) { QualType returnType = recurse(T->getReturnType()); if (returnType.isNull()) @@ -1757,6 +1806,14 @@ return Visit(T->getElementType()); } + Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) { + return Visit(T->getElementType()); + } + + Type *VisitMatrixType(const MatrixType *T) { + return Visit(T->getElementType()); + } + Type *VisitFunctionProtoType(const FunctionProtoType *T) { if (Syntactic && T->hasTrailingReturn()) return const_cast(T); @@ -3675,6 +3732,8 @@ case Type::Vector: case Type::ExtVector: return Cache::get(cast(T)->getElementType()); + case Type::Matrix: + return Cache::get(cast(T)->getElementType()); case Type::FunctionNoProto: return Cache::get(cast(T)->getReturnType()); case Type::FunctionProto: { @@ -3760,6 +3819,8 @@ case Type::Vector: case Type::ExtVector: return computeTypeLinkageInfo(cast(T)->getElementType()); + case Type::Matrix: + return computeTypeLinkageInfo(cast(T)->getElementType()); case Type::FunctionNoProto: return computeTypeLinkageInfo(cast(T)->getReturnType()); case Type::FunctionProto: { @@ -3921,6 +3982,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::Matrix: + case Type::DependentSizedMatrix: case Type::DependentAddressSpace: case Type::FunctionProto: case Type::FunctionNoProto: diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -254,6 +254,8 @@ case Type::DependentSizedExtVector: case Type::Vector: case Type::ExtVector: + case Type::Matrix: + case Type::DependentSizedMatrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Paren: @@ -718,6 +720,37 @@ OS << ")))"; } +void TypePrinter::printMatrixBefore(const MatrixType *T, raw_ostream &OS) { + // TODO: Fix the spacing between the element type and the __attribute__ + printBefore(T->getElementType(), OS); + OS << " __attribute__((matrix_type("; + OS << T->getNumRows() << ", " << T->getNumColumns(); + OS << ")))"; +} + +void TypePrinter::printMatrixAfter(const MatrixType *T, raw_ostream &OS) { + printAfter(T->getElementType(), OS); +} + +void TypePrinter::printDependentSizedMatrixBefore( + const DependentSizedMatrixType *T, raw_ostream &OS) { + printBefore(T->getElementType(), OS); + OS << " __attribute__((matrix_type("; + if (T->getRowExpr()) { + T->getRowExpr()->printPretty(OS, nullptr, Policy); + } + OS << ", "; + if (T->getColumnExpr()) { + T->getColumnExpr()->printPretty(OS, nullptr, Policy); + } + OS << ")))"; +} + +void TypePrinter::printDependentSizedMatrixAfter( + const DependentSizedMatrixType *T, raw_ostream &OS) { + printAfter(T->getElementType(), OS); +} + void FunctionProtoType::printExceptionSpecification(raw_ostream &OS, const PrintingPolicy &Policy) diff --git a/clang/lib/Basic/Targets/OSTargets.cpp b/clang/lib/Basic/Targets/OSTargets.cpp --- a/clang/lib/Basic/Targets/OSTargets.cpp +++ b/clang/lib/Basic/Targets/OSTargets.cpp @@ -133,6 +133,9 @@ Builder.defineMacro("__MACH__"); PlatformMinVersion = VersionTuple(Maj, Min, Rev); + + if (Opts.EnableMatrix) + Builder.defineMacro("__MATRIX_EXTENSION__", "1"); } static void addMinGWDefines(const llvm::Triple &Triple, const LangOptions &Opts, diff --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h --- a/clang/lib/CodeGen/CGDebugInfo.h +++ b/clang/lib/CodeGen/CGDebugInfo.h @@ -190,6 +190,7 @@ llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit); llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F); + llvm::DIType *CreateType(const MatrixType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit); diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -2704,6 +2704,23 @@ return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray); } +llvm::DIType *CGDebugInfo::CreateType(const MatrixType *Ty, + llvm::DIFile *Unit) { + llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit); + uint64_t Size = CGM.getContext().getTypeSize(Ty); + uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext()); + + // Number of Columns, followed by rows + llvm::SmallVector Subscripts; + Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns())); + Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows())); + llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts); + + // FIXME: Create another debug type for matrices + // For the time being, it treats it like a 2D array + return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray); +} + llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) { uint64_t Size; uint32_t Align; @@ -3097,6 +3114,8 @@ case Type::ExtVector: case Type::Vector: return CreateType(cast(Ty), Unit); + case Type::Matrix: + return CreateType(cast(Ty), Unit); case Type::ObjCObjectPointer: return CreateType(cast(Ty), Unit); case Type::ObjCObject: diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -145,8 +145,19 @@ Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align, const Twine &Name, Address *Alloca) { - return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name, - /*ArraySize=*/nullptr, Alloca); + Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name, + /*ArraySize=*/nullptr, Alloca); + + if (Ty->isMatrixType()) { + auto *ArrayTy = cast(Result.getType()->getElementType()); + auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(), + ArrayTy->getNumElements()); + + Result = Address( + Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()), + Result.getAlignment()); + } + return Result; } Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align, @@ -1759,6 +1770,20 @@ } } + if (Ty->isMatrixType()) { + auto *ArrayTy = dyn_cast( + cast(Addr.getPointer()->getType()) + ->getElementType()); + if (ArrayTy) { + auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(), + ArrayTy->getNumElements()); + + Addr = Address( + Builder.CreateBitCast(Addr.getPointer(), VectorTy->getPointerTo()), + Addr.getAlignment()); + } + } + Value = EmitToMemory(Value, Ty); LValue AtomicLValue = @@ -1812,6 +1837,20 @@ if (LV.isSimple()) { assert(!LV.getType()->isFunctionType()); + if (LV.getType()->isMatrixType()) { + auto *ArrayTy = dyn_cast( + cast(LV.getPointer(*this)->getType()) + ->getElementType()); + if (ArrayTy) { + auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(), + ArrayTy->getNumElements()); + + LV.setAddress(Address(Builder.CreateBitCast(LV.getPointer(*this), + VectorTy->getPointerTo()), + LV.getAlignment())); + } + } + // Everything needs a load. return RValue::get(EmitLoadOfScalar(LV, Loc)); } diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -268,6 +268,7 @@ case Type::MemberPointer: case Type::Vector: case Type::ExtVector: + case Type::Matrix: case Type::FunctionProto: case Type::FunctionNoProto: case Type::Enum: @@ -2019,6 +2020,7 @@ case Type::Complex: case Type::Vector: case Type::ExtVector: + case Type::Matrix: case Type::Record: case Type::Enum: case Type::Elaborated: diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -84,6 +84,13 @@ /// a type. For example, the scalar representation for _Bool is i1, but the /// memory representation is usually i8 or i32, depending on the target. llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) { + if (T->isMatrixType()) { + const Type *Ty = Context.getCanonicalType(T).getTypePtr(); + const MatrixType *MT = cast(Ty); + return llvm::ArrayType::get(ConvertType(MT->getElementType()), + MT->getNumRows() * MT->getNumColumns()); + } + llvm::Type *R = ConvertType(T); // If this is a non-bool type, don't map it. @@ -630,6 +637,12 @@ VT->getNumElements()); break; } + case Type::Matrix: { + const MatrixType *MT = cast(Ty); + ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()), + MT->getNumRows() * MT->getNumColumns()); + break; + } case Type::FunctionNoProto: case Type::FunctionProto: ResultType = ConvertFunctionTypeInternal(T); diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -3222,6 +3222,7 @@ // GCC treats vector and complex types as fundamental types. case Type::Vector: case Type::ExtVector: + case Type::Matrix: case Type::Complex: case Type::Atomic: // FIXME: GCC treats block pointers as fundamental types?! @@ -3457,6 +3458,7 @@ case Type::Builtin: case Type::Vector: case Type::ExtVector: + case Type::Matrix: case Type::Complex: case Type::BlockPointer: // Itanium C++ ABI 2.9.5p4: diff --git a/clang/lib/Driver/ToolChains/Clang.cpp b/clang/lib/Driver/ToolChains/Clang.cpp --- a/clang/lib/Driver/ToolChains/Clang.cpp +++ b/clang/lib/Driver/ToolChains/Clang.cpp @@ -4553,6 +4553,13 @@ if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false)) CmdArgs.push_back("-fdefault-calling-conv=stdcall"); + if (Args.hasArg(options::OPT_fenable_matrix)) { + // enable-matrix is needed by both the LangOpts and by LLVM. + CmdArgs.push_back("-fenable-matrix"); + CmdArgs.push_back("-mllvm"); + CmdArgs.push_back("-enable-matrix"); + } + CodeGenOptions::FramePointerKind FPKeepKind = getFramePointerKind(Args, RawTriple); const char *FPKeepKindStr = nullptr; diff --git a/clang/lib/Frontend/CompilerInvocation.cpp b/clang/lib/Frontend/CompilerInvocation.cpp --- a/clang/lib/Frontend/CompilerInvocation.cpp +++ b/clang/lib/Frontend/CompilerInvocation.cpp @@ -3346,6 +3346,8 @@ Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers); Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj); + Opts.EnableMatrix = Args.hasArg(OPT_fenable_matrix); + Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags); } @@ -3567,7 +3569,7 @@ InputArgList Args = Opts.ParseArgs(CommandLineArgs, MissingArgIndex, MissingArgCount, IncludedFlagsBitmask); LangOptions &LangOpts = *Res.getLangOpts(); - + // // Check for missing argument error. if (MissingArgCount) { Diags.Report(diag::err_drv_missing_argument) diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -4248,6 +4248,7 @@ case Type::Complex: case Type::Vector: case Type::ExtVector: + case Type::Matrix: case Type::Record: case Type::Enum: case Type::Elaborated: diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp --- a/clang/lib/Sema/SemaLookup.cpp +++ b/clang/lib/Sema/SemaLookup.cpp @@ -2966,6 +2966,7 @@ // These are fundamental types. case Type::Vector: case Type::ExtVector: + case Type::Matrix: case Type::Complex: break; diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp --- a/clang/lib/Sema/SemaTemplate.cpp +++ b/clang/lib/Sema/SemaTemplate.cpp @@ -5829,6 +5829,11 @@ return Visit(T->getElementType()); } +bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType( + const DependentSizedMatrixType *T) { + return Visit(T->getElementType()); +} + bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType( const DependentAddressSpaceType *T) { return Visit(T->getPointeeType()); @@ -5847,6 +5852,10 @@ return Visit(T->getElementType()); } +bool UnnamedLocalNoLinkageFinder::VisitMatrixType(const MatrixType *T) { + return Visit(T->getElementType()); +} + bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType( const FunctionProtoType* T) { for (const auto &A : T->param_types()) { diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp --- a/clang/lib/Sema/SemaTemplateDeduction.cpp +++ b/clang/lib/Sema/SemaTemplateDeduction.cpp @@ -2054,6 +2054,89 @@ return Sema::TDK_NonDeducedMismatch; } + // (clang extension) + // + // T __attribute__((matrix_type(, ))) + // TODO: Allow deduction from matrix type to vector type + // TODO: Decide on deduction from vector type to matrix type + case Type::Matrix: { + const MatrixType *MatrixParam = cast(Param); + // Matrix-DepSizedMatrix deduction + if (const DependentSizedMatrixType *MatrixArg = + dyn_cast(Arg)) { + // can't check number of elements since the argument is dependent + return DeduceTemplateArgumentsByTypeMatch( + S, TemplateParams, MatrixParam->getElementType(), + MatrixArg->getElementType(), Info, Deduced, TDF); + } + // Matrix-Matrix deduction + if (const MatrixType *MatrixArg = dyn_cast(Arg)) { + // Check that the dimensions are the same + if (MatrixParam->getNumRows() != MatrixArg->getNumRows() || + MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) { + return Sema::TDK_NonDeducedMismatch; + } + // Perform deduction on element types + return DeduceTemplateArgumentsByTypeMatch( + S, TemplateParams, MatrixParam->getElementType(), + MatrixArg->getElementType(), Info, Deduced, TDF); + } + return Sema::TDK_NonDeducedMismatch; + } + + case Type::DependentSizedMatrix: { + const DependentSizedMatrixType *MatrixParam = + cast(Param); + // DepSizedMatrix - DepSizedMatrix deduction + // DepSizedMatrix - Matrix deduction + if (const MatrixType *MatrixArg = dyn_cast(Arg)) { + // Do deduction on the element types + if (Sema::TemplateDeductionResult Result = + DeduceTemplateArgumentsByTypeMatch( + S, TemplateParams, MatrixParam->getElementType(), + MatrixArg->getElementType(), Info, Deduced, TDF)) { + return Result; + } + + // Deduce matrix size if possible + NonTypeTemplateParmDecl *RowExprTemplateParam = + getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr()); + NonTypeTemplateParmDecl *ColumnExprTemplateParam = + getDeducedParameterFromExpr(Info, MatrixParam->getColumnExpr()); + + // TODO: Allow one to fail and the other to succeed in the deduction + // Can't deduce either rows or columns, just say everything is fine + if (!RowExprTemplateParam || !ColumnExprTemplateParam) { + return Sema::TDK_Success; + } + + // Unsigned might make more sense + llvm::APSInt ArgRows(S.Context.getTypeSize(S.Context.IntTy)); + ArgRows = MatrixArg->getNumRows(); + + // Deduce Rows + { + Sema::TemplateDeductionResult Res = DeduceNonTypeTemplateArgument( + S, TemplateParams, RowExprTemplateParam, ArgRows, S.Context.IntTy, + true, Info, Deduced); + if (Res != Sema::TDK_Success) { + return Res; + } + } + + // Deduce Columns + llvm::APSInt ArgColumns(S.Context.getTypeSize(S.Context.IntTy)); + ArgColumns = MatrixArg->getNumColumns(); + + // Deduce columns + return DeduceNonTypeTemplateArgument( + S, TemplateParams, ColumnExprTemplateParam, ArgColumns, + S.Context.IntTy, true, Info, Deduced); + } + return Sema::TDK_NonDeducedMismatch; + } + // (clang extension) // // T __attribute__(((address_space(N)))) @@ -5695,6 +5778,24 @@ break; } + case Type::Matrix: { + const MatrixType *MatType = cast(T); + MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced, + Depth, Used); + break; + } + + case Type::DependentSizedMatrix: { + const DependentSizedMatrixType *MatType = cast(T); + MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced, + Depth, Used); + MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth, + Used); + MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced, + Depth, Used); + break; + } + case Type::FunctionProto: { const FunctionProtoType *Proto = cast(T); MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth, diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -2505,6 +2505,101 @@ return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc); } +/// \brief Build a Matrix Type +/// +/// Run the required checks for the matrix type +QualType Sema::BuildMatrixType(QualType T, Expr *NumRows, Expr *NumCols, + SourceLocation AttrLoc) { + assert(Context.getLangOpts().EnableMatrix && + "Should never build a matrix type when it is disabled"); + + if (NumRows->isTypeDependent() || NumCols->isTypeDependent() || + NumRows->isValueDependent() || NumCols->isValueDependent()) { + return Context.getDependentSizedMatrixType(T, NumRows, NumCols, AttrLoc); + } + + unsigned MatrixRows = 0; + unsigned MatrixColumns = 0; + + { // Handle parameter error checking + // Invalid matrix type (must be float or integer) + if (!(T->isIntegerType() || T->isRealFloatingType() || + T->isDependentType())) { + Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << T; + return QualType(); + } + + // Should this be kept at 32bit even though we're deprecating it? + llvm::APSInt ValueRows(32), ValueColumns(32); + + bool const RowsIsInteger = + NumRows->isIntegerConstantExpr(ValueRows, Context); + bool const ColumnsIsInteger = + NumCols->isIntegerConstantExpr(ValueColumns, Context); + + auto const RowRange = NumRows->getSourceRange(); + auto const ColRange = NumCols->getSourceRange(); + + // Both are invalid types + if (!RowsIsInteger && !ColumnsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange + << ColRange; + return QualType(); + } + + // One or the other are invalid + if (!RowsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange; + return QualType(); + } + + // Getting the wrong source range + if (!ColumnsIsInteger) { + Diag(AttrLoc, diag::err_attribute_argument_type) + << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange; + return QualType(); + } + + MatrixRows = static_cast(ValueRows.getZExtValue()); + MatrixColumns = static_cast(ValueColumns.getZExtValue()); + + // Check Matrix size + if (MatrixRows == 0 && MatrixColumns == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) + << "matrix" << RowRange << ColRange; + return QualType(); + } + if (MatrixRows == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange; + return QualType(); + } + if (MatrixColumns == 0) { + Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange; + return QualType(); + } + + if (VectorType::isVectorSizeTooLarge(MatrixRows) && + VectorType::isVectorSizeTooLarge(MatrixColumns)) { + Diag(AttrLoc, diag::err_attribute_size_too_large) + << "matrix" << RowRange << ColRange; + return QualType(); + } + + if (VectorType::isVectorSizeTooLarge(MatrixRows)) { + Diag(AttrLoc, diag::err_attribute_size_too_large) << "matrix" << RowRange; + return QualType(); + } + + if (VectorType::isVectorSizeTooLarge(MatrixColumns)) { + Diag(AttrLoc, diag::err_attribute_size_too_large) << "matrix" << ColRange; + return QualType(); + } + } + return Context.getMatrixType(T, MatrixRows, MatrixColumns); +} + bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) { if (T->isArrayType() || T->isFunctionType()) { Diag(Loc, diag::err_func_returning_array_function) @@ -7632,6 +7727,71 @@ } } +/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type +static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr, + Sema &S) { + if (!S.getLangOpts().EnableMatrix) { + S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled); + return; + } + + if (Attr.getNumArgs() != 2) { + S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments) + << Attr << 2; + return; + } + + Expr *rowsExpr = nullptr; + Expr *colsExpr = nullptr; + + // TODO: Refactor parameter extraction into separate function + // Get the number of rows + if (Attr.isArgIdent(0)) { + CXXScopeSpec SS; + SourceLocation TemplateKeywordLoc; + UnqualifiedId id; + id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc()); + ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS, + TemplateKeywordLoc, id, false, false); + + if (Rows.isInvalid()) { + // TODO: maybe a good error message would be nice here + return; + } + rowsExpr = Rows.get(); + } else { + assert(Attr.isArgExpr(0) && + "Argument to should either be an identity or expression"); + rowsExpr = Attr.getArgAsExpr(0); + } + + // Get the number of columns + if (Attr.isArgIdent(1)) { + CXXScopeSpec SS; + SourceLocation TemplateKeywordLoc; + UnqualifiedId id; + id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc()); + ExprResult Columns = S.ActOnIdExpression( + S.getCurScope(), SS, TemplateKeywordLoc, id, false, false); + + if (Columns.isInvalid()) { + // TODO: a good error message would be nice here + return; + } + rowsExpr = Columns.get(); + } else { + assert(Attr.isArgExpr(1) && + "Argument to should either be an identity or expression"); + colsExpr = Attr.getArgAsExpr(1); + } + + // Create Matrix Type + QualType T = S.BuildMatrixType(CurType, rowsExpr, colsExpr, Attr.getLoc()); + if (!T.isNull()) { + CurType = T; + } +} + static void HandleLifetimeBoundAttr(TypeProcessingState &State, QualType &CurType, ParsedAttr &Attr) { @@ -7783,6 +7943,11 @@ break; } + case ParsedAttr::AT_MatrixType: + HandleMatrixTypeAttr(type, attr, state.getSema()); + attr.setUsedAsTypeAttr(); + break; + MS_TYPE_ATTRS_CASELIST: if (!handleMSPointerTypeQualifierAttr(state, attr, type)) attr.setUsedAsTypeAttr(); diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -894,6 +894,16 @@ Expr *SizeExpr, SourceLocation AttributeLoc); + /// Build a new matrix type given the element type and dimensions. + QualType RebuildMatrixType(QualType ElementType, unsigned NumRows, + unsigned NumColumns); + + /// Build a new matrix type given the type and dependently-defined + /// dimensions. + QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr, + Expr *ColumnExpr, + SourceLocation AttributeLoc); + /// Build a new DependentAddressSpaceType or return the pointee /// type variable with the correct address space (retrieved from /// AddrSpaceExpr) applied to it. The former will be returned in cases @@ -5136,6 +5146,65 @@ return Result; } +template +QualType TreeTransform::TransformMatrixType(TypeLocBuilder &TLB, + MatrixTypeLoc TL) { + const MatrixType *T = TL.getTypePtr(); + QualType ElementType = getDerived().TransformType(T->getElementType()); + if (ElementType.isNull()) + return QualType(); + + QualType Result = TL.getType(); + if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) { + Result = getDerived().RebuildMatrixType(ElementType, T->getNumRows(), + T->getNumColumns()); + if (Result.isNull()) + return QualType(); + } + + MatrixTypeLoc NewTL = TLB.push(Result); + NewTL.setNameLoc(TL.getNameLoc()); + + return Result; +} + +template +QualType TreeTransform::TransformDependentSizedMatrixType( + TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) { + const DependentSizedMatrixType *T = TL.getTypePtr(); + + QualType ElementType = getDerived().TransformType(T->getElementType()); + if (ElementType.isNull()) { + return QualType(); + } + + EnterExpressionEvaluationContext Unevaluated( + SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated); + ExprResult Rows = getDerived().TransformExpr(T->getRowExpr()); + ExprResult Cols = getDerived().TransformExpr(T->getColumnExpr()); + + QualType Result = TL.getType(); + // TODO: Finish this + if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() || + Rows.get() != T->getRowExpr() || Cols.get() != T->getColumnExpr()) { + Result = getDerived().RebuildDependentSizedMatrixType( + ElementType, Rows.get(), Cols.get(), T->getAttributeLoc()); + + if (Result.isNull()) + return QualType(); + } + + if (isa(Result)) { + DependentSizedMatrixTypeLoc NewTL = + TLB.push(Result); + NewTL.setNameLoc(TL.getNameLoc()); + } else { + MatrixTypeLoc NewTL = TLB.push(Result); + NewTL.setNameLoc(TL.getNameLoc()); + } + return Result; +} + template QualType TreeTransform::TransformDependentAddressSpaceType( TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) { @@ -13546,6 +13615,21 @@ return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc); } +template +QualType TreeTransform::RebuildMatrixType(QualType ElementType, + unsigned NumRows, + unsigned NumColumns) { + return SemaRef.Context.getMatrixType(ElementType, NumRows, NumColumns); +} + +template +QualType TreeTransform::RebuildDependentSizedMatrixType( + QualType ElementType, Expr *RowExpr, Expr *ColumnExpr, + SourceLocation AttributeLoc) { + return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr, + AttributeLoc); +} + template QualType TreeTransform::RebuildFunctionProtoType( QualType T, diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -6525,6 +6525,15 @@ TL.setNameLoc(readSourceLocation()); } +void TypeLocReader::VisitMatrixTypeLoc(MatrixTypeLoc TL) { + TL.setNameLoc(readSourceLocation()); +} + +void TypeLocReader::VisitDependentSizedMatrixTypeLoc( + DependentSizedMatrixTypeLoc TL) { + TL.setNameLoc(readSourceLocation()); +} + void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) { TL.setLocalRangeBegin(readSourceLocation()); TL.setLParenLoc(readSourceLocation()); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -288,6 +288,15 @@ Record.AddSourceLocation(TL.getNameLoc()); } +void TypeLocWriter::VisitMatrixTypeLoc(MatrixTypeLoc TL) { + Record.AddSourceLocation(TL.getNameLoc()); +} + +void TypeLocWriter::VisitDependentSizedMatrixTypeLoc( + DependentSizedMatrixTypeLoc TL) { + Record.AddSourceLocation(TL.getNameLoc()); +} + void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) { Record.AddSourceLocation(TL.getLocalRangeBegin()); Record.AddSourceLocation(TL.getLParenLoc()); diff --git a/clang/test/CodeGen/matrix-type.c b/clang/test/CodeGen/matrix-type.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/matrix-type.c @@ -0,0 +1,79 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +typedef float fx3x4_t __attribute__((matrix_type(3, 4))); + +// CHECK: %struct.Matrix = type { i8, [12 x float], float } + +void load_store(dx5x5_t *a, dx5x5_t *b) { + // CHECK-LABEL: define void @load_store( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); + +void parameter_passing(fx3x3_t a, fx3x3_t *b) { + // CHECK-LABEL: define void @parameter_passing( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4 + // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4 + // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>* + // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4 + // CHECK-NEXT: ret void + *b = a; +} + +fx3x3_t return_matrix(fx3x3_t *a) { + // CHECK-LABEL: define <9 x float> @return_matrix + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>* + // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: ret <9 x float> %2 + return *a; +} + +typedef struct { + char Tmp1; + fx3x4_t Data; + float Tmp2; +} Matrix; + +void matrix_struct(Matrix *a, Matrix *b) { + // CHECK-LABEL: define void @matrix_struct( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b->Data = a->Data; +} diff --git a/clang/test/CodeGenCXX/matrix-type.cpp b/clang/test/CodeGenCXX/matrix-type.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCXX/matrix-type.cpp @@ -0,0 +1,176 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +typedef float fx3x4_t __attribute__((matrix_type(3, 4))); + +// CHECK: %struct.Matrix = type { i8, [12 x float], float } + +void load_store(dx5x5_t *a, dx5x5_t *b) { + // CHECK-LABEL: define void @_Z10load_storePDm5_5_dS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8 + // CHECK-NEXT: ret void + + *a = *b; +} + +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); + +void parameter_passing(fx3x3_t a, fx3x3_t *b) { + // CHECK-LABEL: define void @_Z17parameter_passingDm3_3_fPS_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4 + // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4 + // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8 + // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>* + // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4 + // CHECK-NEXT: ret void + *b = a; +} + +fx3x3_t return_matrix(fx3x3_t *a) { + // CHECK-LABEL: define <9 x float> @_Z13return_matrixPDm3_3_f( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8 + // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>* + // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: ret <9 x float> %2 + return *a; +} + +struct Matrix { + char Tmp1; + fx3x4_t Data; + float Tmp2; +}; + +void matrix_struct_pointers(Matrix *a, Matrix *b) { + // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b->Data = a->Data; +} + +void matrix_struct_reference(Matrix &a, Matrix &b) { + // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8 + // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b.Data = a.Data; +} + +class MatrixClass { +public: + int Tmp1; + fx3x4_t Data; + long Tmp2; +}; + +void matrix_class_reference(MatrixClass &a, MatrixClass &b) { + // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %class.MatrixClass*, align 8 + // CHECK-NEXT: %b.addr = alloca %class.MatrixClass*, align 8 + // CHECK-NEXT: store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8 + // CHECK-NEXT: store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8 + // CHECK-NEXT: %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>* + // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>* + // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4 + // CHECK-NEXT: ret void + b.Data = a.Data; +} + +template +class MatrixClassTemplate { +public: + using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols))); + int Tmp1; + MatrixTy Data; + long Tmp2; +}; + +template +void matrix_template_reference(MatrixClassTemplate &a, MatrixClassTemplate &b) { + b.Data = a.Data; +} + +MatrixClassTemplate matrix_template_reference_caller(float *Data) { + // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data + // CHECK-NEXT: entry: + // CHECK-NEXT: %Data.addr = alloca float*, align 8 + // CHECK-NEXT: %Arg = alloca %class.MatrixClassTemplate, align 8 + // CHECK-NEXT: store float* %Data, float** %Data.addr, align 8 + // CHECK-NEXT: %0 = load float*, float** %Data.addr, align 8 + // CHECK-NEXT: %1 = bitcast float* %0 to [150 x float]* + // CHECK-NEXT: %2 = bitcast [150 x float]* %1 to <150 x float>* + // CHECK-NEXT: %3 = load <150 x float>, <150 x float>* %2, align 4 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>* + // CHECK-NEXT: store <150 x float> %3, <150 x float>* %4, align 4 + // CHECK-NEXT: call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result) + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca %class.MatrixClassTemplate*, align 8 + // CHECK-NEXT: %b.addr = alloca %class.MatrixClassTemplate*, align 8 + // CHECK-NEXT: store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8 + // CHECK-NEXT: store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8 + // CHECK-NEXT: %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8 + // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1 + // CHECK-NEXT: %1 = bitcast [150 x float]* %Data to <150 x float>* + // CHECK-NEXT: %2 = load <150 x float>, <150 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8 + // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1 + // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>* + // CHECK-NEXT: store <150 x float> %2, <150 x float>* %4, align 4 + // CHECK-NEXT: ret void + + MatrixClassTemplate Result, Arg; + Arg.Data = *((MatrixClassTemplate::MatrixTy *)Data); + matrix_template_reference(Arg, Result); + return Result; +} diff --git a/clang/test/SemaCXX/matrix-type.cpp b/clang/test/SemaCXX/matrix-type.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/matrix-type.cpp @@ -0,0 +1,53 @@ +// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s + +using matrix_double_t = double __attribute__((matrix_type(6, 6))); +using matrix_float_t = float __attribute__((matrix_type(6, 6))); +using matrix_int_t = int __attribute__((matrix_type(6, 6))); + +void matrix_var_dimensions(int Rows, unsigned Columns, char C) { + using matrix1_t = int __attribute__((matrix_type(Rows, 1))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix3_t = int __attribute__((matrix_type(C, C))); // expected-error{{matrix_type attribute requires an integer constant}} + using matrix4_t = int __attribute__((matrix_type(-1, 1))); // expected-error{{vector size too large}} + using matrix5_t = int __attribute__((matrix_type(1, -1))); // expected-error{{vector size too large}} + using matrix6_t = int __attribute__((matrix_type(0, 1))); // expected-error{{zero vector size}} + using matrix7_t = int __attribute__((matrix_type(1, 0))); // expected-error{{zero vector size}} + using matrix7_t = int __attribute__((matrix_type(char, 0))); // expected-error{{expected '(' for function-style cast or type construction}} +} + +struct S1 {}; + +void matrix_unsupported_element_type() { + using matrix1_t = char *__attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'char *'}} + using matrix2_t = S1 __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'S1'}} +} + +template // expected-note{{declared here}} +void matrix_template_1() { + using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}} +} + +template // expected-note{{declared here}} +void matrix_template_2() { + using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}} +} + +template +void matrix_template_3() { + using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero vector size}} +} + +void instantiate_template_3() { + matrix_template_3<1, 10>(); + matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}} +} + +template +void matrix_template_4() { + using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{vector size too large}} +} + +void instantiate_template_4() { + matrix_template_4<2, 10>(); + matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}} +} diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -1786,6 +1786,8 @@ DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type) DEFAULT_TYPELOC_IMPL(Vector, Type) DEFAULT_TYPELOC_IMPL(ExtVector, VectorType) +DEFAULT_TYPELOC_IMPL(Matrix, Type) +DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, Type) DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType) DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType) DEFAULT_TYPELOC_IMPL(Record, TagType)