diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -2050,7 +2050,8 @@ bool isComplexIntegerType() const; // GCC _Complex integer type. bool isVectorType() const; // GCC vector type. bool isExtVectorType() const; // Extended vector type. - bool isConstantMatrixType() const; // Matrix type. + bool isMatrixType() const; // Matrix type. + bool isConstantMatrixType() const; // Constant matrix type. bool isDependentAddressSpaceType() const; // value-dependent address space qualifier bool isObjCObjectPointerType() const; // pointer to ObjC object bool isObjCRetainableType() const; // ObjC object or block pointer @@ -6744,6 +6745,10 @@ return isa(CanonicalType); } +inline bool Type::isMatrixType() const { + return isa(CanonicalType); +} + inline bool Type::isConstantMatrixType() const { return isa(CanonicalType); } diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11208,6 +11208,11 @@ QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc); + /// Type checking for matrix binary operators. + QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign); + bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsPowerPC.h" +#include "llvm/IR/MatrixBuilder.h" #include "llvm/IR/Module.h" #include @@ -3536,6 +3537,11 @@ } } + if (op.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder MB(Builder); + return MB.CreateAdd(op.LHS, op.RHS); + } + if (op.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), op)) @@ -3720,6 +3726,11 @@ } } + if (op.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder MB(Builder); + return MB.CreateSub(op.LHS, op.RHS); + } + if (op.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), op)) diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -10235,6 +10235,11 @@ return compType; } + if (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType()) { + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy); + } + QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic); if (LHS.isInvalid() || RHS.isInvalid()) @@ -10330,6 +10335,11 @@ return compType; } + if (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType()) { + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy); + } + QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic); if (LHS.isInvalid() || RHS.isInvalid()) @@ -11925,6 +11935,63 @@ return GetSignedVectorType(LHS.get()->getType()); } +static bool tryConvertScalarToMatrixElementTy(Sema &S, QualType ElementType, + ExprResult *Scalar) { + InitializedEntity Entity = + InitializedEntity::InitializeTemporary(ElementType); + InitializationKind Kind = InitializationKind::CreateCopy( + Scalar->get()->getBeginLoc(), SourceLocation()); + Expr *Arg = Scalar->get(); + InitializationSequence InitSeq(S, Entity, Kind, Arg); + *Scalar = InitSeq.Perform(S, Entity, Kind, Arg); + return !Scalar->isInvalid(); +} + +QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign) { + if (!IsCompAssign) { + LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); + if (LHS.isInvalid()) + return QualType(); + } + RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); + if (RHS.isInvalid()) + return QualType(); + + // For conversion purposes, we ignore any qualifiers. + // For example, "const float" and "float" are equivalent. + QualType LHSType = LHS.get()->getType().getUnqualifiedType(); + QualType RHSType = RHS.get()->getType().getUnqualifiedType(); + + const MatrixType *LHSMatType = LHSType->getAs(); + const MatrixType *RHSMatType = RHSType->getAs(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (Context.hasSameType(LHSType, RHSType)) + return LHSType; + + // Type conversion may change LHS/RHS. Keep copies to the original results, in + // case we have to return InvalidOperands. + ExprResult OriginalLHS = LHS; + ExprResult OriginalRHS = RHS; + if (LHSMatType && !RHSMatType) { + if (tryConvertScalarToMatrixElementTy(*this, LHSMatType->getElementType(), + &RHS)) + return LHSType; + return InvalidOperands(Loc, OriginalLHS, OriginalRHS); + } + + if (!LHSMatType && RHSMatType) { + if (tryConvertScalarToMatrixElementTy(*this, RHSMatType->getElementType(), + &LHS)) + return RHSType; + return InvalidOperands(Loc, OriginalLHS, OriginalRHS); + } + + return InvalidOperands(Loc, LHS, RHS); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -7687,6 +7687,10 @@ /// candidates. TypeSet VectorTypes; + /// The set of matrix types that will be used in the built-in + /// candidates. + TypeSet MatrixTypes; + /// A flag indicating non-record types are viable candidates bool HasNonRecordTypes; @@ -7747,6 +7751,11 @@ iterator vector_begin() { return VectorTypes.begin(); } iterator vector_end() { return VectorTypes.end(); } + llvm::iterator_range matrix_types() { return MatrixTypes; } + iterator matrix_begin() { return MatrixTypes.begin(); } + iterator matrix_end() { return MatrixTypes.end(); } + + bool containsMatrixType(QualType Ty) const { return MatrixTypes.count(Ty); } bool hasNonRecordTypes() { return HasNonRecordTypes; } bool hasArithmeticOrEnumeralTypes() { return HasArithmeticOrEnumeralTypes; } bool hasNullPtrType() const { return HasNullPtrType; } @@ -7921,6 +7930,11 @@ // extension. HasArithmeticOrEnumeralTypes = true; VectorTypes.insert(Ty); + } else if (Ty->isMatrixType()) { + // Similar to vector types, we treat vector types as arithmetic types in + // many contexts as an extension. + HasArithmeticOrEnumeralTypes = true; + MatrixTypes.insert(Ty); } else if (Ty->isNullPtrType()) { HasNullPtrType = true; } else if (AllowUserConversions && TyRec) { @@ -8149,6 +8163,13 @@ } + /// Helper to add an overload candidate for a binary builtin with types \p L + /// and \p R. + void AddCandidate(QualType L, QualType R) { + QualType LandR[2] = {L, R}; + S.AddBuiltinCandidate(LandR, Args, CandidateSet); + } + public: BuiltinOperatorOverloadBuilder( Sema &S, ArrayRef Args, @@ -8567,6 +8588,27 @@ } } + /// Add binary operator overloads for each candidate matrix type M1, M2: + /// * (M1, M1) -> M1 + /// * (M1, M1.getElementType()) -> M1 + /// * (M2.getElementType(), M2) -> M2 + /// * (M2, M2) -> M2 // Only if M2 is not part of CandidateTypes[0]. + void addMatrixBinaryArithmeticOverloads() { + if (!HasArithmeticOrEnumeralCandidateType) + return; + + for (QualType M1 : CandidateTypes[0].matrix_types()) { + AddCandidate(M1, cast(M1)->getElementType()); + AddCandidate(M1, M1); + } + + for (QualType M2 : CandidateTypes[1].matrix_types()) { + AddCandidate(cast(M2)->getElementType(), M2); + if (!CandidateTypes[0].containsMatrixType(M2)) + AddCandidate(M2, M2); + } + } + // C++2a [over.built]p14: // // For every integral type T there exists a candidate operator function @@ -9140,6 +9182,7 @@ } else { OpBuilder.addBinaryPlusOrMinusPointerOverloads(Op); OpBuilder.addGenericBinaryArithmeticOverloads(); + OpBuilder.addMatrixBinaryArithmeticOverloads(); } break; diff --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/matrix-type-operators.c @@ -0,0 +1,174 @@ +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s + +typedef double dx5x5_t __attribute__((matrix_type(5, 5))); +typedef float fx2x3_t __attribute__((matrix_type(2, 3))); +typedef int ix9x3_t __attribute__((matrix_type(9, 3))); +typedef unsigned long long ullx4x2_t __attribute__((matrix_type(4, 2))); + +// Floating point matrix/scalar additions. + +void add_matrix_matrix_double(dx5x5_t a, dx5x5_t b, dx5x5_t c) { + // CHECK-LABEL: define void @add_matrix_matrix_double(<25 x double> %a, <25 x double> %b, <25 x double> %c) + // CHECK: [[B:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[C:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[RES:%.*]] = fadd <25 x double> [[B]], [[C]] + // CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 + + a = b + c; +} + +void add_matrix_matrix_float(fx2x3_t a, fx2x3_t b, fx2x3_t c) { + // CHECK-LABEL: define void @add_matrix_matrix_float(<6 x float> %a, <6 x float> %b, <6 x float> %c) + // CHECK: [[B:%.*]] = load <6 x float>, <6 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[C:%.*]] = load <6 x float>, <6 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = fadd <6 x float> [[B]], [[C]] + // CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* {{.*}}, align 4 + + a = b + c; +} + +void add_matrix_scalar_double_float(dx5x5_t a, float vf) { + // CHECK-LABEL: define void @add_matrix_scalar_double_float(<25 x double> %a, float %vf) + // CHECK: [[MATRIX:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR:%.*]] = load float, float* %vf.addr, align 4 + // CHECK-NEXT: [[SCALAR_EXT:%.*]] = fpext float [[SCALAR]] to double + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <25 x double> undef, double [[SCALAR_EXT]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <25 x double> [[SCALAR_EMBED]], <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fadd <25 x double> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 + + a = a + vf; +} + +void add_matrix_scalar_double_double(dx5x5_t a, double vd) { + // CHECK-LABEL: define void @add_matrix_scalar_double_double(<25 x double> %a, double %vd) + // CHECK: [[MATRIX:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR:%.*]] = load double, double* %vd.addr, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <25 x double> undef, double [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <25 x double> [[SCALAR_EMBED]], <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fadd <25 x double> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 + + a = a + vd; +} + +void add_matrix_scalar_float_float(fx2x3_t b, float vf) { + // CHECK-LABEL: define void @add_matrix_scalar_float_float(<6 x float> %b, float %vf) + // CHECK: [[MATRIX:%.*]] = load <6 x float>, <6 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[SCALAR:%.*]] = load float, float* %vf.addr, align 4 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <6 x float> undef, float [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <6 x float> [[SCALAR_EMBED]], <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fadd <6 x float> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* {{.*}}, align 4 + + b = b + vf; +} + +void add_matrix_scalar_float_double(fx2x3_t b, double vd) { + // CHECK-LABEL: define void @add_matrix_scalar_float_double(<6 x float> %b, double %vd) + // CHECK: [[MATRIX:%.*]] = load <6 x float>, <6 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[SCALAR:%.*]] = load double, double* %vd.addr, align 8 + // CHECK-NEXT: [[SCALAR_TRUNC:%.*]] = fptrunc double [[SCALAR]] to float + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <6 x float> undef, float [[SCALAR_TRUNC]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <6 x float> [[SCALAR_EMBED]], <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fadd <6 x float> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* {{.*}}, align 4 + + b = b + vd; +} + +// Integer matrix/scalar additions + +void add_matrix_matrix_int(ix9x3_t a, ix9x3_t b, ix9x3_t c) { + // CHECK-LABEL: define void @add_matrix_matrix_int(<27 x i32> %a, <27 x i32> %b, <27 x i32> %c) + // CHECK: [[B:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 + // CHECK-NEXT: [[C:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = add <27 x i32> [[B]], [[C]] + // CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* {{.*}}, align 4 + a = b + c; +} + +void add_matrix_matrix_unsigned_long_long(ullx4x2_t a, ullx4x2_t b, ullx4x2_t c) { + // CHECK-LABEL: define void @add_matrix_matrix_unsigned_long_long(<8 x i64> %a, <8 x i64> %b, <8 x i64> %c) + // CHECK: [[B:%.*]] = load <8 x i64>, <8 x i64>* {{.*}}, align 8 + // CHECK-NEXT: [[C:%.*]] = load <8 x i64>, <8 x i64>* {{.*}}, align 8 + // CHECK-NEXT: [[RES:%.*]] = add <8 x i64> [[B]], [[C]] + // CHECK-NEXT: store <8 x i64> [[RES]], <8 x i64>* {{.*}}, align 8 + + a = b + c; +} + +void add_matrix_scalar_int_short(ix9x3_t a, short vs) { + // CHECK-LABEL: define void @add_matrix_scalar_int_short(<27 x i32> %a, i16 signext %vs) + // CHECK: [[MATRIX:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 + // CHECK-NEXT: [[SCALAR:%.*]] = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: [[SCALAR_EXT:%.*]] = sext i16 [[SCALAR]] to i32 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <27 x i32> undef, i32 [[SCALAR_EXT]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <27 x i32> [[SCALAR_EMBED]], <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = add <27 x i32> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 + + a = a + vs; +} + +void add_matrix_scalar_int_long_int(ix9x3_t a, long int vli) { + // CHECK-LABEL: define void @add_matrix_scalar_int_long_int(<27 x i32> %a, i64 %vli) + // CHECK: [[MATRIX:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 + // CHECK-NEXT: [[SCALAR:%.*]] = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: [[SCALAR_TRUNC:%.*]] = trunc i64 [[SCALAR]] to i32 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <27 x i32> undef, i32 [[SCALAR_TRUNC]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <27 x i32> [[SCALAR_EMBED]], <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = add <27 x i32> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 + + a = a + vli; +} + +void add_matrix_scalar_int_unsigned_long_long(ix9x3_t a, unsigned long long int vulli) { + // CHECK-LABEL: define void @add_matrix_scalar_int_unsigned_long_long(<27 x i32> %a, i64 %vulli) + // CHECK: [[MATRIX:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 + // CHECK-NEXT: [[SCALAR:%.*]] = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: [[SCALAR_TRUNC:%.*]] = trunc i64 [[SCALAR]] to i32 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <27 x i32> undef, i32 [[SCALAR_TRUNC]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <27 x i32> [[SCALAR_EMBED]], <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = add <27 x i32> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 + + a = a + vulli; +} + +void add_matrix_scalar_long_long_int_short(ullx4x2_t b, short vs) { + // CHECK-LABEL: define void @add_matrix_scalar_long_long_int_short(<8 x i64> %b, i16 signext %vs) + // CHECK: [[SCALAR:%.*]] = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: [[SCALAR_EXT:%.*]] = sext i16 [[SCALAR]] to i64 + // CHECK-NEXT: [[MATRIX:%.*]] = load <8 x i64>, <8 x i64>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <8 x i64> undef, i64 [[SCALAR_EXT]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <8 x i64> [[SCALAR_EMBED]], <8 x i64> undef, <8 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = add <8 x i64> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK-NEXT: store <8 x i64> [[RES]], <8 x i64>* {{.*}}, align 8 + + b = vs + b; +} + +void add_matrix_scalar_long_long_int_int(ullx4x2_t b, long int vli) { + // CHECK-LABEL: define void @add_matrix_scalar_long_long_int_int(<8 x i64> %b, i64 %vli) + // CHECK: [[SCALAR:%.*]] = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: [[MATRIX:%.*]] = load <8 x i64>, <8 x i64>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <8 x i64> undef, i64 [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <8 x i64> [[SCALAR_EMBED]], <8 x i64> undef, <8 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = add <8 x i64> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK-NEXT: store <8 x i64> [[RES]], <8 x i64>* {{.*}}, align 8 + + b = vli + b; +} + +void add_matrix_scalar_long_long_int_unsigned_long_long(ullx4x2_t b, unsigned long long int vulli) { + // CHECK-LABEL: define void @add_matrix_scalar_long_long_int_unsigned_long_long + // CHECK: [[SCALAR:%.*]] = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: [[MATRIX:%.*]] = load <8 x i64>, <8 x i64>* %0, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <8 x i64> undef, i64 [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <8 x i64> [[SCALAR_EMBED]], <8 x i64> undef, <8 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = add <8 x i64> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK-NEXT: store <8 x i64> [[RES]], <8 x i64>* {{.*}}, align 8 + b = vulli + b; +} diff --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -0,0 +1,156 @@ +// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py +// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - -std=c++11 | FileCheck %s + +template +struct MyMatrix { + using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns))); + + matrix_t value; +}; + +template +typename MyMatrix::matrix_t add(MyMatrix &A, MyMatrix &B) { + return A.value + B.value; +} + +void test_add_template() { + // CHECK-LABEL: define void @_Z17test_add_templatev() + // CHECK: %call = call <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat2) + + // CHECK-LABEL: define linkonce_odr <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_( + // CHECK: [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK: [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = fadd <10 x float> [[MAT1]], [[MAT2]] + // CHECK-NEXT: ret <10 x float> [[RES]] + + MyMatrix Mat1; + MyMatrix Mat2; + Mat1.value = add(Mat1, Mat2); +} + +template +typename MyMatrix::matrix_t subtract(MyMatrix &A, MyMatrix &B) { + return A.value - B.value; +} + +void test_subtract_template() { + // CHECK-LABEL: define void @_Z22test_subtract_templatev() + // CHECK: %call = call <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat2) + + // CHECK-LABEL: define linkonce_odr <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_( + // CHECK: [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK: [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = fsub <10 x float> [[MAT1]], [[MAT2]] + // CHECK-NEXT: ret <10 x float> [[RES]] + + MyMatrix Mat1; + MyMatrix Mat2; + Mat1.value = subtract(Mat1, Mat2); +} + +struct DoubleWrapper1 { + int x; + operator double() { + return x; + } +}; + +void test_DoubleWrapper1_Sub1(MyMatrix &m) { + // CHECK-LABEL: define void @_Z24test_DoubleWrapper1_Sub1R8MyMatrixIdLj10ELj9EE( + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK: [[SCALAR:%.*]] = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1) + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fsub <90 x double> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + + DoubleWrapper1 w1; + w1.x = 10; + m.value = m.value - w1; +} + +void test_DoubleWrapper1_Sub2(MyMatrix &m) { + // CHECK-LABEL: define void @_Z24test_DoubleWrapper1_Sub2R8MyMatrixIdLj10ELj9EE( + // CHECK: [[SCALAR:%.*]] = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1) + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fsub <90 x double> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + + DoubleWrapper1 w1; + w1.x = 10; + m.value = w1 - m.value; +} + +struct DoubleWrapper2 { + int x; + operator double() { + return x; + } +}; + +void test_DoubleWrapper2_Add1(MyMatrix &m) { + // CHECK-LABEL: define void @_Z24test_DoubleWrapper2_Add1R8MyMatrixIdLj10ELj9EE( + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* %1, align 8 + // CHECK: [[SCALAR:%.*]] = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2) + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fadd <90 x double> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + + DoubleWrapper2 w2; + w2.x = 20; + m.value = m.value + w2; +} + +void test_DoubleWrapper2_Add2(MyMatrix &m) { + // CHECK-LABEL: define void @_Z24test_DoubleWrapper2_Add2R8MyMatrixIdLj10ELj9EE( + // CHECK: [[SCALAR:%.*]] = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2) + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* %1, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fadd <90 x double> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + + DoubleWrapper2 w2; + w2.x = 20; + m.value = w2 + m.value; +} + +struct IntWrapper { + char x; + operator int() { + return x; + } +}; + +void test_IntWrapper_Add(MyMatrix &m) { + // CHECK-LABEL: define void @_Z19test_IntWrapper_AddR8MyMatrixIdLj10ELj9EE( + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK: [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3) + // CHECK: [[SCALAR_FP:%.*]] = sitofp i32 %call to double + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fadd <90 x double> [[MATRIX]], [[SCALAR_EMBED1]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + + IntWrapper w3; + w3.x = 'c'; + m.value = m.value + w3; +} + +void test_IntWrapper_Sub(MyMatrix &m) { + // CHECK-LABEL: define void @_Z19test_IntWrapper_SubR8MyMatrixIdLj10ELj9EE( + // CHECK: [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3) + // CHECK-NEXT: [[SCALAR_FP:%.*]] = sitofp i32 %call to double + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fsub <90 x double> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + + IntWrapper w3; + w3.x = 'c'; + m.value = w3 - m.value; +} diff --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c new file mode 100644 --- /dev/null +++ b/clang/test/Sema/matrix-type-operators.c @@ -0,0 +1,33 @@ +// RUN: %clang_cc1 %s -fenable-matrix -pedantic -verify -triple=x86_64-apple-darwin9 + +typedef float sx5x10_t __attribute__((matrix_type(5, 10))); +typedef float sx10x5_t __attribute__((matrix_type(10, 5))); +typedef float sx10x10_t __attribute__((matrix_type(10, 10))); + +void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + a = b + c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}} + + a = b + b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = 10 + b; + // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = b + &c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} + // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} +} + +void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + a = b - c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}} + + a = b - b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = 10 - b; + // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = b - &c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} + // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} +} diff --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp new file mode 100644 --- /dev/null +++ b/clang/test/SemaCXX/matrix-type-operators.cpp @@ -0,0 +1,93 @@ +// RUN: %clang_cc1 %s -fenable-matrix -pedantic -std=c++11 -verify -triple=x86_64-apple-darwin9 + +typedef float sx5x10_t __attribute__((matrix_type(5, 10))); + +template +struct MyMatrix { + using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns))); + + matrix_t value; +}; + +template +typename MyMatrix::matrix_t add(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value + B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + + return A.value + B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} +} + +void test_add_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = add(Mat1, Mat1); + // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-note@-2 {{in instantiation of function template specialization 'add' requested here}} + + Mat1.value = add(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'add' requested here}} + + Mat1.value = add(Mat2, Mat3); + // expected-note@-1 {{in instantiation of function template specialization 'add' requested here}} +} + +template +typename MyMatrix::matrix_t subtract(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value - B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}} + + return A.value - B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}} +} + +void test_subtract_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = subtract(Mat1, Mat1); + // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-note@-2 {{in instantiation of function template specialization 'subtract' requested here}} + + Mat1.value = subtract(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} + + Mat1.value = subtract(Mat2, Mat3); + // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} +} + +struct UserT {}; + +struct StructWithC { + operator UserT() { + // expected-note@-1 4 {{candidate function}} + return {}; + } +}; + +void test_DoubleWrapper(MyMatrix &m, StructWithC &c) { + m.value = m.value + c; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}} + + m.value = c + m.value; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}} + + m.value = m.value - c; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}} + + m.value = c - m.value; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}} +} diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -127,6 +127,16 @@ /// Add matrixes \p LHS and \p RHS. Support both integer and floating point /// matrixes. Value *CreateAdd(Value *LHS, Value *RHS) { + assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy() @@ -137,6 +147,16 @@ /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating /// point matrixes. Value *CreateSub(Value *LHS, Value *RHS) { + assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy()