diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11143,6 +11143,8 @@ QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, bool IsCompAssign); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, bool IsCompAssign); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -738,6 +738,22 @@ } } + if (Ops.Ty->isMatrixType()) { + llvm::MatrixBuilder MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast(Ops.E); + auto *LHSMatTy = + dyn_cast(BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = + dyn_cast(BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -9639,6 +9639,9 @@ return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isMatrixType() || + RHS.get()->getType()->isMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -11720,6 +11723,34 @@ return InvalidOperands(Loc, LHS, RHS); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign) { + + // For conversion purposes, we ignore any qualifiers. + // For example, "const float" and "float" are equivalent. + QualType LHSType = LHS.get()->getType().getUnqualifiedType(); + QualType RHSType = RHS.get()->getType().getUnqualifiedType(); + + const MatrixType *LHSMatType = LHSType->getAs(); + const MatrixType *RHSMatType = RHSType->getAs(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && RHSMatType) { + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); + } + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { diff --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c --- a/clang/test/CodeGen/matrix-type-operators.c +++ b/clang/test/CodeGen/matrix-type-operators.c @@ -463,3 +463,210 @@ // CHECK-NEXT: store <27 x i32> %19, <27 x i32>* %1, align 4 // CHECK-NEXT: ret void } + +void multiply_matrix_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @multiply_matrix_matrix( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]] +// +void multiply_matrix_scalar_floats(dx5x5_t *a, fx2x3_t *b, double vf, float vd) { + *a = *a * vf; + *a = *a * vd; + + // CHECK-LABEL: define void @multiply_matrix_scalar_floats([25 x double]* %a, [6 x float]* %b, double %vf, float %vd) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [6 x float]*, align 8 + // CHECK-NEXT: %vf.addr = alloca double, align 8 + // CHECK-NEXT: %vd.addr = alloca float, align 4 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [6 x float]* %b, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: store double %vf, double* %vf.addr, align 8 + // CHECK-NEXT: store float %vd, float* %vd.addr, align 4 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %vf.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: %9 = load <25 x double>, <25 x double>* %8, align 8 + // CHECK-NEXT: %10 = load float, float* %vd.addr, align 4 + // CHECK-NEXT: %conv = fpext float %10 to double + // CHECK-NEXT: %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %11 = fmul <25 x double> %9, %scalar.splat.splat2 + // CHECK-NEXT: %12 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %13 = bitcast [25 x double]* %12 to <25 x double>* + // CHECK-NEXT: store <25 x double> %11, <25 x double>* %13, align 8 + + *b = vf * *b; + *b = vd * *b; + + // CHECK-NEXT: %14 = load double, double* %vf.addr, align 8 + // CHECK-NEXT: %conv3 = fptrunc double %14 to float + // CHECK-NEXT: %15 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %16 = bitcast [6 x float]* %15 to <6 x float>* + // CHECK-NEXT: %17 = load <6 x float>, <6 x float>* %16, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert4 = insertelement <6 x float> undef, float %conv3, i32 0 + // CHECK-NEXT: %scalar.splat.splat5 = shufflevector <6 x float> %scalar.splat.splatinsert4, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %18 = fmul <6 x float> %scalar.splat.splat5, %17 + // CHECK-NEXT: %19 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %20 = bitcast [6 x float]* %19 to <6 x float>* + // CHECK-NEXT: store <6 x float> %18, <6 x float>* %20, align 4 + // CHECK-NEXT: %21 = load float, float* %vd.addr, align 4 + // CHECK-NEXT: %22 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %23 = bitcast [6 x float]* %22 to <6 x float>* + // CHECK-NEXT: %24 = load <6 x float>, <6 x float>* %23, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %21, i32 0 + // CHECK-NEXT: %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %25 = fmul <6 x float> %scalar.splat.splat7, %24 + // CHECK-NEXT: %26 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %27 = bitcast [6 x float]* %26 to <6 x float>* + // CHECK-NEXT: store <6 x float> %25, <6 x float>* %27, align 4 + // CHECK-NEXT: ret void +} + +void multiply_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) { + a = a * vs; + a = a * vli; + a = a * vulli; + + // CHECK-LABEL: define void @multiply_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %b.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %vs.addr = alloca i16, align 2 + // CHECK-NEXT: %vli.addr = alloca i64, align 8 + // CHECK-NEXT: %vulli.addr = alloca i64, align 8 + // CHECK-NEXT: %0 = bitcast [27 x i32]* %a.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %a, <27 x i32>* %0, align 4 + // CHECK-NEXT: %1 = bitcast [27 x i32]* %b.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %b, <27 x i32>* %1, align 4 + // CHECK-NEXT: store i16 %vs, i16* %vs.addr, align 2 + // CHECK-NEXT: store i64 %vli, i64* %vli.addr, align 8 + // CHECK-NEXT: store i64 %vulli, i64* %vulli.addr, align 8 + // CHECK-NEXT: %2 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %3 = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: %conv = sext i16 %3 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %4 = mul <27 x i32> %2, %scalar.splat.splat + // CHECK-NEXT: store <27 x i32> %4, <27 x i32>* %0, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %6 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv1 = trunc i64 %6 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0 + // CHECK-NEXT: %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %7 = mul <27 x i32> %5, %scalar.splat.splat3 + // CHECK-NEXT: store <27 x i32> %7, <27 x i32>* %0, align 4 + // CHECK-NEXT: %8 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %9 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv4 = trunc i64 %9 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0 + // CHECK-NEXT: %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %10 = mul <27 x i32> %8, %scalar.splat.splat6 + // CHECK-NEXT: store <27 x i32> %10, <27 x i32>* %0, align 4 + + b = vs * b; + b = vli * b; + b = vulli * b; + // CHECK-NEXT: %11 = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: %conv7 = sext i16 %11 to i32 + // CHECK-NEXT: %12 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0 + // CHECK-NEXT: %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %13 = mul <27 x i32> %scalar.splat.splat9, %12 + // CHECK-NEXT: store <27 x i32> %13, <27 x i32>* %1, align 4 + // CHECK-NEXT: %14 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv10 = trunc i64 %14 to i32 + // CHECK-NEXT: %15 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0 + // CHECK-NEXT: %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %16 = mul <27 x i32> %scalar.splat.splat12, %15 + // CHECK-NEXT: store <27 x i32> %16, <27 x i32>* %1, align 4 + // CHECK-NEXT: %17 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv13 = trunc i64 %17 to i32 + // CHECK-NEXT: %18 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0 + // CHECK-NEXT: %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %19 = mul <27 x i32> %scalar.splat.splat15, %18 + // CHECK-NEXT: store <27 x i32> %19, <27 x i32>* %1, align 4 + // CHECK-NEXT: ret void +} + +void multiply_matrix_scalar_constants(ix9x3_t a, fx2x3_t b, dx5x5_t c) { + a = a * 10; + a = a * 20ull; + a = a * 30ll; + + // CHECK-LABEL: define void @multiply_matrix_scalar_constants(<27 x i32> %a, <6 x float> %b, <25 x double> %c) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %b.addr = alloca [6 x float], align 4 + // CHECK-NEXT: %c.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %0 = bitcast [27 x i32]* %a.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %a, <27 x i32>* %0, align 4 + // CHECK-NEXT: %1 = bitcast [6 x float]* %b.addr to <6 x float>* + // CHECK-NEXT: store <6 x float> %b, <6 x float>* %1, align 4 + // CHECK-NEXT: %2 = bitcast [25 x double]* %c.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %c, <25 x double>* %2, align 8 + // CHECK-NEXT: %3 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %4 = mul <27 x i32> %3, + // CHECK-NEXT: store <27 x i32> %4, <27 x i32>* %0, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %6 = mul <27 x i32> %5, + // CHECK-NEXT: store <27 x i32> %6, <27 x i32>* %0, align 4 + // CHECK-NEXT: %7 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %8 = mul <27 x i32> %7, + // CHECK-NEXT: store <27 x i32> %8, <27 x i32>* %0, align 4 + + b = 10.0 * b; + b = ((float)20.0) * b; + + // CHECK-NEXT: %9 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %10 = fmul <6 x float> , %9 + // CHECK-NEXT: store <6 x float> %10, <6 x float>* %1, align 4 + // CHECK-NEXT: %11 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %12 = fmul <6 x float> , %11 + // CHECK-NEXT: store <6 x float> %12, <6 x float>* %1, align 4 + + c = 10.0 * c; + c = ((float)20.0) * c; + + // CHECK-NEXT: %13 = load <25 x double>, <25 x double>* %2, align 8 + // CHECK-NEXT: %14 = fmul <25 x double> , %13 + // CHECK-NEXT: store <25 x double> %14, <25 x double>* %2, align 8 + // CHECK-NEXT: %15 = load <25 x double>, <25 x double>* %2, align 8 + // CHECK-NEXT: %16 = fmul <25 x double> , %15 + // CHECK-NEXT: store <25 x double> %16, <25 x double>* %2, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn } diff --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp --- a/clang/test/CodeGenCXX/matrix-type-operators.cpp +++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -285,3 +285,53 @@ MyMatrix Mat2; Mat1.value = subtract(Mat1, Mat2); } + +void multiply1_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @_Z16multiply1_matrixPDm5_5_dS0_S0_([25 x double]* %a, [25 x double]* %b, [25 x double]* %c) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) + +void multiply1_scalar(dx5x5_t *a, dx5x5_t *b, double c) { + *a = *b * c; + + // CHECK-LABEL:@_Z16multiply1_scalarPDm5_5_dS0_d([25 x double]* %a, [25 x double]* %b, double %c + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca double, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store double %c, double* %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %c.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: ret void +} diff --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c --- a/clang/test/Sema/matrix-type-operators.c +++ b/clang/test/Sema/matrix-type-operators.c @@ -94,3 +94,22 @@ a = b - &c; // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} } + +void mat_mat_multiply(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + // Invalid dimensions for operands. + a = c * c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))') and 'sx10x5_t')}} + + // Shape of multiplication result does not match the type of b. + b = a * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + b = a * &c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} +} + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float scalar) { + // Shape of multiplication result does not match the type of b. + b = a * scalar; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} diff --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp --- a/clang/test/SemaCXX/matrix-type-operators.cpp +++ b/clang/test/SemaCXX/matrix-type-operators.cpp @@ -129,3 +129,26 @@ Mat1.value = subtract(Mat2, Mat3); // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} } + +template +typename MyMatrix::matrix_t multiply(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + Mat1.value = multiply(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} +} diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -33,6 +33,21 @@ IRBuilderTy &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + std::pair splatScalarOperandIfNeeded(Value *LHS, + Value *RHS) { + assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && + "One of the operands must be a matrix (embedded in a vector)"); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return {LHS, RHS}; + } + public: MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} @@ -127,16 +142,7 @@ /// Add matrixes \p LHS and \p RHS. Support both integer and floating point /// matrixes. Value *CreateAdd(Value *LHS, Value *RHS) { - assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); - if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) - RHS = B.CreateVectorSplat( - cast(LHS->getType())->getNumElements(), RHS, - "scalar.splat"); - else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) - LHS = B.CreateVectorSplat( - cast(RHS->getType())->getNumElements(), LHS, - "scalar.splat"); - + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy() @@ -147,16 +153,7 @@ /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating /// point matrixes. Value *CreateSub(Value *LHS, Value *RHS) { - assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); - if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) - RHS = B.CreateVectorSplat( - cast(LHS->getType())->getNumElements(), RHS, - "scalar.splat"); - else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) - LHS = B.CreateVectorSplat( - cast(RHS->getType())->getNumElements(), LHS, - "scalar.splat"); - + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); return cast(LHS->getType()) ->getElementType() ->isFloatingPointTy() @@ -164,15 +161,13 @@ : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); + if (LHS->getType()->getScalarType()->isFloatingPointTy()) + return B.CreateFMul(LHS, RHS); + return B.CreateMul(LHS, RHS); } /// Extracts the element at (\p Row, \p Column) from \p Matrix.