diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -978,6 +978,7 @@ #include "clang/Basic/OpenCLImageTypes.def" CanQualType OCLSamplerTy, OCLEventTy, OCLClkEventTy; CanQualType OCLQueueTy, OCLReserveIDTy; + CanQualType IncompleteMatrixIdxTy; CanQualType OMPArraySectionTy, OMPArrayShapingTy, OMPIteratorTy; #define EXT_OPAQUE_TYPE(ExtType, Id, Ext) \ CanQualType Id##Ty; diff --git a/clang/include/clang/AST/BuiltinTypes.def b/clang/include/clang/AST/BuiltinTypes.def --- a/clang/include/clang/AST/BuiltinTypes.def +++ b/clang/include/clang/AST/BuiltinTypes.def @@ -310,6 +310,9 @@ // context. PLACEHOLDER_TYPE(ARCUnbridgedCast, ARCUnbridgedCastTy) +// A placeholder type for incomplete matrix index expressions. +PLACEHOLDER_TYPE(IncompleteMatrixIdx, IncompleteMatrixIdxTy) + // A placeholder type for OpenMP array sections. PLACEHOLDER_TYPE(OMPArraySection, OMPArraySectionTy) diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h --- a/clang/include/clang/AST/ComputeDependence.h +++ b/clang/include/clang/AST/ComputeDependence.h @@ -28,6 +28,7 @@ class UnaryOperator; class UnaryExprOrTypeTraitExpr; class ArraySubscriptExpr; +class MatrixSubscriptExpr; class CompoundLiteralExpr; class CastExpr; class BinaryOperator; @@ -108,6 +109,7 @@ ExprDependence computeDependence(UnaryOperator *E); ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E); ExprDependence computeDependence(ArraySubscriptExpr *E); +ExprDependence computeDependence(MatrixSubscriptExpr *E); ExprDependence computeDependence(CompoundLiteralExpr *E); ExprDependence computeDependence(CastExpr *E); ExprDependence computeDependence(BinaryOperator *E); diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h --- a/clang/include/clang/AST/Expr.h +++ b/clang/include/clang/AST/Expr.h @@ -2584,7 +2584,7 @@ : Expr(ArraySubscriptExprClass, t, VK, OK) { SubExprs[LHS] = lhs; SubExprs[RHS] = rhs; - ArraySubscriptExprBits.RBracketLoc = rbracketloc; + ArrayOrMatrixSubscriptExprBits.RBracketLoc = rbracketloc; setDependence(computeDependence(this)); } @@ -2621,10 +2621,10 @@ SourceLocation getEndLoc() const { return getRBracketLoc(); } SourceLocation getRBracketLoc() const { - return ArraySubscriptExprBits.RBracketLoc; + return ArrayOrMatrixSubscriptExprBits.RBracketLoc; } void setRBracketLoc(SourceLocation L) { - ArraySubscriptExprBits.RBracketLoc = L; + ArrayOrMatrixSubscriptExprBits.RBracketLoc = L; } SourceLocation getExprLoc() const LLVM_READONLY { @@ -2644,6 +2644,72 @@ } }; +/// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType +/// extension. +class MatrixSubscriptExpr : public Expr { + enum { BASE, ROW_IDX, COLUMN_IDX, END_EXPR }; + Stmt *SubExprs[END_EXPR]; + +public: + MatrixSubscriptExpr(Expr *Base, Expr *RowIdx, Expr *ColumnIdx, QualType T, + SourceLocation RBracketLoc) + : Expr(MatrixSubscriptExprClass, T, Base->getValueKind(), + OK_MatrixElement) { + SubExprs[BASE] = Base; + SubExprs[ROW_IDX] = RowIdx; + SubExprs[COLUMN_IDX] = ColumnIdx; + ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc; + setDependence(computeDependence(this)); + } + + /// Create an empty matrix subscript expression. + explicit MatrixSubscriptExpr(EmptyShell Shell) + : Expr(MatrixSubscriptExprClass, Shell) {} + + Expr *getBase() { return cast(SubExprs[BASE]); } + const Expr *getBase() const { return cast(SubExprs[BASE]); } + void setBase(Expr *E) { SubExprs[BASE] = E; } + + Expr *getRowIdx() { return cast(SubExprs[ROW_IDX]); } + const Expr *getRowIdx() const { return cast(SubExprs[ROW_IDX]); } + void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; } + + Expr *getColumnIdx() { return cast_or_null(SubExprs[COLUMN_IDX]); } + const Expr *getColumnIdx() const { + return cast_or_null(SubExprs[COLUMN_IDX]); + } + void setColumnIdx(Expr *E) { SubExprs[COLUMN_IDX] = E; } + + SourceLocation getBeginLoc() const LLVM_READONLY { + return getBase()->getBeginLoc(); + } + + SourceLocation getEndLoc() const { return getRBracketLoc(); } + + SourceLocation getExprLoc() const LLVM_READONLY { + return getBase()->getExprLoc(); + } + + SourceLocation getRBracketLoc() const { + return ArrayOrMatrixSubscriptExprBits.RBracketLoc; + } + void setRBracketLoc(SourceLocation L) { + ArrayOrMatrixSubscriptExprBits.RBracketLoc = L; + } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == MatrixSubscriptExprClass; + } + + // Iterators + child_range children() { + return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR); + } + const_child_range children() const { + return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR); + } +}; + /// CallExpr - Represents a function call (C99 6.5.2.2, C++ [expr.call]). /// CallExpr itself represents a normal function call, e.g., "f(x, 2)", /// while its subclasses may represent alternative syntax that (semantically) diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -2588,6 +2588,7 @@ // over the children. DEF_TRAVERSE_STMT(AddrLabelExpr, {}) DEF_TRAVERSE_STMT(ArraySubscriptExpr, {}) +DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {}) DEF_TRAVERSE_STMT(OMPArraySectionExpr, {}) DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {}) DEF_TRAVERSE_STMT(OMPIteratorExpr, {}) diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h --- a/clang/include/clang/AST/Stmt.h +++ b/clang/include/clang/AST/Stmt.h @@ -445,8 +445,9 @@ unsigned IsType : 1; // true if operand is a type, false if an expression. }; - class ArraySubscriptExprBitfields { + class ArrayOrMatrixSubscriptExprBitfields { friend class ArraySubscriptExpr; + friend class MatrixSubscriptExpr; unsigned : NumExprBits; @@ -999,7 +1000,7 @@ CharacterLiteralBitfields CharacterLiteralBits; UnaryOperatorBitfields UnaryOperatorBits; UnaryExprOrTypeTraitExprBitfields UnaryExprOrTypeTraitExprBits; - ArraySubscriptExprBitfields ArraySubscriptExprBits; + ArrayOrMatrixSubscriptExprBitfields ArrayOrMatrixSubscriptExprBits; CallExprBitfields CallExprBits; MemberExprBitfields MemberExprBits; CastExprBitfields CastExprBits; diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -2035,7 +2035,8 @@ bool isComplexIntegerType() const; // GCC _Complex integer type. bool isVectorType() const; // GCC vector type. bool isExtVectorType() const; // Extended vector type. - bool isConstantMatrixType() const; // Matrix type. + bool isMatrixType() const; // Matrix type. + bool isConstantMatrixType() const; // Constant matrix type. bool isDependentAddressSpaceType() const; // value-dependent address space qualifier bool isObjCObjectPointerType() const; // pointer to ObjC object bool isObjCRetainableType() const; // ObjC object or block pointer @@ -6745,6 +6746,10 @@ return isa(CanonicalType); } +inline bool Type::isMatrixType() const { + return isa(CanonicalType); +} + inline bool Type::isConstantMatrixType() const { return isa(CanonicalType); } diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10747,6 +10747,12 @@ def err_builtin_matrix_disabled: Error< "matrix types extension is disabled. Pass -fenable-matrix to enable it">; +def err_matrix_index_not_integer: Error< + "matrix %select{row|column}0 index is not an integer">; +def err_matrix_index_outside_range: Error< + "matrix %select{row|column}0 index is outside the allowed range [0, %1)">; +def err_matrix_incomplete_index: Error< + "single subscript expressions are not allowed for matrix values">; def err_preserve_field_info_not_field : Error< "__builtin_preserve_field_info argument %0 not a field access">; diff --git a/clang/include/clang/Basic/Specifiers.h b/clang/include/clang/Basic/Specifiers.h --- a/clang/include/clang/Basic/Specifiers.h +++ b/clang/include/clang/Basic/Specifiers.h @@ -154,7 +154,10 @@ /// An Objective-C array/dictionary subscripting which reads an /// object or writes at the subscripted array/dictionary element via /// Objective-C method calls. - OK_ObjCSubscript + OK_ObjCSubscript, + + /// A single matrix element of a matrix. + OK_MatrixElement }; /// The reason why a DeclRefExpr does not constitute an odr-use. diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -69,6 +69,7 @@ def OffsetOfExpr : StmtNode; def UnaryExprOrTypeTraitExpr : StmtNode; def ArraySubscriptExpr : StmtNode; +def MatrixSubscriptExpr : StmtNode; def OMPArraySectionExpr : StmtNode; def OMPIteratorExpr : StmtNode; def CallExpr : StmtNode; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -4903,6 +4903,9 @@ Expr *Idx, SourceLocation RLoc); ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc, Expr *Idx, SourceLocation RLoc); + ExprResult ActOnMatrixSubscriptExpr(Scope *S, Expr *Base, Expr *RowIdx, + Expr *ColumnIdx, SourceLocation RBLoc); + ExprResult ActOnOMPArraySectionExpr(Expr *Base, SourceLocation LBLoc, Expr *LowerBound, SourceLocation ColonLoc, Expr *Length, SourceLocation RBLoc); diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h --- a/clang/include/clang/Serialization/ASTBitCodes.h +++ b/clang/include/clang/Serialization/ASTBitCodes.h @@ -1057,7 +1057,10 @@ /// The placeholder type for OpenMP iterator expression. PREDEF_TYPE_OMP_ITERATOR = 71, - /// OpenCL image types with auto numeration + /// A placeholder type for incomplete matrix index operations. + PREDEF_TYPE_INCOMPLETE_MATRIX_IDX = 72, + + /// OpenCL image types with auto numeration #define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) \ PREDEF_TYPE_##Id##_ID, #include "clang/Basic/OpenCLImageTypes.def" @@ -1597,6 +1600,9 @@ /// An ArraySubscriptExpr record. EXPR_ARRAY_SUBSCRIPT, + /// An MatrixSubscriptExpr record. + EXPR_MATRIX_SUBSCRIPT, + /// A CallExpr record. EXPR_CALL, diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -1388,6 +1388,8 @@ InitBuiltinType(OMPArrayShapingTy, BuiltinType::OMPArrayShaping); InitBuiltinType(OMPIteratorTy, BuiltinType::OMPIterator); } + if (LangOpts.MatrixTypes) + InitBuiltinType(IncompleteMatrixIdxTy, BuiltinType::IncompleteMatrixIdx); // C99 6.2.5p11. FloatComplexTy = getComplexType(FloatTy); diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp --- a/clang/lib/AST/ComputeDependence.cpp +++ b/clang/lib/AST/ComputeDependence.cpp @@ -83,6 +83,12 @@ return E->getLHS()->getDependence() | E->getRHS()->getDependence(); } +ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) { + return E->getBase()->getDependence() | E->getRowIdx()->getDependence() | + (E->getColumnIdx() ? E->getColumnIdx()->getDependence() + : ExprDependence::None); +} + ExprDependence clang::computeDependence(CompoundLiteralExpr *E) { return toExprDependence(E->getTypeSourceInfo()->getType()->getDependence()) | turnTypeToValueDependence(E->getInitializer()->getDependence()); diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -3425,6 +3425,7 @@ case ParenExprClass: case ArraySubscriptExprClass: + case MatrixSubscriptExprClass: case OMPArraySectionExprClass: case OMPArrayShapingExprClass: case OMPIteratorExprClass: diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp --- a/clang/lib/AST/ExprClassification.cpp +++ b/clang/lib/AST/ExprClassification.cpp @@ -224,6 +224,10 @@ } return Cl::CL_LValue; + // Subscripting matrix types behaves like member accesses. + case Expr::MatrixSubscriptExprClass: + return ClassifyInternal(Ctx, cast(E)->getBase()); + // C++ [expr.prim.general]p3: The result is an lvalue if the entity is a // function or variable and a prvalue otherwise. case Expr::DeclRefExprClass: diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -14184,6 +14184,7 @@ case Expr::ImaginaryLiteralClass: case Expr::StringLiteralClass: case Expr::ArraySubscriptExprClass: + case Expr::MatrixSubscriptExprClass: case Expr::OMPArraySectionExprClass: case Expr::OMPArrayShapingExprClass: case Expr::OMPIteratorExprClass: diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -4234,6 +4234,9 @@ break; } + case Expr::MatrixSubscriptExprClass: + llvm_unreachable("matrix subscript expressions not supported yet"); + case Expr::CompoundAssignOperatorClass: // fallthrough case Expr::BinaryOperatorClass: { const BinaryOperator *BO = cast(E); diff --git a/clang/lib/AST/NSAPI.cpp b/clang/lib/AST/NSAPI.cpp --- a/clang/lib/AST/NSAPI.cpp +++ b/clang/lib/AST/NSAPI.cpp @@ -482,6 +482,7 @@ case BuiltinType::Half: case BuiltinType::PseudoObject: case BuiltinType::BuiltinFn: + case BuiltinType::IncompleteMatrixIdx: case BuiltinType::OMPArraySection: case BuiltinType::OMPArrayShaping: case BuiltinType::OMPIterator: diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp --- a/clang/lib/AST/StmtPrinter.cpp +++ b/clang/lib/AST/StmtPrinter.cpp @@ -1337,6 +1337,16 @@ OS << "]"; } +void StmtPrinter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *Node) { + PrintExpr(Node->getBase()); + OS << "["; + PrintExpr(Node->getRowIdx()); + OS << "]"; + OS << "["; + PrintExpr(Node->getColumnIdx()); + OS << "]"; +} + void StmtPrinter::VisitOMPArraySectionExpr(OMPArraySectionExpr *Node) { PrintExpr(Node->getBase()); OS << "["; diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -1204,6 +1204,10 @@ VisitExpr(S); } +void StmtProfiler::VisitMatrixSubscriptExpr(const MatrixSubscriptExpr *S) { + VisitExpr(S); +} + void StmtProfiler::VisitOMPArraySectionExpr(const OMPArraySectionExpr *S) { VisitExpr(S); } diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -162,6 +162,9 @@ case OK_VectorComponent: OS << " vectorcomponent"; break; + case OK_MatrixElement: + OS << " matrixelement"; + break; } } } diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -3025,6 +3025,8 @@ return "queue_t"; case OCLReserveID: return "reserve_id_t"; + case IncompleteMatrixIdx: + return ""; case OMPArraySection: return ""; case OMPArrayShaping: @@ -4045,6 +4047,7 @@ #include "clang/Basic/AArch64SVEACLETypes.def" case BuiltinType::BuiltinFn: case BuiltinType::NullPtr: + case BuiltinType::IncompleteMatrixIdx: case BuiltinType::OMPArraySection: case BuiltinType::OMPArrayShaping: case BuiltinType::OMPIterator: diff --git a/clang/lib/AST/TypeLoc.cpp b/clang/lib/AST/TypeLoc.cpp --- a/clang/lib/AST/TypeLoc.cpp +++ b/clang/lib/AST/TypeLoc.cpp @@ -403,6 +403,7 @@ case BuiltinType::Id: #include "clang/Basic/AArch64SVEACLETypes.def" case BuiltinType::BuiltinFn: + case BuiltinType::IncompleteMatrixIdx: case BuiltinType::OMPArraySection: case BuiltinType::OMPArrayShaping: case BuiltinType::OMPIterator: diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -1369,6 +1369,8 @@ return EmitUnaryOpLValue(cast(E)); case Expr::ArraySubscriptExprClass: return EmitArraySubscriptExpr(cast(E)); + case Expr::MatrixSubscriptExprClass: + return EmitMatrixSubscriptExpr(cast(E)); case Expr::OMPArraySectionExprClass: return EmitOMPArraySectionExpr(cast(E)); case Expr::ExtVectorElementExprClass: @@ -1894,6 +1896,8 @@ if (LV.isGlobalReg()) return EmitLoadOfGlobalRegLValue(LV); + assert(!LV.isMatrixElt() && + "loads of matrix element LValues should be handled elsewhere"); assert(LV.isBitField() && "Unknown LValue type!"); return EmitLoadOfBitfieldLValue(LV, Loc); } @@ -1999,6 +2003,19 @@ return RValue::get(Call); } +// Store the specified rvalue into the specified matrix element. +static void EmitStoreThroughMatrixEltLValue(RValue Src, LValue Dst, + CodeGenFunction &CGF) { + Address DstAddr = MaybeConvertMatrixAddress( + Address(Dst.getVectorPointer(), + CGF.getContext().getTypeAlignInChars(Dst.getType())), + CGF); + llvm::Value *Vec = CGF.Builder.CreateLoad(DstAddr); + Vec = CGF.Builder.CreateInsertElement(Vec, Src.getScalarVal(), + Dst.getVectorIdx(), "matins"); + + CGF.Builder.CreateStore(Vec, DstAddr, Dst.isVolatileQualified()); +} /// EmitStoreThroughLValue - Store the specified rvalue into the specified /// lvalue, where both are guaranteed to the have the same type, and that type @@ -2025,6 +2042,9 @@ if (Dst.isGlobalReg()) return EmitStoreThroughGlobalRegLValue(Src, Dst); + if (Dst.isMatrixElt()) + return EmitStoreThroughMatrixEltLValue(Src, Dst, *this); + assert(Dst.isBitField() && "Unknown LValue type"); return EmitStoreThroughBitfieldLValue(Src, Dst); } @@ -3755,6 +3775,25 @@ return LV; } +LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) { + LValue Base = EmitLValue(E->getBase()); + llvm::Value *RowIdx = EmitScalarExpr(E->getRowIdx()); + llvm::Value *ColIdx = EmitScalarExpr(E->getColumnIdx()); + unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), + ColIdx->getType()->getScalarSizeInBits()); + llvm::Type *IntTy = llvm::IntegerType::get(getLLVMContext(), MaxWidth); + RowIdx = Builder.CreateZExt(RowIdx, IntTy); + ColIdx = Builder.CreateZExt(ColIdx, IntTy); + llvm::Value *NumRows = Builder.getIntN( + MaxWidth, + E->getBase()->getType()->getAs()->getNumRows()); + llvm::Value *FinalIdx = + Builder.CreateAdd(Builder.CreateMul(ColIdx, NumRows), RowIdx); + return LValue::MakeMatrixElt(Base.getAddress(*this), FinalIdx, + E->getBase()->getType(), Base.getBaseInfo(), + TBAAAccessInfo()); +} + static Address emitOMPArraySectionBase(CodeGenFunction &CGF, const Expr *Base, LValueBaseInfo &BaseInfo, TBAAAccessInfo &TBAAInfo, diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsPowerPC.h" +#include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/Module.h" #include @@ -577,6 +578,7 @@ } Value *VisitArraySubscriptExpr(ArraySubscriptExpr *E); + Value *VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E); Value *VisitShuffleVectorExpr(ShuffleVectorExpr *E); Value *VisitConvertVectorExpr(ConvertVectorExpr *E); Value *VisitMemberExpr(MemberExpr *E); @@ -1808,6 +1810,22 @@ return Builder.CreateExtractElement(Base, Idx, "vecext"); } +Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) { + TestAndClearIgnoreResultAssign(); + + // Handle the vector case. The base must be a vector, the index must be an + // integer value. + Value *RowIdx = Visit(E->getRowIdx()); + Value *ColumnIdx = Visit(E->getColumnIdx()); + Value *Matrix = Visit(E->getBase()); + + // TODO: Should we emit bounds checks with SanitizerKind::ArrayBounds? + llvm::MatrixBuilder MB(Builder); + return MB.CreateExtractElement( + Matrix, RowIdx, ColumnIdx, + E->getBase()->getType()->getAs()->getNumRows()); +} + static int getMaskElt(llvm::ShuffleVectorInst *SVI, unsigned Idx, unsigned Off) { int MV = SVI->getMaskValue(Idx); diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h --- a/clang/lib/CodeGen/CGValue.h +++ b/clang/lib/CodeGen/CGValue.h @@ -170,7 +170,8 @@ VectorElt, // This is a vector element l-value (V[i]), use getVector* BitField, // This is a bitfield l-value, use getBitfield*. ExtVectorElt, // This is an extended vector subset, use getExtVectorComp - GlobalReg // This is a register l-value, use getGlobalReg() + GlobalReg, // This is a register l-value, use getGlobalReg() + MatrixElt // This is a matrix element, use getVector* } LVType; llvm::Value *V; @@ -254,6 +255,7 @@ bool isBitField() const { return LVType == BitField; } bool isExtVectorElt() const { return LVType == ExtVectorElt; } bool isGlobalReg() const { return LVType == GlobalReg; } + bool isMatrixElt() const { return LVType == MatrixElt; } bool isVolatileQualified() const { return Quals.hasVolatile(); } bool isRestrictQualified() const { return Quals.hasRestrict(); } @@ -337,8 +339,14 @@ Address getVectorAddress() const { return Address(getVectorPointer(), getAlignment()); } - llvm::Value *getVectorPointer() const { assert(isVectorElt()); return V; } - llvm::Value *getVectorIdx() const { assert(isVectorElt()); return VectorIdx; } + llvm::Value *getVectorPointer() const { + assert(isVectorElt() || isMatrixElt()); + return V; + } + llvm::Value *getVectorIdx() const { + assert(isVectorElt() || isMatrixElt()); + return VectorIdx; + } // extended vector elements. Address getExtVectorAddress() const { @@ -430,6 +438,18 @@ return R; } + static LValue MakeMatrixElt(Address matAddress, llvm::Value *Idx, + QualType type, LValueBaseInfo BaseInfo, + TBAAAccessInfo TBAAInfo) { + LValue R; + R.LVType = MatrixElt; + R.V = matAddress.getPointer(); + R.VectorIdx = Idx; + R.Initialize(type, type.getQualifiers(), matAddress.getAlignment(), + BaseInfo, TBAAInfo); + return R; + } + RValue asAggregateRValue(CodeGenFunction &CGF) const { return RValue::getAggregate(getAddress(CGF), isVolatileQualified()); } diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -3625,6 +3625,7 @@ LValue EmitUnaryOpLValue(const UnaryOperator *E); LValue EmitArraySubscriptExpr(const ArraySubscriptExpr *E, bool Accessed = false); + LValue EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E); LValue EmitOMPArraySectionExpr(const OMPArraySectionExpr *E, bool IsLowerBound = true); LValue EmitExtVectorElementExpr(const ExtVectorElementExpr *E); diff --git a/clang/lib/Sema/SemaCast.cpp b/clang/lib/Sema/SemaCast.cpp --- a/clang/lib/Sema/SemaCast.cpp +++ b/clang/lib/Sema/SemaCast.cpp @@ -2089,6 +2089,7 @@ return TC_NotApplicable; // FIXME: Use a specific diagnostic for the rest of these cases. case OK_VectorComponent: inappropriate = "vector element"; break; + case OK_MatrixElement: inappropriate = "matrix element"; break; case OK_ObjCProperty: inappropriate = "property expression"; break; case OK_ObjCSubscript: inappropriate = "container subscripting expression"; break; diff --git a/clang/lib/Sema/SemaExceptionSpec.cpp b/clang/lib/Sema/SemaExceptionSpec.cpp --- a/clang/lib/Sema/SemaExceptionSpec.cpp +++ b/clang/lib/Sema/SemaExceptionSpec.cpp @@ -1299,6 +1299,7 @@ // Some might be dependent for other reasons. case Expr::ArraySubscriptExprClass: + case Expr::MatrixSubscriptExprClass: case Expr::OMPArraySectionExprClass: case Expr::OMPArrayShapingExprClass: case Expr::OMPIteratorExprClass: diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -4546,6 +4546,13 @@ base = result.get(); } + auto *SubscriptE = dyn_cast(base); + if (SubscriptE) + return ActOnMatrixSubscriptExpr(S, SubscriptE->getBase(), + SubscriptE->getRowIdx(), idx, rbLoc); + if (base->getType()->isMatrixType()) + return ActOnMatrixSubscriptExpr(S, base, idx, nullptr, rbLoc); + // A comma-expression as the index is deprecated in C++2a onwards. if (getLangOpts().CPlusPlus20 && ((isa(idx) && cast(idx)->isCommaOp()) || @@ -4561,7 +4568,9 @@ // resolution for the operator overload should get the first crack // at the overload. bool IsMSPropertySubscript = false; - if (base->getType()->isNonOverloadPlaceholderType()) { + auto BaseTy = base->getType(); + if (BaseTy->isNonOverloadPlaceholderType() && + !BaseTy->isSpecificPlaceholderType(BuiltinType::IncompleteMatrixIdx)) { IsMSPropertySubscript = isMSPropertySubscriptExpr(*this, base); if (!IsMSPropertySubscript) { ExprResult result = CheckPlaceholderExpr(base); @@ -4621,6 +4630,47 @@ return Res; } +ExprResult Sema::ActOnMatrixSubscriptExpr(Scope *S, Expr *Base, Expr *RowIdx, + Expr *ColumnIdx, + SourceLocation RBLoc) { + // Check that IndexExpr is an integer expression. If it is a constant + // expression, check that it is less than Dim (= the number of elements in the + // corresponding dimension). + auto IsIndexValid = [&](Expr *IndexExpr, unsigned Dim, bool IsColumnIdx) { + if (!IndexExpr->getType()->isIntegerType() && + !IndexExpr->isTypeDependent()) { + Diag(IndexExpr->getBeginLoc(), diag::err_matrix_index_not_integer) + << IsColumnIdx; + return false; + } + + llvm::APSInt Idx; + if (IndexExpr->isIntegerConstantExpr(Idx, Context) && + (Idx < 0 || Idx >= Dim)) { + Diag(IndexExpr->getBeginLoc(), diag::err_matrix_index_outside_range) + << IsColumnIdx << Dim; + return false; + } + return true; + }; + auto *MTy = Base->getType()->getAs(); + bool IsRowValid = IsIndexValid(RowIdx, MTy->getNumRows(), false); + + if (!ColumnIdx) { + if (!IsRowValid) + return ExprError(); + + return new (Context) MatrixSubscriptExpr( + Base, RowIdx, ColumnIdx, Context.IncompleteMatrixIdxTy, RBLoc); + } + + if (!IsRowValid || !IsIndexValid(ColumnIdx, MTy->getNumColumns(), true)) + return ExprError(); + + return new (Context) MatrixSubscriptExpr(Base, RowIdx, ColumnIdx, + MTy->getElementType(), RBLoc); +} + void Sema::CheckAddressOfNoDeref(const Expr *E) { ExpressionEvaluationContextRecord &LastRecord = ExprEvalContexts.back(); const Expr *StrippedExpr = E->IgnoreParenImpCasts(); @@ -5935,6 +5985,7 @@ // These are always invalid as call arguments and should be reported. case BuiltinType::BoundMember: case BuiltinType::BuiltinFn: + case BuiltinType::IncompleteMatrixIdx: case BuiltinType::OMPArraySection: case BuiltinType::OMPArrayShaping: case BuiltinType::OMPIterator: @@ -18864,6 +18915,11 @@ return ExprError(); } + case BuiltinType::IncompleteMatrixIdx: + Diag(cast(E)->getRowIdx()->getBeginLoc(), + diag::err_matrix_incomplete_index); + return ExprError(); + // Expressions of unknown type. case BuiltinType::OMPArraySection: Diag(E->getBeginLoc(), diag::err_omp_array_section_use); diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2410,6 +2410,17 @@ RBracketLoc); } + /// Build a new matrix subscript expression. + /// + /// By default, performs semantic analysis to build the new expression. + /// Subclasses may override this routine to provide different behavior. + ExprResult RebuildMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, + Expr *ColumnIdx, + SourceLocation RBracketLoc) { + return getSema().ActOnMatrixSubscriptExpr( + /*Scope=*/nullptr, Base, RowIdx, ColumnIdx, RBracketLoc); + } + /// Build a new array section expression. /// /// By default, performs semantic analysis to build the new expression. @@ -10253,6 +10264,29 @@ /*FIXME:*/ E->getLHS()->getBeginLoc(), RHS.get(), E->getRBracketLoc()); } +template +ExprResult +TreeTransform::TransformMatrixSubscriptExpr(MatrixSubscriptExpr *E) { + ExprResult Base = getDerived().TransformExpr(E->getBase()); + if (Base.isInvalid()) + return ExprError(); + + ExprResult RowIdx = getDerived().TransformExpr(E->getRowIdx()); + if (RowIdx.isInvalid()) + return ExprError(); + + ExprResult ColumnIdx = getDerived().TransformExpr(E->getColumnIdx()); + if (ColumnIdx.isInvalid()) + return ExprError(); + + if (!getDerived().AlwaysRebuild() && Base.get() == E->getBase() && + RowIdx.get() == E->getRowIdx() && ColumnIdx.get() == E->getColumnIdx()) + return E; + + return getDerived().RebuildMatrixSubscriptExpr( + Base.get(), RowIdx.get(), ColumnIdx.get(), E->getRBracketLoc()); +} + template ExprResult TreeTransform::TransformOMPArraySectionExpr(OMPArraySectionExpr *E) { diff --git a/clang/lib/Serialization/ASTCommon.cpp b/clang/lib/Serialization/ASTCommon.cpp --- a/clang/lib/Serialization/ASTCommon.cpp +++ b/clang/lib/Serialization/ASTCommon.cpp @@ -240,6 +240,9 @@ case BuiltinType::BuiltinFn: ID = PREDEF_TYPE_BUILTIN_FN; break; + case BuiltinType::IncompleteMatrixIdx: + ID = PREDEF_TYPE_INCOMPLETE_MATRIX_IDX; + break; case BuiltinType::OMPArraySection: ID = PREDEF_TYPE_OMP_ARRAY_SECTION; break; diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -7007,6 +7007,9 @@ case PREDEF_TYPE_BUILTIN_FN: T = Context.BuiltinFnTy; break; + case PREDEF_TYPE_INCOMPLETE_MATRIX_IDX: + T = Context.IncompleteMatrixIdxTy; + break; case PREDEF_TYPE_OMP_ARRAY_SECTION: T = Context.OMPArraySectionTy; break; diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -907,6 +907,14 @@ E->setRBracketLoc(readSourceLocation()); } +void ASTStmtReader::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) { + VisitExpr(E); + E->setBase(Record.readSubExpr()); + E->setRowIdx(Record.readSubExpr()); + E->setColumnIdx(Record.readSubExpr()); + // E->setRBracketLoc(readSourceLocation()); +} + void ASTStmtReader::VisitOMPArraySectionExpr(OMPArraySectionExpr *E) { VisitExpr(E); E->setBase(Record.readSubExpr()); @@ -2926,6 +2934,10 @@ S = new (Context) ArraySubscriptExpr(Empty); break; + case EXPR_MATRIX_SUBSCRIPT: + S = new (Context) MatrixSubscriptExpr(Empty); + break; + case EXPR_OMP_ARRAY_SECTION: S = new (Context) OMPArraySectionExpr(Empty); break; diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -772,6 +772,15 @@ Code = serialization::EXPR_ARRAY_SUBSCRIPT; } +void ASTStmtWriter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) { + VisitExpr(E); + Record.AddStmt(E->getBase()); + Record.AddStmt(E->getRowIdx()); + Record.AddStmt(E->getColumnIdx()); + // Record.AddSourceLocation(E->getRBracketLoc()); + Code = serialization::EXPR_ARRAY_SUBSCRIPT; +} + void ASTStmtWriter::VisitOMPArraySectionExpr(OMPArraySectionExpr *E) { VisitExpr(E); Record.AddStmt(E->getBase()); diff --git a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp --- a/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ b/clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1515,6 +1515,10 @@ Bldr.addNodes(Dst); break; + case Stmt::MatrixSubscriptExprClass: + llvm_unreachable("Support for MatrixSubscriptExpr is not implemented."); + break; + case Stmt::GCCAsmStmtClass: Bldr.takeNodes(Pred); VisitGCCAsmStmt(cast(S), Pred, Dst); diff --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/matrix-type-operators.c @@ -0,0 +1,304 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +// Tests for the matrix type operators. + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +typedef float fx2x3_t __attribute__((matrix_type(2, 3))); + +// Check that we can use matrix index expression on different floating point +// matrixes and indices. +void insert_fp(dx5x5_t a, double d, fx2x3_t b, float e, int j, unsigned k) { + // CHECK-LABEL: define void @insert_fp(<25 x double> %a, double %d, <6 x float> %b, float %e, i32 %j, i32 %k) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %d.addr = alloca double, align 8 + // CHECK-NEXT: %b.addr = alloca [6 x float], align 4 + // CHECK-NEXT: %e.addr = alloca float, align 4 + // CHECK-NEXT: %j.addr = alloca i32, align 4 + // CHECK-NEXT: %k.addr = alloca i32, align 4 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: store double %d, double* %d.addr, align 8 + // CHECK-NEXT: %1 = bitcast [6 x float]* %b.addr to <6 x float>* + // CHECK-NEXT: store <6 x float> %b, <6 x float>* %1, align 4 + // CHECK-NEXT: store float %e, float* %e.addr, align 4 + // CHECK-NEXT: store i32 %j, i32* %j.addr, align 4 + // CHECK-NEXT: store i32 %k, i32* %k.addr, align 4 + // CHECK-NEXT: %2 = load double, double* %d.addr, align 8 + // CHECK-NEXT: %3 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %matins = insertelement <25 x double> %3, double %2, i64 5 + // CHECK-NEXT: store <25 x double> %matins, <25 x double>* %0, align 8 + a[0ll][1u] = d; + + // CHECK-NEXT: %4 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %5 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %matins1 = insertelement <6 x float> %5, float %4, i32 1 + // CHECK-NEXT: store <6 x float> %matins1, <6 x float>* %1, align 4 + b[1][0] = e; + + // CHECK-NEXT: %6 = load double, double* %d.addr, align 8 + // CHECK-NEXT: %7 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %matins2 = insertelement <25 x double> %7, double %6, i32 1 + // CHECK-NEXT: store <25 x double> %matins2, <25 x double>* %0, align 8 + a[1][0u] = d; + + // CHECK-NEXT: %8 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %9 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %matins3 = insertelement <6 x float> %9, float %8, i64 3 + // CHECK-NEXT: store <6 x float> %matins3, <6 x float>* %1, align 4 + b[1ull][1] = e; + + // CHECK-NEXT: %10 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %11 = load i32, i32* %j.addr, align 4 + // CHECK-NEXT: %12 = load i32, i32* %k.addr, align 4 + // CHECK-NEXT: %13 = mul i32 %12, 2 + // CHECK-NEXT: %14 = add i32 %13, %11 + // CHECK-NEXT: %15 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %matins4 = insertelement <6 x float> %15, float %10, i32 %14 + // CHECK-NEXT: store <6 x float> %matins4, <6 x float>* %1, align 4 + // CHECK-NEXT: ret void + b[j][k] = e; +} + +// Check that we can can use matrix index expressions on integer matrixes. +typedef int ix9x3_t __attribute__((matrix_type(9, 3))); +void insert_int(ix9x3_t a, int i) { + // CHECK-LABEL: define void @insert_int(<27 x i32> %a, i32 %i) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %i.addr = alloca i32, align 4 + // CHECK-NEXT: %0 = bitcast [27 x i32]* %a.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %a, <27 x i32>* %0, align 4 + // CHECK-NEXT: store i32 %i, i32* %i.addr, align 4 + // CHECK-NEXT: %1 = load i32, i32* %i.addr, align 4 + // CHECK-NEXT: %2 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %matins = insertelement <27 x i32> %2, i32 %1, i32 13 + // CHECK-NEXT: store <27 x i32> %matins, <27 x i32>* %0, align 4 + // CHECK-NEXT: ret void + + a[4u][1u] = i; +} + +// Check that we can can use matrix index expressions on FP and integer +// matrixes. +typedef int ix9x3_t __attribute__((matrix_type(9, 3))); +void insert_int_fp(ix9x3_t *a, int i, fx2x3_t b, float e, short j, unsigned long long k) { + // CHECK-LABEL: define void @insert_int_fp([27 x i32]* %a, i32 %i, <6 x float> %b, float %e, i16 signext %j, i64 %k) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32]*, align 8 + // CHECK-NEXT: %i.addr = alloca i32, align 4 + // CHECK-NEXT: %b.addr = alloca [6 x float], align 4 + // CHECK-NEXT: %e.addr = alloca float, align 4 + // CHECK-NEXT: %j.addr = alloca i16, align 2 + // CHECK-NEXT: %k.addr = alloca i64, align 8 + // CHECK-NEXT: store [27 x i32]* %a, [27 x i32]** %a.addr, align 8 + // CHECK-NEXT: store i32 %i, i32* %i.addr, align 4 + // CHECK-NEXT: %0 = bitcast [6 x float]* %b.addr to <6 x float>* + // CHECK-NEXT: store <6 x float> %b, <6 x float>* %0, align 4 + // CHECK-NEXT: store float %e, float* %e.addr, align 4 + // CHECK-NEXT: store i16 %j, i16* %j.addr, align 2 + // CHECK-NEXT: store i64 %k, i64* %k.addr, align 8 + // CHECK-NEXT: %1 = load i32, i32* %i.addr, align 4 + // CHECK-NEXT: %2 = load [27 x i32]*, [27 x i32]** %a.addr, align 8 + // CHECK-NEXT: %3 = bitcast [27 x i32]* %2 to <27 x i32>* + // CHECK-NEXT: %4 = load <27 x i32>, <27 x i32>* %3, align 4 + // CHECK-NEXT: %matins = insertelement <27 x i32> %4, i32 %1, i32 13 + // CHECK-NEXT: store <27 x i32> %matins, <27 x i32>* %3, align 4 + (*a)[4u][1u] = i; + + // CHECK-NEXT: %5 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %6 = load <6 x float>, <6 x float>* %0, align 4 + // CHECK-NEXT: %matins1 = insertelement <6 x float> %6, float %5, i32 3 + // CHECK-NEXT: store <6 x float> %matins1, <6 x float>* %0, align 4 + b[1u][1u] = e; + + // CHECK-NEXT: %7 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %8 = load i16, i16* %j.addr, align 2 + // CHECK-NEXT: %9 = load i64, i64* %k.addr, align 8 + // CHECK-NEXT: %10 = zext i16 %8 to i64 + // CHECK-NEXT: %11 = mul i64 %9, 2 + // CHECK-NEXT: %12 = add i64 %11, %10 + // CHECK-NEXT: %13 = load <6 x float>, <6 x float>* %0, align 4 + // CHECK-NEXT: %matins2 = insertelement <6 x float> %13, float %7, i64 %12 + // CHECK-NEXT: store <6 x float> %matins2, <6 x float>* %0, align 4 + // CHECK-NEXT: ret void + b[j][k] = e; +} + +// Check that we can use overloaded matrix index expressions on matrixes with +// matching dimensions, but different element types. +typedef double dx3x3_t __attribute__((matrix_type(3, 3))); +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); +void insert_matching_dimensions(dx3x3_t a, double i, fx3x3_t b, float e, long int j, char k) { + // CHECK-LABEL: define void @insert_matching_dimensions(<9 x double> %a, double %i, <9 x float> %b, float %e, i64 %j, i8 signext %k) #3 { + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [9 x double], align 8 + // CHECK-NEXT: %i.addr = alloca double, align 8 + // CHECK-NEXT: %b.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %e.addr = alloca float, align 4 + // CHECK-NEXT: %j.addr = alloca i64, align 8 + // CHECK-NEXT: %k.addr = alloca i8, align 1 + // CHECK-NEXT: %0 = bitcast [9 x double]* %a.addr to <9 x double>* + // CHECK-NEXT: store <9 x double> %a, <9 x double>* %0, align 8 + // CHECK-NEXT: store double %i, double* %i.addr, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %b.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %b, <9 x float>* %1, align 4 + // CHECK-NEXT: store float %e, float* %e.addr, align 4 + // CHECK-NEXT: store i64 %j, i64* %j.addr, align 8 + // CHECK-NEXT: store i8 %k, i8* %k.addr, align 1 + // CHECK-NEXT: %2 = load double, double* %i.addr, align 8 + // CHECK-NEXT: %3 = load <9 x double>, <9 x double>* %0, align 8 + // CHECK-NEXT: %matins = insertelement <9 x double> %3, double %2, i32 5 + // CHECK-NEXT: store <9 x double> %matins, <9 x double>* %0, align 8 + a[2u][1u] = i; + + // CHECK-NEXT: %4 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %5 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %matins1 = insertelement <9 x float> %5, float %4, i32 7 + // CHECK-NEXT: store <9 x float> %matins1, <9 x float>* %1, align 4 + b[1u][2u] = e; + + // CHECK-NEXT: %6 = load double, double* %i.addr, align 8 + // CHECK-NEXT: %conv = fptrunc double %6 to float + // CHECK-NEXT: %7 = load i64, i64* %j.addr, align 8 + // CHECK-NEXT: %8 = load i8, i8* %k.addr, align 1 + // CHECK-NEXT: %9 = zext i8 %8 to i64 + // CHECK-NEXT: %10 = mul i64 %9, 3 + // CHECK-NEXT: %11 = add i64 %10, %7 + // CHECK-NEXT: %12 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %matins2 = insertelement <9 x float> %12, float %conv, i64 %11 + // CHECK-NEXT: store <9 x float> %matins2, <9 x float>* %1, align 4 + // CHECK-NEXT: ret void + b[j][k] = i; +} + +void extract1(dx5x5_t a, fx3x3_t b, ix9x3_t c, unsigned long j) { + // CHECK-LABEL: @extract1( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %b.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %c.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %j.addr = alloca i64, align 8 + // CHECK-NEXT: %v1 = alloca double, align 8 + // CHECK-NEXT: %v2 = alloca float, align 4 + // CHECK-NEXT: %v3 = alloca i32, align 4 + // CHECK-NEXT: %v4 = alloca i32, align 4 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %b.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %b, <9 x float>* %1, align 4 + double v1 = a[2][3]; + + // CHECK-NEXT: %2 = bitcast [27 x i32]* %c.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %c, <27 x i32>* %2, align 4 + // CHECK-NEXT: store i64 %j, i64* %j.addr, align 8 + // CHECK-NEXT: %3 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %matext = extractelement <25 x double> %3, i32 17 + // CHECK-NEXT: store double %matext, double* %v1, align 8 + float v2 = b[2][1]; + + // CHECK-NEXT: %4 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %matext1 = extractelement <9 x float> %4, i32 5 + // CHECK-NEXT: store float %matext1, float* %v2, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %2, align 4 + // CHECK-NEXT: %matext2 = extractelement <27 x i32> %5, i32 10 + // CHECK-NEXT: store i32 %matext2, i32* %v3, align 4 + int v3 = c[1][1]; + + // CHECK-NEXT: %6 = load i64, i64* %j.addr, align 8 + // CHECK-NEXT: %7 = load i64, i64* %j.addr, align 8 + // CHECK-NEXT: %8 = load <27 x i32>, <27 x i32>* %2, align 4 + // CHECK-NEXT: %9 = mul i64 %7, 9 + // CHECK-NEXT: %10 = add i64 %9, %6 + // CHECK-NEXT: %matext3 = extractelement <27 x i32> %8, i64 %10 + // CHECK-NEXT: store i32 %matext3, i32* %v4, align 4 + // CHECK-NEXT: ret void + int v4 = c[j][j]; +} + +typedef double dx3x2_t __attribute__((matrix_type(3, 2))); +double test_extract_matrix_pointer(dx5x5_t *ptr, dx3x2_t **ptr2) { + // CHECK-LABEL: define double @test_extract_matrix_pointer([25 x double]* %ptr, [6 x double]** %ptr2) + // CHECK-NEXT: entry: + // CHECK-NEXT: %ptr.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %ptr2.addr = alloca [6 x double]**, align 8 + // CHECK-NEXT: store [25 x double]* %ptr, [25 x double]** %ptr.addr, align 8 + // CHECK-NEXT: store [6 x double]** %ptr2, [6 x double]*** %ptr2.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %ptr.addr, align 8 + // CHECK-NEXT: %arrayidx = getelementptr inbounds [25 x double], [25 x double]* %0, i64 0 + // CHECK-NEXT: %1 = bitcast [25 x double]* %arrayidx to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %matext = extractelement <25 x double> %2, i32 17 + // CHECK-NEXT: %3 = load [6 x double]**, [6 x double]*** %ptr2.addr, align 8 + // CHECK-NEXT: %arrayidx1 = getelementptr inbounds [6 x double]*, [6 x double]** %3, i64 1 + // CHECK-NEXT: %4 = load [6 x double]*, [6 x double]** %arrayidx1, align 8 + // CHECK-NEXT: %arrayidx2 = getelementptr inbounds [6 x double], [6 x double]* %4, i64 2 + // CHECK-NEXT: %5 = bitcast [6 x double]* %arrayidx2 to <6 x double>* + // CHECK-NEXT: %6 = load <6 x double>, <6 x double>* %5, align 8 + // CHECK-NEXT: %matext3 = extractelement <6 x double> %6, i32 3 + // CHECK-NEXT: %add = fadd double %matext, %matext3 + // CHECK-NEXT: %7 = load [6 x double]**, [6 x double]*** %ptr2.addr, align 8 + // CHECK-NEXT: %add.ptr = getelementptr inbounds [6 x double]*, [6 x double]** %7, i64 4 + // CHECK-NEXT: %8 = load [6 x double]*, [6 x double]** %add.ptr, align 8 + // CHECK-NEXT: %add.ptr4 = getelementptr inbounds [6 x double], [6 x double]* %8, i64 6 + // CHECK-NEXT: %9 = bitcast [6 x double]* %add.ptr4 to <6 x double>* + // CHECK-NEXT: %10 = load <6 x double>, <6 x double>* %9, align 8 + // CHECK-NEXT: %matext5 = extractelement <6 x double> %10, i32 1 + // CHECK-NEXT: %add6 = fadd double %add, %matext5 + // CHECK-NEXT: ret double %add6 + + return (ptr[0])[2][3] + ptr2[1][2][0][1] + (*(*(ptr2 + 4) + 6))[1][0]; +} +void insert_extract(dx5x5_t a, fx3x3_t b, unsigned long j, short k) { + // CHECK-LABEL: define void @insert_extract(<25 x double> %a, <9 x float> %b, i64 %j, i16 signext %k) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %b.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %j.addr = alloca i64, align 8 + // CHECK-NEXT: %k.addr = alloca i16, align 2 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %b.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %b, <9 x float>* %1, align 4 + // CHECK-NEXT: store i64 %j, i64* %j.addr, align 8 + // CHECK-NEXT: store i16 %k, i16* %k.addr, align 2 + // CHECK-NEXT: %2 = load i64, i64* %j.addr, align 8 + // CHECK-NEXT: %3 = load i16, i16* %k.addr, align 2 + // CHECK-NEXT: %4 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %5 = zext i16 %3 to i64 + // CHECK-NEXT: %6 = mul i64 %5, 3 + // CHECK-NEXT: %7 = add i64 %6, %2 + // CHECK-NEXT: %matext = extractelement <9 x float> %4, i64 %7 + // CHECK-NEXT: %conv = fpext float %matext to double + // CHECK-NEXT: %8 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %matins = insertelement <25 x double> %8, double %conv, i32 17 + // CHECK-NEXT: store <25 x double> %matins, <25 x double>* %0, align 8 + a[2][3] = b[j][k]; + + // CHECK-NEXT: %9 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %matext1 = extractelement <9 x float> %9, i32 3 + // CHECK-NEXT: %10 = load i16, i16* %k.addr, align 2 + // CHECK-NEXT: %11 = load i64, i64* %j.addr, align 8 + // CHECK-NEXT: %12 = zext i16 %10 to i64 + // CHECK-NEXT: %13 = mul i64 %11, 3 + // CHECK-NEXT: %14 = add i64 %13, %12 + // CHECK-NEXT: %15 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %matins2 = insertelement <9 x float> %15, float %matext1, i64 %14 + // CHECK-NEXT: store <9 x float> %matins2, <9 x float>* %1, align 4 + b[k][j] = b[0][1]; + + // CHECK-NEXT: %16 = load i16, i16* %k.addr, align 2 + // CHECK-NEXT: %17 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %18 = zext i16 %16 to i32 + // CHECK-NEXT: %19 = mul i32 %18, 3 + // CHECK-NEXT: %20 = add i32 %19, 0 + // CHECK-NEXT: %matext3 = extractelement <9 x float> %17, i32 %20 + // CHECK-NEXT: %21 = load i64, i64* %j.addr, align 8 + // CHECK-NEXT: %22 = mul i64 %21, 3 + // CHECK-NEXT: %23 = add i64 %22, 2 + // CHECK-NEXT: %24 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %matins4 = insertelement <9 x float> %24, float %matext3, i64 %23 + // CHECK-NEXT: store <9 x float> %matins4, <9 x float>* %1, align 4 + // CHECK-NEXT: ret void + b[2][j] = b[0][k]; +} diff --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -0,0 +1,211 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck %s + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +using fx2x3_t = float __attribute__((matrix_type(2, 3))); + +void insert_fp(dx5x5_t *a, double d, fx2x3_t *b, float e) { + (*a)[0u][1u] = d; + (*b)[1u][0u] = e; + + // CHECK-LABEL: @_Z9insert_fpPU11matrix_typeLm5ELm5EddPU11matrix_typeLm2ELm3Eff( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %d.addr = alloca double, align 8 + // CHECK-NEXT: %b.addr = alloca [6 x float]*, align 8 + // CHECK-NEXT: %e.addr = alloca float, align 4 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store double %d, double* %d.addr, align 8 + // CHECK-NEXT: store [6 x float]* %b, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: store float %e, float* %e.addr, align 4 + // CHECK-NEXT: %0 = load double, double* %d.addr, align 8 + // CHECK-NEXT: %1 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %2 = bitcast [25 x double]* %1 to <25 x double>* + // CHECK-NEXT: %3 = load <25 x double>, <25 x double>* %2, align 8 + // CHECK-NEXT: %matins = insertelement <25 x double> %3, double %0, i32 5 + // CHECK-NEXT: store <25 x double> %matins, <25 x double>* %2, align 8 + // CHECK-NEXT: %4 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %5 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %6 = bitcast [6 x float]* %5 to <6 x float>* + // CHECK-NEXT: %7 = load <6 x float>, <6 x float>* %6, align 4 + // CHECK-NEXT: %matins1 = insertelement <6 x float> %7, float %4, i32 1 + // CHECK-NEXT: store <6 x float> %matins1, <6 x float>* %6, align 4 + // CHECK-NEXT: ret void +} + +typedef int ix9x3_t __attribute__((matrix_type(9, 3))); + +void insert_int(ix9x3_t *a, int i) { + (*a)[4u][1u] = i; + + // CHECK-LABEL: @_Z10insert_intPU11matrix_typeLm9ELm3Eii( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32]*, align 8 + // CHECK-NEXT: %i.addr = alloca i32, align 4 + // CHECK-NEXT: store [27 x i32]* %a, [27 x i32]** %a.addr, align 8 + // CHECK-NEXT: store i32 %i, i32* %i.addr, align 4 + // CHECK-NEXT: %0 = load i32, i32* %i.addr, align 4 + // CHECK-NEXT: %1 = load [27 x i32]*, [27 x i32]** %a.addr, align 8 + // CHECK-NEXT: %2 = bitcast [27 x i32]* %1 to <27 x i32>* + // CHECK-NEXT: %3 = load <27 x i32>, <27 x i32>* %2, align 4 + // CHECK-NEXT: %matins = insertelement <27 x i32> %3, i32 %0, i32 13 + // CHECK-NEXT: store <27 x i32> %matins, <27 x i32>* %2, align 4 + // CHECK-NEXT: ret void +} + +template +struct MyMatrix { + using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns))); + + matrix_t value; +}; + +template +void insert(MyMatrix &Mat, EltTy e) { + Mat.value[1u][0u] = e; +} + +void test_template(unsigned *Ptr1, unsigned E1, float *Ptr2, float E2) { + + // CHECK-LABEL: define void @_Z13test_templatePjjPff(i32* %Ptr1, i32 %E1, float* %Ptr2, float %E2) + // CHECK-NEXT: entry: + // CHECK-NEXT: %Ptr1.addr = alloca i32*, align 8 + // CHECK-NEXT: %E1.addr = alloca i32, align 4 + // CHECK-NEXT: %Ptr2.addr = alloca float*, align 8 + // CHECK-NEXT: %E2.addr = alloca float, align 4 + // CHECK-NEXT: %Mat1 = alloca %struct.MyMatrix, align 4 + // CHECK-NEXT: %Mat2 = alloca %struct.MyMatrix.0, align 4 + // CHECK-NEXT: store i32* %Ptr1, i32** %Ptr1.addr, align 8 + // CHECK-NEXT: store i32 %E1, i32* %E1.addr, align 4 + // CHECK-NEXT: store float* %Ptr2, float** %Ptr2.addr, align 8 + // CHECK-NEXT: store float %E2, float* %E2.addr, align 4 + // CHECK-NEXT: %0 = load i32*, i32** %Ptr1.addr, align 8 + // CHECK-NEXT: %1 = bitcast i32* %0 to [4 x i32]* + // CHECK-NEXT: %2 = bitcast [4 x i32]* %1 to <4 x i32>* + // CHECK-NEXT: %3 = load <4 x i32>, <4 x i32>* %2, align 4 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix, %struct.MyMatrix* %Mat1, i32 0, i32 0 + // CHECK-NEXT: %4 = bitcast [4 x i32]* %value to <4 x i32>* + // CHECK-NEXT: store <4 x i32> %3, <4 x i32>* %4, align 4 + // CHECK-NEXT: %5 = load i32, i32* %E1.addr, align 4 + // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_(%struct.MyMatrix* nonnull align 4 dereferenceable(16) %Mat1, i32 %5) + // CHECK-NEXT: %6 = load float*, float** %Ptr2.addr, align 8 + // CHECK-NEXT: %7 = bitcast float* %6 to [24 x float]* + // CHECK-NEXT: %8 = bitcast [24 x float]* %7 to <24 x float>* + // CHECK-NEXT: %9 = load <24 x float>, <24 x float>* %8, align 4 + // CHECK-NEXT: %value1 = getelementptr inbounds %struct.MyMatrix.0, %struct.MyMatrix.0* %Mat2, i32 0, i32 0 + // CHECK-NEXT: %10 = bitcast [24 x float]* %value1 to <24 x float>* + // CHECK-NEXT: store <24 x float> %9, <24 x float>* %10, align 4 + // CHECK-NEXT: %11 = load float, float* %E2.addr, align 4 + // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_(%struct.MyMatrix.0* nonnull align 4 dereferenceable(96) %Mat2, float %11) + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_(%struct.MyMatrix* nonnull align 4 dereferenceable(16) %Mat, i32 %e) + // CHECK-NEXT: entry: + // CHECK-NEXT: %Mat.addr = alloca %struct.MyMatrix*, align 8 + // CHECK-NEXT: %e.addr = alloca i32, align 4 + // CHECK-NEXT: store %struct.MyMatrix* %Mat, %struct.MyMatrix** %Mat.addr, align 8 + // CHECK-NEXT: store i32 %e, i32* %e.addr, align 4 + // CHECK-NEXT: %0 = load i32, i32* %e.addr, align 4 + // CHECK-NEXT: %1 = load %struct.MyMatrix*, %struct.MyMatrix** %Mat.addr, align 8 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix, %struct.MyMatrix* %1, i32 0, i32 0 + // CHECK-NEXT: %2 = bitcast [4 x i32]* %value to <4 x i32>* + // CHECK-NEXT: %3 = load <4 x i32>, <4 x i32>* %2, align 4 + // CHECK-NEXT: %matins = insertelement <4 x i32> %3, i32 %0, i32 1 + // CHECK-NEXT: store <4 x i32> %matins, <4 x i32>* %2, align 4 + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_(%struct.MyMatrix.0* nonnull align 4 dereferenceable(96) %Mat, float %e) + // CHECK-NEXT: entry: + // CHECK-NEXT: %Mat.addr = alloca %struct.MyMatrix.0*, align 8 + // CHECK-NEXT: %e.addr = alloca float, align 4 + // CHECK-NEXT: store %struct.MyMatrix.0* %Mat, %struct.MyMatrix.0** %Mat.addr, align 8 + // CHECK-NEXT: store float %e, float* %e.addr, align 4 + // CHECK-NEXT: %0 = load float, float* %e.addr, align 4 + // CHECK-NEXT: %1 = load %struct.MyMatrix.0*, %struct.MyMatrix.0** %Mat.addr, align 8 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.0, %struct.MyMatrix.0* %1, i32 0, i32 0 + // CHECK-NEXT: %2 = bitcast [24 x float]* %value to <24 x float>* + // CHECK-NEXT: %3 = load <24 x float>, <24 x float>* %2, align 4 + // CHECK-NEXT: %matins = insertelement <24 x float> %3, float %0, i32 1 + // CHECK-NEXT: store <24 x float> %matins, <24 x float>* %2, align 4 + // CHECK-NEXT: ret void + + MyMatrix Mat1; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + insert(Mat1, E1); + + MyMatrix Mat2; + Mat2.value = *((decltype(Mat2)::matrix_t *)Ptr2); + insert(Mat2, E2); +} + +typedef float fx3x3_t __attribute__((matrix_type(3, 3))); +void extract1(dx5x5_t a, fx3x3_t b, ix9x3_t c) { + // CHECK-LABEL: @_Z8extract1U11matrix_typeLm5ELm5EdU11matrix_typeLm3ELm3EfU11matrix_typeLm9ELm3Ei( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %b.addr = alloca [9 x float], align 4 + // CHECK-NEXT: %c.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %v1 = alloca double, align 8 + // CHECK-NEXT: %v2 = alloca float, align 4 + // CHECK-NEXT: %v3 = alloca i32, align 4 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: %1 = bitcast [9 x float]* %b.addr to <9 x float>* + // CHECK-NEXT: store <9 x float> %b, <9 x float>* %1, align 4 + // CHECK-NEXT: %2 = bitcast [27 x i32]* %c.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %c, <27 x i32>* %2, align 4 + // CHECK-NEXT: %3 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %matext = extractelement <25 x double> %3, i32 17 + // CHECK-NEXT: store double %matext, double* %v1, align 8 + // CHECK-NEXT: %4 = load <9 x float>, <9 x float>* %1, align 4 + // CHECK-NEXT: %matext1 = extractelement <9 x float> %4, i32 5 + // CHECK-NEXT: store float %matext1, float* %v2, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %2, align 4 + // CHECK-NEXT: %matext2 = extractelement <27 x i32> %5, i32 10 + // CHECK-NEXT: store i32 %matext2, i32* %v3, align 4 + // CHECK-NEXT: ret void + + double v1 = a[2][3]; + float v2 = b[2][1]; + int v3 = c[1][1]; +} + +template +EltTy extract(MyMatrix &Mat) { + return Mat.value[1u][0u]; +} + +void test_extract_template(unsigned *Ptr1, float *Ptr2) { + // CHECK-LABEL: define void @_Z21test_extract_templatePjPf(i32* %Ptr1, float* %Ptr2) + // CHECK-NEXT: entry: + // CHECK-NEXT: %Ptr1.addr = alloca i32*, align 8 + // CHECK-NEXT: %Ptr2.addr = alloca float*, align 8 + // CHECK-NEXT: %Mat1 = alloca %struct.MyMatrix, align 4 + // CHECK-NEXT: %v1 = alloca i32, align 4 + // CHECK-NEXT: store i32* %Ptr1, i32** %Ptr1.addr, align 8 + // CHECK-NEXT: store float* %Ptr2, float** %Ptr2.addr, align 8 + // CHECK-NEXT: %0 = load i32*, i32** %Ptr1.addr, align 8 + // CHECK-NEXT: %1 = bitcast i32* %0 to [4 x i32]* + // CHECK-NEXT: %2 = bitcast [4 x i32]* %1 to <4 x i32>* + // CHECK-NEXT: %3 = load <4 x i32>, <4 x i32>* %2, align 4 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix, %struct.MyMatrix* %Mat1, i32 0, i32 0 + // CHECK-NEXT: %4 = bitcast [4 x i32]* %value to <4 x i32>* + // CHECK-NEXT: store <4 x i32> %3, <4 x i32>* %4, align 4 + // CHECK-NEXT: %call = call i32 @_Z7extractIjLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(16) %Mat1) + // CHECK-NEXT: store i32 %call, i32* %v1, align 4 + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr i32 @_Z7extractIjLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(16) %Mat) + // CHECK-NEXT: entry: + // CHECK-NEXT: %Mat.addr = alloca %struct.MyMatrix*, align 8 + // CHECK-NEXT: store %struct.MyMatrix* %Mat, %struct.MyMatrix** %Mat.addr, align 8 + // CHECK-NEXT: %0 = load %struct.MyMatrix*, %struct.MyMatrix** %Mat.addr, align 8 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix, %struct.MyMatrix* %0, i32 0, i32 0 + // CHECK-NEXT: %1 = bitcast [4 x i32]* %value to <4 x i32>* + // CHECK-NEXT: %2 = load <4 x i32>, <4 x i32>* %1, align 4 + // CHECK-NEXT: %matext = extractelement <4 x i32> %2, i32 1 + // CHECK-NEXT: ret i32 %matext + + MyMatrix Mat1; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = extract(Mat1); +} diff --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c new file mode 100644 --- /dev/null +++ b/clang/test/Sema/matrix-type-operators.c @@ -0,0 +1,78 @@ +// RUN: %clang_cc1 %s -fenable-matrix -pedantic -verify -triple=x86_64-apple-darwin9 + +typedef float sx5x10_t __attribute__((matrix_type(5, 10))); + +sx5x10_t get_matrix(); + +void insert(sx5x10_t a, float f) { + // Non integer indexes. + a[3][f] = 0; + // expected-error@-1 {{matrix column index is not an integer}} + a[f][9] = 0; + // expected-error@-1 {{matrix row index is not an integer}} + a[f][f] = 0; + // expected-error@-1 {{matrix row index is not an integer}} + a[0][f] = 0; + // expected-error@-1 {{matrix column index is not an integer}} + + // Invalid element type. + a[3][4] = &f; + // expected-error@-1 {{assigning to 'float' from incompatible type 'float *'; remove &}} + + // Indexes outside allowed dimensions. + a[-1][3] = 10.0; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + a[3][-1] = 10.0; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + a[3][-1u] = 10.0; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + a[-1u][3] = 10.0; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + a[5][2] = 10.0; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + a[4][10] = 10.0; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + a[5][10.0] = f; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + + a[3] = 5.0; + // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} + + get_matrix()[0][0] = f; + // expected-error@-1 {{expression is not assignable}} + get_matrix()[5][10.0] = f; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + get_matrix()[3] = 5.0; + // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} +} + +void extract(sx5x10_t a, float f) { + // Non integer indexes. + float v1 = a[3][f]; + // expected-error@-1 {{matrix column index is not an integer}} + float v2 = a[f][9]; + // expected-error@-1 {{matrix row index is not an integer}} + float v3 = a[f][f]; + // expected-error@-1 {{matrix row index is not an integer}} + + // Invalid element type. + char *v4 = a[3][4]; + // expected-error@-1 {{initializing 'char *' with an expression of incompatible type 'float'}} + + // Indexes outside allowed dimensions. + float v5 = a[-1][3]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + float v6 = a[3][-1]; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + float v8 = a[-1u][3]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + float v9 = a[5][2]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + float v10 = a[4][10]; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + float v11 = a[5][10.0]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + + float v12 = a[3]; + // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} +} diff --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/matrix-type-operators.cpp @@ -0,0 +1,86 @@ +// RUN: %clang_cc1 %s -fenable-matrix -pedantic -std=c++11 -verify -triple=x86_64-apple-darwin9 + +typedef float sx5x10_t __attribute__((matrix_type(5, 10))); + +sx5x10_t get_matrix(); + +void insert(sx5x10_t a, float f) { + // Non integer indexes. + a[3][f] = 0; + // expected-error@-1 {{matrix column index is not an integer}} + a[f][9] = 0; + // expected-error@-1 {{matrix row index is not an integer}} + a[f][f] = 0; + // expected-error@-1 {{matrix row index is not an integer}} + a[0][f] = 0; + // expected-error@-1 {{matrix column index is not an integer}} + + // Invalid element type. + a[3][4] = &f; + // expected-error@-1 {{assigning to 'float' from incompatible type 'float *'; remove &}} + + // Indexes outside allowed dimensions. + a[-1][3] = 10.0; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + a[3][-1] = 10.0; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + a[3][-1u] = 10.0; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + a[-1u][3] = 10.0; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + a[5][2] = 10.0; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + a[4][10] = 10.0; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + a[5][10.0] = f; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + + get_matrix()[0][0] = f; + // expected-error@-1 {{expression is not assignable}} + get_matrix()[5][10.0] = f; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + get_matrix()[3] = 5.0; + // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} + + float &x = reinterpret_cast(a[3][3]); + // expected-error@-1 {{reinterpret_cast of a matrix element to 'float &' needs its address, which is not allowed}} +} + +void extract(sx5x10_t a, float f) { + // Non integer indexes. + float v1 = a[3][f]; + // expected-error@-1 {{matrix column index is not an integer}} + float v2 = a[f][9]; + // expected-error@-1 {{matrix row index is not an integer}} + float v3 = a[f][f]; + // expected-error@-1 {{matrix row index is not an integer}} + + // Invalid element type. + char *v4 = a[3][4]; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an lvalue of type 'float'}} + + // Indexes outside allowed dimensions. + float v5 = a[-1][3]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + float v6 = a[3][-1]; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + float v8 = a[-1u][3]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + float v9 = a[5][2]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + float v10 = a[4][10]; + // expected-error@-1 {{matrix column index is outside the allowed range [0, 10)}} + float v11 = a[5][10.0]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} + + float v12 = get_matrix()[0][0]; + float v13 = get_matrix()[5][10.0]; + // expected-error@-1 {{matrix row index is outside the allowed range [0, 5)}} +} + +void incomplete_matrix_index_expr(sx5x10_t a, float f) { + float x = a[3]; + // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} + a[2] = f; + // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} +} diff --git a/clang/tools/libclang/CXCursor.cpp b/clang/tools/libclang/CXCursor.cpp --- a/clang/tools/libclang/CXCursor.cpp +++ b/clang/tools/libclang/CXCursor.cpp @@ -419,6 +419,11 @@ K = CXCursor_ArraySubscriptExpr; break; + case Stmt::MatrixSubscriptExprClass: + // TODO: add support for MatrixSubscriptExpr. + K = CXCursor_UnexposedExpr; + break; + case Stmt::OMPArraySectionExprClass: K = CXCursor_OMPArraySectionExpr; break; diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -155,15 +155,19 @@ return B.CreateMul(LHS, ScalarVector); } - /// Extracts the element at (\p Row, \p Column) from \p Matrix. - Value *CreateExtractMatrix(Value *Matrix, Value *Row, Value *Column, - unsigned NumRows, Twine const &Name = "") { - + /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix. + Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx, + unsigned NumRows, Twine const &Name = "") { + + unsigned MaxWidth = std::max(RowIdx->getType()->getScalarSizeInBits(), + ColumnIdx->getType()->getScalarSizeInBits()); + Type *IntTy = IntegerType::get(RowIdx->getType()->getContext(), MaxWidth); + RowIdx = B.CreateZExt(RowIdx, IntTy); + ColumnIdx = B.CreateZExt(ColumnIdx, IntTy); + Value *NumRowsV = B.getIntN(MaxWidth, NumRows); return B.CreateExtractElement( - Matrix, - B.CreateAdd( - B.CreateMul(Column, ConstantInt::get(Column->getType(), NumRows)), - Row)); + Matrix, B.CreateAdd(B.CreateMul(ColumnIdx, NumRowsV), RowIdx), + "matext"); } };