diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11077,6 +11077,8 @@ /// Type checking for matrix binary operators. QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -738,6 +738,22 @@ } } + if (Ops.Ty->isMatrixType()) { + llvm::MatrixBuilder MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast(Ops.E); + auto *LHSMatTy = + dyn_cast(BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = + dyn_cast(BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -9544,6 +9544,9 @@ return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isMatrixType() || + RHS.get()->getType()->isMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -11578,6 +11581,41 @@ // assert(LHSMatType || RHSMatType); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc) { + // For conversion purposes, we ignore any qualifiers. + // For example, "const float" and "float" are equivalent. + QualType LHSType = LHS.get()->getType().getUnqualifiedType(); + QualType RHSType = RHS.get()->getType().getUnqualifiedType(); + + const MatrixType *LHSMatType = LHSType->getAs(); + const MatrixType *RHSMatType = RHSType->getAs(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && !RHSMatType) { + if (!Context.hasSameType(LHSMatType->getElementType(), RHSType)) + return InvalidOperands(Loc, LHS, RHS); + return LHSType; + } + + if (!LHSMatType && RHSMatType) { + if (!Context.hasSameType(LHSType, RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + return RHSType; + } + + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { diff --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c --- a/clang/test/CodeGen/matrix-type-operators.c +++ b/clang/test/CodeGen/matrix-type-operators.c @@ -225,3 +225,55 @@ // CHECK-NEXT: store <27 x i32> %11, <27 x i32>* %3, align 4 // CHECK-NEXT: ret void } + +void multiply_matrix_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @multiply_matrix_matrix( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]] +// +void multiply_matrix_scalar(dx5x5_t *a, dx5x5_t *b, double c) { + *a = *b * c; + + // CHECK-LABEL: @multiply_matrix_scalar( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca double, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store double %c, double* %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %c.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn } diff --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp --- a/clang/test/CodeGenCXX/matrix-type-operators.cpp +++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -285,3 +285,53 @@ MyMatrix Mat2; Mat1.value = subtract(Mat1, Mat2); } + +void multiply1_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @_Z16multiply1_matrixPDm5_5_dS0_S0_([25 x double]* %a, [25 x double]* %b, [25 x double]* %c) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) + +void multiply1_scalar(dx5x5_t *a, dx5x5_t *b, double c) { + *a = *b * c; + + // CHECK-LABEL:@_Z16multiply1_scalarPDm5_5_dS0_d([25 x double]* %a, [25 x double]* %b, double %c + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca double, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store double %c, double* %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %c.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: ret void +} diff --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c --- a/clang/test/Sema/matrix-type-operators.c +++ b/clang/test/Sema/matrix-type-operators.c @@ -96,3 +96,22 @@ a = b - &c; // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} } + +void mat_mat_multiply(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + // Invalid dimensions for operands. + a = c * c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))') and 'sx10x5_t')}} + + // Shape of multiplication result does not match the type of b. + b = a * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + b = a * &c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} +} + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float scalar) { + // Shape of multiplication result does not match the type of b. + b = a * scalar; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} diff --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp --- a/clang/test/SemaCXX/matrix-type-operators.cpp +++ b/clang/test/SemaCXX/matrix-type-operators.cpp @@ -122,3 +122,26 @@ Mat1.value = subtract(Mat2, Mat3); // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} } + +template +typename MyMatrix::matrix_t multiply(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2))) '}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + Mat1.value = multiply(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} +} diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -144,15 +144,24 @@ : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + assert(LHS->getType()->isVectorTy() || + RHS->getType()->isVectorTy() && + "One of the operands must be a matrix (embedded in a vector)"); + Value *ScalarVector = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), + LHS->getType()->isVectorTy() ? RHS : LHS, "scalar.splat"); + if (RHS->getType()->isFloatingPointTy()) { + if (LHS->getType()->isVectorTy()) + return B.CreateFMul(LHS, ScalarVector); + return B.CreateFMul(ScalarVector, RHS); + } + + if (LHS->getType()->isVectorTy()) + return B.CreateMul(LHS, ScalarVector); + return B.CreateMul(ScalarVector, RHS); } /// Extracts the element at (\p Row, \p Column) from \p Matrix.