diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11210,6 +11210,11 @@ QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc); + /// Type checking for matrix binary operators. + QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign); + bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -3554,6 +3554,11 @@ } } + if (op.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder MB(Builder); + return MB.CreateAdd(op.LHS, op.RHS); + } + if (op.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), op)) @@ -3738,6 +3743,11 @@ } } + if (op.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder MB(Builder); + return MB.CreateSub(op.LHS, op.RHS); + } + if (op.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), op)) diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -10304,6 +10304,11 @@ return compType; } + if (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType()) { + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy); + } + QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic); if (LHS.isInvalid() || RHS.isInvalid()) @@ -10399,6 +10404,11 @@ return compType; } + if (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType()) { + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, CompLHSTy); + } + QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, CompLHSTy ? ACK_CompAssign : ACK_Arithmetic); if (LHS.isInvalid() || RHS.isInvalid()) @@ -11994,6 +12004,71 @@ return GetSignedVectorType(LHS.get()->getType()); } +static bool tryConvertScalarToMatrixElementTy(Sema &S, QualType ElementType, + ExprResult *Scalar) { + QualType ScalarTy = Scalar->get()->getType().getUnqualifiedType(); + if (!ScalarTy->isArithmeticType()) { + + InitializedEntity Entity = + InitializedEntity::InitializeTemporary(ElementType); + InitializationKind Kind = InitializationKind::CreateCopy( + Scalar->get()->getBeginLoc(), SourceLocation()); + Expr *Arg = Scalar->get(); + InitializationSequence InitSeq(S, Entity, Kind, Arg); + *Scalar = InitSeq.Perform(S, Entity, Kind, Arg); + return !Scalar->isInvalid(); + } + + CastKind CK = S.PrepareScalarCast(*Scalar, ElementType); + *Scalar = S.ImpCastExprToType(Scalar->get(), ElementType, CK); + return true; +} + +QualType Sema::CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign) { + if (!IsCompAssign) { + LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); + if (LHS.isInvalid()) + return QualType(); + } + RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); + if (RHS.isInvalid()) + return QualType(); + + // For conversion purposes, we ignore any qualifiers. + // For example, "const float" and "float" are equivalent. + QualType LHSType = LHS.get()->getType().getUnqualifiedType(); + QualType RHSType = RHS.get()->getType().getUnqualifiedType(); + + const MatrixType *LHSMatType = LHSType->getAs(); + const MatrixType *RHSMatType = RHSType->getAs(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (Context.hasSameType(LHSType, RHSType)) + return LHSType; + + // Type conversion may change LHS/RHS. Keep copies to the original results, in + // case we have to return InvalidOperands. + ExprResult OriginalLHS = LHS; + ExprResult OriginalRHS = RHS; + if (LHSMatType && !RHSMatType) { + if (tryConvertScalarToMatrixElementTy(*this, LHSMatType->getElementType(), + &RHS)) + return LHSType; + return InvalidOperands(Loc, OriginalLHS, OriginalRHS); + } + + if (!LHSMatType && RHSMatType) { + if (tryConvertScalarToMatrixElementTy(*this, RHSMatType->getElementType(), + &LHS)) + return RHSType; + return InvalidOperands(Loc, OriginalLHS, OriginalRHS); + } + + return InvalidOperands(Loc, LHS, RHS); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -7687,6 +7687,10 @@ /// candidates. TypeSet VectorTypes; + /// The set of matrix types that will be used in the built-in + /// candidates. + TypeSet MatrixTypes; + /// A flag indicating non-record types are viable candidates bool HasNonRecordTypes; @@ -7747,6 +7751,10 @@ iterator vector_begin() { return VectorTypes.begin(); } iterator vector_end() { return VectorTypes.end(); } + llvm::iterator_range matrix_types() { return MatrixTypes; } + iterator matrix_begin() { return MatrixTypes.begin(); } + iterator matrix_end() { return MatrixTypes.end(); } + bool hasNonRecordTypes() { return HasNonRecordTypes; } bool hasArithmeticOrEnumeralTypes() { return HasArithmeticOrEnumeralTypes; } bool hasNullPtrType() const { return HasNullPtrType; } @@ -7921,6 +7929,11 @@ // extension. HasArithmeticOrEnumeralTypes = true; VectorTypes.insert(Ty); + } else if (Ty->isMatrixType()) { + // Similar to vector types, we treat vector types as arithmetic types in + // many contexts as an extension. + HasArithmeticOrEnumeralTypes = true; + MatrixTypes.insert(Ty); } else if (Ty->isNullPtrType()) { HasNullPtrType = true; } else if (AllowUserConversions && TyRec) { @@ -8541,30 +8554,42 @@ if (!HasArithmeticOrEnumeralCandidateType) return; + auto AddCandidate = [this](QualType L, QualType R) { + QualType LandR[2] = {L, R}; + S.AddBuiltinCandidate(LandR, Args, CandidateSet); + }; for (unsigned Left = FirstPromotedArithmeticType; - Left < LastPromotedArithmeticType; ++Left) { + Left < LastPromotedArithmeticType; ++Left) for (unsigned Right = FirstPromotedArithmeticType; - Right < LastPromotedArithmeticType; ++Right) { - QualType LandR[2] = { ArithmeticTypes[Left], - ArithmeticTypes[Right] }; - S.AddBuiltinCandidate(LandR, Args, CandidateSet); - } - } + Right < LastPromotedArithmeticType; ++Right) + AddCandidate(ArithmeticTypes[Left], ArithmeticTypes[Right]); // Extension: Add the binary operators ==, !=, <, <=, >=, >, *, /, and the // conditional operator for vector types. for (BuiltinCandidateTypeSet::iterator - Vec1 = CandidateTypes[0].vector_begin(), - Vec1End = CandidateTypes[0].vector_end(); - Vec1 != Vec1End; ++Vec1) { + Vec1 = CandidateTypes[0].vector_begin(), + Vec1End = CandidateTypes[0].vector_end(); + Vec1 != Vec1End; ++Vec1) for (BuiltinCandidateTypeSet::iterator - Vec2 = CandidateTypes[1].vector_begin(), - Vec2End = CandidateTypes[1].vector_end(); - Vec2 != Vec2End; ++Vec2) { - QualType LandR[2] = { *Vec1, *Vec2 }; - S.AddBuiltinCandidate(LandR, Args, CandidateSet); - } - } + Vec2 = CandidateTypes[1].vector_begin(), + Vec2End = CandidateTypes[1].vector_end(); + Vec2 != Vec2End; ++Vec2) + AddCandidate(*Vec1, *Vec2); + + // Extension: Add following the binary operators overloads for each + // candidate type M1, M2: + // * (M1, M2) -> M1, if M1 == M2 + // * (M1, M1.getElementType()) -> M1 + // * (M2.getElementType(), M2) -> M2 + for (const QualType &M1 : CandidateTypes[0].matrix_types()) { + for (const QualType &M2 : CandidateTypes[1].matrix_types()) + if (S.Context.hasSameType(M1, M2)) + AddCandidate(M1, M2); + + AddCandidate(M1, cast(M1)->getElementType()); + } + for (const QualType &M2 : CandidateTypes[1].matrix_types()) + AddCandidate(cast(M2)->getElementType(), M2); } // C++2a [over.built]p14: diff --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c --- a/clang/test/CodeGen/matrix-type-operators.c +++ b/clang/test/CodeGen/matrix-type-operators.c @@ -312,3 +312,311 @@ // CHECK-NEXT: ret void b[2][j] = b[0][k]; } + +void add_matrix_matrix(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) { + a = b + c; + ai = bi + ci; + + // CHECK-LABEL: @add_matrix_matrix( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %ai.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %bi.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %ci.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %b.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %b, <25 x double>* %1, align 8 + // CHECK-NEXT: %2 = bitcast [25 x double]* %c.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %c, <25 x double>* %2, align 8 + // CHECK-NEXT: %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %ai, <27 x i32>* %3, align 4 + // CHECK-NEXT: %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %bi, <27 x i32>* %4, align 4 + // CHECK-NEXT: %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %ci, <27 x i32>* %5, align 4 + // CHECK-NEXT: %6 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %7 = load <25 x double>, <25 x double>* %2, align 8 + // CHECK-NEXT: %8 = fadd <25 x double> %6, %7 + // CHECK-NEXT: store <25 x double> %8, <25 x double>* %0, align 8 + // CHECK-NEXT: %9 = load <27 x i32>, <27 x i32>* %4, align 4 + // CHECK-NEXT: %10 = load <27 x i32>, <27 x i32>* %5, align 4 + // CHECK-NEXT: %11 = add <27 x i32> %9, %10 + // CHECK-NEXT: store <27 x i32> %11, <27 x i32>* %3, align 4 + // CHECK-NEXT: ret void +} + +void add_matrix_scalar_float(dx5x5_t a, fx2x3_t b, float vf, double vd) { + a = a + vf; + a = a + vd; + + // CHECK-LABEL: define void @add_matrix_scalar_float(<25 x double> %a, <6 x float> %b, float %vf, double %vd) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %b.addr = alloca [6 x float], align 4 + // CHECK-NEXT: %vf.addr = alloca float, align 4 + // CHECK-NEXT: %vd.addr = alloca double, align 8 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: %1 = bitcast [6 x float]* %b.addr to <6 x float>* + // CHECK-NEXT: store <6 x float> %b, <6 x float>* %1, align 4 + // CHECK-NEXT: store float %vf, float* %vf.addr, align 4 + // CHECK-NEXT: store double %vd, double* %vd.addr, align 8 + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %3 = load float, float* %vf.addr, align 4 + // CHECK-NEXT: %conv = fpext float %3 to double + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fadd <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %0, align 8 + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %6 = load double, double* %vd.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %6, i32 0 + // CHECK-NEXT: %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %7 = fadd <25 x double> %5, %scalar.splat.splat2 + // CHECK-NEXT: store <25 x double> %7, <25 x double>* %0, align 8 + + b = b + vf; + b = b + vd; + + // CHECK-NEXT: %8 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %9 = load float, float* %vf.addr, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert3 = insertelement <6 x float> undef, float %9, i32 0 + // CHECK-NEXT: %scalar.splat.splat4 = shufflevector <6 x float> %scalar.splat.splatinsert3, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %10 = fadd <6 x float> %8, %scalar.splat.splat4 + // CHECK-NEXT: store <6 x float> %10, <6 x float>* %1, align 4 + // CHECK-NEXT: %11 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %12 = load double, double* %vd.addr, align 8 + // CHECK-NEXT: %conv5 = fptrunc double %12 to float + // CHECK-NEXT: %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %conv5, i32 0 + // CHECK-NEXT: %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %13 = fadd <6 x float> %11, %scalar.splat.splat7 + // CHECK-NEXT: store <6 x float> %13, <6 x float>* %1, align 4 + // CHECK-NEXT: ret void +} + +typedef int llix9x3_t __attribute__((matrix_type(9, 3))); + +void add_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) { + a = a + vs; + a = a + vli; + a = a + vulli; + + // CHECK-LABEL: define void @add_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %b.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %vs.addr = alloca i16, align 2 + // CHECK-NEXT: %vli.addr = alloca i64, align 8 + // CHECK-NEXT: %vulli.addr = alloca i64, align 8 + // CHECK-NEXT: %0 = bitcast [27 x i32]* %a.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %a, <27 x i32>* %0, align 4 + // CHECK-NEXT: %1 = bitcast [27 x i32]* %b.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %b, <27 x i32>* %1, align 4 + // CHECK-NEXT: store i16 %vs, i16* %vs.addr, align 2 + // CHECK-NEXT: store i64 %vli, i64* %vli.addr, align 8 + // CHECK-NEXT: store i64 %vulli, i64* %vulli.addr, align 8 + // CHECK-NEXT: %2 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %3 = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: %conv = sext i16 %3 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %4 = add <27 x i32> %2, %scalar.splat.splat + // CHECK-NEXT: store <27 x i32> %4, <27 x i32>* %0, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %6 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv1 = trunc i64 %6 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0 + // CHECK-NEXT: %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %7 = add <27 x i32> %5, %scalar.splat.splat3 + // CHECK-NEXT: store <27 x i32> %7, <27 x i32>* %0, align 4 + // CHECK-NEXT: %8 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %9 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv4 = trunc i64 %9 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0 + // CHECK-NEXT: %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %10 = add <27 x i32> %8, %scalar.splat.splat6 + // CHECK-NEXT: store <27 x i32> %10, <27 x i32>* %0, align 4 + // CHECK-NEXT: %11 = load i16, i16* %vs.addr, align 2 + + b = vs + b; + b = vli + b; + b = vulli + b; + + // CHECK-NEXT: %conv7 = sext i16 %11 to i32 + // CHECK-NEXT: %12 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0 + // CHECK-NEXT: %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %13 = add <27 x i32> %scalar.splat.splat9, %12 + // CHECK-NEXT: store <27 x i32> %13, <27 x i32>* %1, align 4 + // CHECK-NEXT: %14 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv10 = trunc i64 %14 to i32 + // CHECK-NEXT: %15 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0 + // CHECK-NEXT: %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %16 = add <27 x i32> %scalar.splat.splat12, %15 + // CHECK-NEXT: store <27 x i32> %16, <27 x i32>* %1, align 4 + // CHECK-NEXT: %17 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv13 = trunc i64 %17 to i32 + // CHECK-NEXT: %18 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0 + // CHECK-NEXT: %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %19 = add <27 x i32> %scalar.splat.splat15, %18 + // CHECK-NEXT: store <27 x i32> %19, <27 x i32>* %1, align 4 + // CHECK-NEXT: ret void +} + +void sub_matrix_matrix(dx5x5_t a, dx5x5_t b, dx5x5_t c, ix9x3_t ai, ix9x3_t bi, ix9x3_t ci) { + a = b - c; + ai = bi - ci; + + // CHECK-LABEL: @sub_matrix_matrix( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %ai.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %bi.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %ci.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %b.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %b, <25 x double>* %1, align 8 + // CHECK-NEXT: %2 = bitcast [25 x double]* %c.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %c, <25 x double>* %2, align 8 + // CHECK-NEXT: %3 = bitcast [27 x i32]* %ai.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %ai, <27 x i32>* %3, align 4 + // CHECK-NEXT: %4 = bitcast [27 x i32]* %bi.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %bi, <27 x i32>* %4, align 4 + // CHECK-NEXT: %5 = bitcast [27 x i32]* %ci.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %ci, <27 x i32>* %5, align 4 + // CHECK-NEXT: %6 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %7 = load <25 x double>, <25 x double>* %2, align 8 + // CHECK-NEXT: %8 = fsub <25 x double> %6, %7 + // CHECK-NEXT: store <25 x double> %8, <25 x double>* %0, align 8 + // CHECK-NEXT: %9 = load <27 x i32>, <27 x i32>* %4, align 4 + // CHECK-NEXT: %10 = load <27 x i32>, <27 x i32>* %5, align 4 + // CHECK-NEXT: %11 = sub <27 x i32> %9, %10 + // CHECK-NEXT: store <27 x i32> %11, <27 x i32>* %3, align 4 + // CHECK-NEXT: ret void +} + +void sub_matrix_scalar_float(dx5x5_t a, fx2x3_t b, float vf, double vd) { + a = a - vf; + a = a - vd; + + // CHECK-LABEL: define void @sub_matrix_scalar_float(<25 x double> %a, <6 x float> %b, float %vf, double %vd) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %b.addr = alloca [6 x float], align 4 + // CHECK-NEXT: %vf.addr = alloca float, align 4 + // CHECK-NEXT: %vd.addr = alloca double, align 8 + // CHECK-NEXT: %0 = bitcast [25 x double]* %a.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %a, <25 x double>* %0, align 8 + // CHECK-NEXT: %1 = bitcast [6 x float]* %b.addr to <6 x float>* + // CHECK-NEXT: store <6 x float> %b, <6 x float>* %1, align 4 + // CHECK-NEXT: store float %vf, float* %vf.addr, align 4 + // CHECK-NEXT: store double %vd, double* %vd.addr, align 8 + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %3 = load float, float* %vf.addr, align 4 + // CHECK-NEXT: %conv = fpext float %3 to double + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fsub <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %0, align 8 + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %0, align 8 + // CHECK-NEXT: %6 = load double, double* %vd.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %6, i32 0 + // CHECK-NEXT: %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %7 = fsub <25 x double> %5, %scalar.splat.splat2 + // CHECK-NEXT: store <25 x double> %7, <25 x double>* %0, align 8 + + b = b - vf; + b = b - vd; + + // CHECK-NEXT: %8 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %9 = load float, float* %vf.addr, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert3 = insertelement <6 x float> undef, float %9, i32 0 + // CHECK-NEXT: %scalar.splat.splat4 = shufflevector <6 x float> %scalar.splat.splatinsert3, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %10 = fsub <6 x float> %8, %scalar.splat.splat4 + // CHECK-NEXT: store <6 x float> %10, <6 x float>* %1, align 4 + // CHECK-NEXT: %11 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %12 = load double, double* %vd.addr, align 8 + // CHECK-NEXT: %conv5 = fptrunc double %12 to float + // CHECK-NEXT: %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %conv5, i32 0 + // CHECK-NEXT: %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %13 = fsub <6 x float> %11, %scalar.splat.splat7 + // CHECK-NEXT: store <6 x float> %13, <6 x float>* %1, align 4 + // CHECK-NEXT: ret void +} + +void sub_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) { + a = a - vs; + a = a - vli; + a = a - vulli; + + // CHECK-LABEL: define void @sub_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %b.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %vs.addr = alloca i16, align 2 + // CHECK-NEXT: %vli.addr = alloca i64, align 8 + // CHECK-NEXT: %vulli.addr = alloca i64, align 8 + // CHECK-NEXT: %0 = bitcast [27 x i32]* %a.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %a, <27 x i32>* %0, align 4 + // CHECK-NEXT: %1 = bitcast [27 x i32]* %b.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %b, <27 x i32>* %1, align 4 + // CHECK-NEXT: store i16 %vs, i16* %vs.addr, align 2 + // CHECK-NEXT: store i64 %vli, i64* %vli.addr, align 8 + // CHECK-NEXT: store i64 %vulli, i64* %vulli.addr, align 8 + // CHECK-NEXT: %2 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %3 = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: %conv = sext i16 %3 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %4 = sub <27 x i32> %2, %scalar.splat.splat + // CHECK-NEXT: store <27 x i32> %4, <27 x i32>* %0, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %6 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv1 = trunc i64 %6 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0 + // CHECK-NEXT: %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %7 = sub <27 x i32> %5, %scalar.splat.splat3 + // CHECK-NEXT: store <27 x i32> %7, <27 x i32>* %0, align 4 + // CHECK-NEXT: %8 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %9 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv4 = trunc i64 %9 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0 + // CHECK-NEXT: %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %10 = sub <27 x i32> %8, %scalar.splat.splat6 + // CHECK-NEXT: store <27 x i32> %10, <27 x i32>* %0, align 4 + + b = vs - b; + b = vli - b; + b = vulli - b; + + // CHECK-NEXT: %11 = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: %conv7 = sext i16 %11 to i32 + // CHECK-NEXT: %12 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0 + // CHECK-NEXT: %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %13 = sub <27 x i32> %scalar.splat.splat9, %12 + // CHECK-NEXT: store <27 x i32> %13, <27 x i32>* %1, align 4 + // CHECK-NEXT: %14 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv10 = trunc i64 %14 to i32 + // CHECK-NEXT: %15 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0 + // CHECK-NEXT: %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %16 = sub <27 x i32> %scalar.splat.splat12, %15 + // CHECK-NEXT: store <27 x i32> %16, <27 x i32>* %1, align 4 + // CHECK-NEXT: %17 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv13 = trunc i64 %17 to i32 + // CHECK-NEXT: %18 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0 + // CHECK-NEXT: %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %19 = sub <27 x i32> %scalar.splat.splat15, %18 + // CHECK-NEXT: store <27 x i32> %19, <27 x i32>* %1, align 4 + // CHECK-NEXT: ret void +} diff --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp --- a/clang/test/CodeGenCXX/matrix-type-operators.cpp +++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -242,3 +242,252 @@ return matrix_subscript(m, 1, 2); } + +template +typename MyMatrix::matrix_t add(MyMatrix &A, MyMatrix &B) { + return A.value + B.value; +} + +void test_add_template() { + // CHECK-LABEL: define void @_Z17test_add_templatev() + // CHECK-NEXT: entry: + // CHECK-NEXT: %Mat1 = alloca %struct.MyMatrix.1, align 4 + // CHECK-NEXT: %Mat2 = alloca %struct.MyMatrix.1, align 4 + // CHECK-NEXT: %call = call <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat2) + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0 + // CHECK-NEXT: %0 = bitcast [10 x float]* %value to <10 x float>* + // CHECK-NEXT: store <10 x float> %call, <10 x float>* %0, align 4 + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr <10 x float> @_Z3addIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %A, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %B) + // CHECK-NEXT: entry: + // CHECK-NEXT: %A.addr = alloca %struct.MyMatrix.1*, align 8 + // CHECK-NEXT: %B.addr = alloca %struct.MyMatrix.1*, align 8 + // CHECK-NEXT: store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8 + // CHECK-NEXT: store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8 + // CHECK-NEXT: %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0 + // CHECK-NEXT: %1 = bitcast [10 x float]* %value to <10 x float>* + // CHECK-NEXT: %2 = load <10 x float>, <10 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8 + // CHECK-NEXT: %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0 + // CHECK-NEXT: %4 = bitcast [10 x float]* %value1 to <10 x float>* + // CHECK-NEXT: %5 = load <10 x float>, <10 x float>* %4, align 4 + // CHECK-NEXT: %6 = fadd <10 x float> %2, %5 + // CHECK-NEXT: ret <10 x float> %6 + + MyMatrix Mat1; + MyMatrix Mat2; + Mat1.value = add(Mat1, Mat2); +} + +template +typename MyMatrix::matrix_t subtract(MyMatrix &A, MyMatrix &B) { + return A.value - B.value; +} + +void test_subtract_template() { + // CHECK-LABEL: define void @_Z22test_subtract_templatev() + // CHECK-NEXT: entry: + // CHECK-NEXT: %Mat1 = alloca %struct.MyMatrix.1, align 4 + // CHECK-NEXT: %Mat2 = alloca %struct.MyMatrix.1, align 4 + // CHECK-NEXT: %call = call <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %Mat2) + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %Mat1, i32 0, i32 0 + // CHECK-NEXT: %0 = bitcast [10 x float]* %value to <10 x float>* + // CHECK-NEXT: store <10 x float> %call, <10 x float>* %0, align 4 + // CHECK-NEXT: ret void + + // CHECK-LABEL: define linkonce_odr <10 x float> @_Z8subtractIfLj2ELj5EEN8MyMatrixIT_XT0_EXT1_EE8matrix_tERS2_S4_(%struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %A, %struct.MyMatrix.1* nonnull align 4 dereferenceable(40) %B) + // CHECK-NEXT: entry: + // CHECK-NEXT: %A.addr = alloca %struct.MyMatrix.1*, align 8 + // CHECK-NEXT: %B.addr = alloca %struct.MyMatrix.1*, align 8 + // CHECK-NEXT: store %struct.MyMatrix.1* %A, %struct.MyMatrix.1** %A.addr, align 8 + // CHECK-NEXT: store %struct.MyMatrix.1* %B, %struct.MyMatrix.1** %B.addr, align 8 + // CHECK-NEXT: %0 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %A.addr, align 8 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %0, i32 0, i32 0 + // CHECK-NEXT: %1 = bitcast [10 x float]* %value to <10 x float>* + // CHECK-NEXT: %2 = load <10 x float>, <10 x float>* %1, align 4 + // CHECK-NEXT: %3 = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %B.addr, align 8 + // CHECK-NEXT: %value1 = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %3, i32 0, i32 0 + // CHECK-NEXT: %4 = bitcast [10 x float]* %value1 to <10 x float>* + // CHECK-NEXT: %5 = load <10 x float>, <10 x float>* %4, align 4 + // CHECK-NEXT: %6 = fsub <10 x float> %2, %5 + // CHECK-NEXT: ret <10 x float> %6 + + MyMatrix Mat1; + MyMatrix Mat2; + Mat1.value = subtract(Mat1, Mat2); +} + +struct DoubleWrapper1 { + int x; + operator double() { + return x; + } +}; + +struct DoubleWrapper2 { + int x; + operator double() { + return x; + } +}; + +struct IntWrapper { + char x; + operator int() { + return x; + } +}; + +void test_DoubleWrapper(MyMatrix &m, MyMatrix &m2) { + // CHECK-LABEL: define void @_Z18test_DoubleWrapperR8MyMatrixIdLj10ELj9EERS_IiLj3ELj4EE(%struct.MyMatrix.2* nonnull align 8 dereferenceable(720) %m, %struct.MyMatrix.3* nonnull align 4 dereferenceable(48) %m2) + // CHECK-NEXT: entry: + // CHECK-NEXT: %m.addr = alloca %struct.MyMatrix.2*, align 8 + // CHECK-NEXT: %m2.addr = alloca %struct.MyMatrix.3*, align 8 + // CHECK-NEXT: %w1 = alloca %struct.DoubleWrapper1, align 4 + // CHECK-NEXT: %w2 = alloca %struct.DoubleWrapper2, align 4 + // CHECK-NEXT: %w3 = alloca %struct.IntWrapper, align 1 + // CHECK-NEXT: store %struct.MyMatrix.2* %m, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: store %struct.MyMatrix.3* %m2, %struct.MyMatrix.3** %m2.addr, align 8 + // CHECK-NEXT: %x = getelementptr inbounds %struct.DoubleWrapper1, %struct.DoubleWrapper1* %w1, i32 0, i32 0 + // CHECK-NEXT: store i32 10, i32* %x, align 4 + // CHECK-NEXT: %0 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %0, i32 0, i32 0 + // CHECK-NEXT: %1 = bitcast [90 x double]* %value to <90 x double>* + // CHECK-NEXT: %2 = load <90 x double>, <90 x double>* %1, align 8 + // CHECK-NEXT: %call = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1) + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <90 x double> undef, double %call, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <90 x double> %scalar.splat.splatinsert, <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: %3 = fadd <90 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %4 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value1 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %4, i32 0, i32 0 + // CHECK-NEXT: %5 = bitcast [90 x double]* %value1 to <90 x double>* + // CHECK-NEXT: store <90 x double> %3, <90 x double>* %5, align 8 + // CHECK-NEXT: %call2 = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1) + // CHECK-NEXT: %6 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value3 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %6, i32 0, i32 0 + // CHECK-NEXT: %7 = bitcast [90 x double]* %value3 to <90 x double>* + // CHECK-NEXT: %8 = load <90 x double>, <90 x double>* %7, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert4 = insertelement <90 x double> undef, double %call2, i32 0 + // CHECK-NEXT: %scalar.splat.splat5 = shufflevector <90 x double> %scalar.splat.splatinsert4, <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: %9 = fadd <90 x double> %scalar.splat.splat5, %8 + // CHECK-NEXT: %10 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value6 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %10, i32 0, i32 0 + // CHECK-NEXT: %11 = bitcast [90 x double]* %value6 to <90 x double>* + // CHECK-NEXT: store <90 x double> %9, <90 x double>* %11, align 8 + // CHECK-NEXT: %call7 = call double @_ZN14DoubleWrapper1cvdEv(%struct.DoubleWrapper1* %w1) + // CHECK-NEXT: %12 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value8 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %12, i32 0, i32 0 + // CHECK-NEXT: %13 = bitcast [90 x double]* %value8 to <90 x double>* + // CHECK-NEXT: %14 = load <90 x double>, <90 x double>* %13, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert9 = insertelement <90 x double> undef, double %call7, i32 0 + // CHECK-NEXT: %scalar.splat.splat10 = shufflevector <90 x double> %scalar.splat.splatinsert9, <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: %15 = fsub <90 x double> %scalar.splat.splat10, %14 + // CHECK-NEXT: %16 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value11 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %16, i32 0, i32 0 + // CHECK-NEXT: %17 = bitcast [90 x double]* %value11 to <90 x double>* + // CHECK-NEXT: store <90 x double> %15, <90 x double>* %17, align 8 + // CHECK-NEXT: %x12 = getelementptr inbounds %struct.DoubleWrapper2, %struct.DoubleWrapper2* %w2, i32 0, i32 0 + // CHECK-NEXT: store i32 20, i32* %x12, align 4 + + DoubleWrapper1 w1; + w1.x = 10; + m.value = m.value + w1; + m.value = w1 + m.value; + m.value = w1 - m.value; + + // CHECK-NEXT: %18 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value13 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %18, i32 0, i32 0 + // CHECK-NEXT: %19 = bitcast [90 x double]* %value13 to <90 x double>* + // CHECK-NEXT: %20 = load <90 x double>, <90 x double>* %19, align 8 + // CHECK-NEXT: %call14 = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2) + // CHECK-NEXT: %scalar.splat.splatinsert15 = insertelement <90 x double> undef, double %call14, i32 0 + // CHECK-NEXT: %scalar.splat.splat16 = shufflevector <90 x double> %scalar.splat.splatinsert15, <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: %21 = fadd <90 x double> %20, %scalar.splat.splat16 + // CHECK-NEXT: %22 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value17 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %22, i32 0, i32 0 + // CHECK-NEXT: %23 = bitcast [90 x double]* %value17 to <90 x double>* + // CHECK-NEXT: store <90 x double> %21, <90 x double>* %23, align 8 + // CHECK-NEXT: %call18 = call double @_ZN14DoubleWrapper2cvdEv(%struct.DoubleWrapper2* %w2) + // CHECK-NEXT: %24 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value19 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %24, i32 0, i32 0 + // CHECK-NEXT: %25 = bitcast [90 x double]* %value19 to <90 x double>* + // CHECK-NEXT: %26 = load <90 x double>, <90 x double>* %25, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert20 = insertelement <90 x double> undef, double %call18, i32 0 + // CHECK-NEXT: %scalar.splat.splat21 = shufflevector <90 x double> %scalar.splat.splatinsert20, <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: %27 = fadd <90 x double> %scalar.splat.splat21, %26 + // CHECK-NEXT: %28 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value22 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %28, i32 0, i32 0 + // CHECK-NEXT: %29 = bitcast [90 x double]* %value22 to <90 x double>* + // CHECK-NEXT: store <90 x double> %27, <90 x double>* %29, align 8 + DoubleWrapper2 w2; + w2.x = 20; + m.value = m.value + w2; + m.value = w2 + m.value; + + // CHECK-NEXT: %x23 = getelementptr inbounds %struct.IntWrapper, %struct.IntWrapper* %w3, i32 0, i32 0 + // CHECK-NEXT: store i8 99, i8* %x23, align 1 + // CHECK-NEXT: %30 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8 + // CHECK-NEXT: %value24 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %30, i32 0, i32 0 + // CHECK-NEXT: %31 = bitcast [12 x i32]* %value24 to <12 x i32>* + // CHECK-NEXT: %32 = load <12 x i32>, <12 x i32>* %31, align 4 + // CHECK-NEXT: %call25 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3) + // CHECK-NEXT: %scalar.splat.splatinsert26 = insertelement <12 x i32> undef, i32 %call25, i32 0 + // CHECK-NEXT: %scalar.splat.splat27 = shufflevector <12 x i32> %scalar.splat.splatinsert26, <12 x i32> undef, <12 x i32> zeroinitializer + // CHECK-NEXT: %33 = add <12 x i32> %32, %scalar.splat.splat27 + // CHECK-NEXT: %34 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8 + // CHECK-NEXT: %value28 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %34, i32 0, i32 0 + // CHECK-NEXT: %35 = bitcast [12 x i32]* %value28 to <12 x i32>* + // CHECK-NEXT: store <12 x i32> %33, <12 x i32>* %35, align 4 + // CHECK-NEXT: %call29 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3) + // CHECK-NEXT: %36 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8 + // CHECK-NEXT: %value30 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %36, i32 0, i32 0 + // CHECK-NEXT: %37 = bitcast [12 x i32]* %value30 to <12 x i32>* + // CHECK-NEXT: %38 = load <12 x i32>, <12 x i32>* %37, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert31 = insertelement <12 x i32> undef, i32 %call29, i32 0 + // CHECK-NEXT: %scalar.splat.splat32 = shufflevector <12 x i32> %scalar.splat.splatinsert31, <12 x i32> undef, <12 x i32> zeroinitializer + // CHECK-NEXT: %39 = add <12 x i32> %scalar.splat.splat32, %38 + // CHECK-NEXT: %40 = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %m2.addr, align 8 + // CHECK-NEXT: %value33 = getelementptr inbounds %struct.MyMatrix.3, %struct.MyMatrix.3* %40, i32 0, i32 0 + // CHECK-NEXT: %41 = bitcast [12 x i32]* %value33 to <12 x i32>* + // CHECK-NEXT: store <12 x i32> %39, <12 x i32>* %41, align 4 + + IntWrapper w3; + w3.x = 'c'; + m2.value = m2.value + w3; + m2.value = w3 + m2.value; + + // int conversion function in struct and implicit cast to element type double. + // CHECK-NEXT: %42 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value34 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %42, i32 0, i32 0 + // CHECK-NEXT: %43 = bitcast [90 x double]* %value34 to <90 x double>* + // CHECK-NEXT: %44 = load <90 x double>, <90 x double>* %43, align 8 + // CHECK-NEXT: %call35 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3) + // CHECK-NEXT: %conv = sitofp i32 %call35 to double + // CHECK-NEXT: %scalar.splat.splatinsert36 = insertelement <90 x double> undef, double %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat37 = shufflevector <90 x double> %scalar.splat.splatinsert36, <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: %45 = fsub <90 x double> %44, %scalar.splat.splat37 + // CHECK-NEXT: %46 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value38 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %46, i32 0, i32 0 + // CHECK-NEXT: %47 = bitcast [90 x double]* %value38 to <90 x double>* + // CHECK-NEXT: store <90 x double> %45, <90 x double>* %47, align 8 + // CHECK-NEXT: %call39 = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* %w3) + // CHECK-NEXT: %conv40 = sitofp i32 %call39 to double + // CHECK-NEXT: %48 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value41 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %48, i32 0, i32 0 + // CHECK-NEXT: %49 = bitcast [90 x double]* %value41 to <90 x double>* + // CHECK-NEXT: %50 = load <90 x double>, <90 x double>* %49, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert42 = insertelement <90 x double> undef, double %conv40, i32 0 + // CHECK-NEXT: %scalar.splat.splat43 = shufflevector <90 x double> %scalar.splat.splatinsert42, <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: %51 = fsub <90 x double> %scalar.splat.splat43, %50 + // CHECK-NEXT: %52 = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %m.addr, align 8 + // CHECK-NEXT: %value44 = getelementptr inbounds %struct.MyMatrix.2, %struct.MyMatrix.2* %52, i32 0, i32 0 + // CHECK-NEXT: %53 = bitcast [90 x double]* %value44 to <90 x double>* + // CHECK-NEXT: store <90 x double> %51, <90 x double>* %53, align 8 + // CHECK-NEXT: ret void + // CHECK-NEXT: } + + m.value = m.value - w3; + m.value = w3 - m.value; +} diff --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c --- a/clang/test/Sema/matrix-type-operators.c +++ b/clang/test/Sema/matrix-type-operators.c @@ -91,3 +91,34 @@ float v12 = a[3]; // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} } + +typedef float sx10x5_t __attribute__((matrix_type(10, 5))); +typedef float sx10x10_t __attribute__((matrix_type(10, 10))); + +void add(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + a = b + c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}} + + a = b + b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = 10 + b; + // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = b + &c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} + // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} +} + +void sub(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + a = b - c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))'))}} + + a = b - b; // expected-error {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = 10 - b; + // expected-error@-1 {{assigning to 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') from incompatible type 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))')}} + + a = b - &c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} + // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} +} diff --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp --- a/clang/test/SemaCXX/matrix-type-operators.cpp +++ b/clang/test/SemaCXX/matrix-type-operators.cpp @@ -84,3 +84,96 @@ a[2] = f; // expected-error@-1 {{single subscript expressions are not allowed for matrix values}} } + +template +struct MyMatrix { + using matrix_t = EltTy __attribute__((matrix_type(Rows, Columns))); + + matrix_t value; +}; + +template +typename MyMatrix::matrix_t add(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value + B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + + return A.value + B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} +} + +void test_add_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = add(Mat1, Mat1); + // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-note@-2 {{in instantiation of function template specialization 'add' requested here}} + + Mat1.value = add(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'add' requested here}} + + Mat1.value = add(Mat2, Mat3); + // expected-note@-1 {{in instantiation of function template specialization 'add' requested here}} +} + +template +typename MyMatrix::matrix_t subtract(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value - B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}} + + return A.value - B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))')}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))')}} +} + +void test_subtract_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = subtract(Mat1, Mat1); + // expected-error@-1 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + // expected-note@-2 {{in instantiation of function template specialization 'subtract' requested here}} + + Mat1.value = subtract(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} + + Mat1.value = subtract(Mat2, Mat3); + // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} +} + +struct UserT {}; + +struct StructWithC { + operator UserT() { + // expected-note@-1 {{candidate function}} + // expected-note@-2 {{candidate function}} + // expected-note@-3 {{candidate function}} + // expected-note@-4 {{candidate function}} + return {}; + } +}; + +void test_DoubleWrapper(MyMatrix &m, StructWithC &c) { + m.value = m.value + c; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}} + + m.value = c + m.value; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}} + + m.value = m.value - c; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))') and 'StructWithC')}} + + m.value = c - m.value; + // expected-error@-1 {{no viable conversion from 'StructWithC' to 'double'}} + // expected-error@-2 {{invalid operands to binary expression ('StructWithC' and 'MyMatrix::matrix_t' (aka 'double __attribute__((matrix_type(10, 9)))'))}} +} diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -127,6 +127,16 @@ /// Add matrixes \p LHS and \p RHS. Support both integer and floating point /// matrixes. Value *CreateAdd(Value *LHS, Value *RHS) { + assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy() @@ -137,6 +147,16 @@ /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating /// point matrixes. Value *CreateSub(Value *LHS, Value *RHS) { + assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy()