diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11218,6 +11218,8 @@ QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, bool IsCompAssign); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, bool IsCompAssign); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType); diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp --- a/clang/lib/CodeGen/CGExprScalar.cpp +++ b/clang/lib/CodeGen/CGExprScalar.cpp @@ -741,6 +741,22 @@ } } + if (Ops.Ty->isConstantMatrixType()) { + llvm::MatrixBuilder MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast(Ops.E); + auto *LHSMatTy = dyn_cast( + BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = dyn_cast( + BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -10058,6 +10058,9 @@ return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isConstantMatrixType() || + RHS.get()->getType()->isConstantMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -12120,6 +12123,37 @@ return InvalidOperands(Loc, LHS, RHS); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign) { + if (!IsCompAssign) { + LHS = DefaultFunctionArrayLvalueConversion(LHS.get()); + if (LHS.isInvalid()) + return QualType(); + } + RHS = DefaultFunctionArrayLvalueConversion(RHS.get()); + if (RHS.isInvalid()) + return QualType(); + + auto *LHSMatType = LHS.get()->getType()->getAs(); + auto *RHSMatType = RHS.get()->getType()->getAs(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && RHSMatType) { + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getConstantMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); + } + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { diff --git a/clang/lib/Sema/SemaOverload.cpp b/clang/lib/Sema/SemaOverload.cpp --- a/clang/lib/Sema/SemaOverload.cpp +++ b/clang/lib/Sema/SemaOverload.cpp @@ -9167,8 +9167,10 @@ case OO_Star: // '*' is either unary or binary if (Args.size() == 1) OpBuilder.addUnaryStarPointerOverloads(); - else + else { OpBuilder.addGenericBinaryArithmeticOverloads(); + OpBuilder.addMatrixBinaryArithmeticOverloads(); + } break; case OO_Slash: diff --git a/clang/test/CodeGen/matrix-type-operators.c b/clang/test/CodeGen/matrix-type-operators.c --- a/clang/test/CodeGen/matrix-type-operators.c +++ b/clang/test/CodeGen/matrix-type-operators.c @@ -173,6 +173,134 @@ b = vulli + b; } +// Tests for matrix multiplication. + +void multiply_matrix_matrix_double(dx5x5_t b, dx5x5_t c) { + // CHECK-LABEL: @multiply_matrix_matrix_double( + // CHECK: [[B:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[C:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[RES:%.*]] = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> [[B]], <25 x double> [[C]], i32 5, i32 5, i32 5) + // CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [25 x double]* %a to <25 x double>* + // CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* [[A_ADDR]], align 8 + // CHECK-NEXT: ret void + // + + dx5x5_t a; + a = b * c; +} + +typedef int ix3x9_t __attribute__((matrix_type(3, 9))); +typedef int ix9x9_t __attribute__((matrix_type(9, 9))); +// CHECK-LABEL: @multiply_matrix_matrix_int( +// CHECK: [[B:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[C:%.*]] = load <27 x i32>, <27 x i32>* {{.*}}, align 4 +// CHECK-NEXT: [[RES:%.*]] = call <81 x i32> @llvm.matrix.multiply.v81i32.v27i32.v27i32(<27 x i32> [[B]], <27 x i32> [[C]], i32 9, i32 3, i32 9) +// CHECK-NEXT: [[A_ADDR:%.*]] = bitcast [81 x i32]* %a to <81 x i32>* +// CHECK-NEXT: store <81 x i32> [[RES]], <81 x i32>* [[A_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_matrix_matrix_int(ix9x3_t b, ix3x9_t c) { + ix9x9_t a; + a = b * c; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_float( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load float, float* %s.addr, align 4 +// CHECK-NEXT: [[S_EXT:%.*]] = fpext float [[S]] to double +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_float(dx5x5_t a, float s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_double_matrix_scalar_double( +// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> undef, double [[S]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> undef, <25 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <25 x double> [[A]], [[VECSPLAT]] +// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8 +// CHECK-NEXT: ret void +// +void multiply_double_matrix_scalar_double(dx5x5_t a, double s) { + a = a * s; +} + +// CHECK-LABEL: @multiply_float_matrix_scalar_double( +// CHECK: [[S:%.*]] = load double, double* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = fptrunc double [[S]] to float +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <6 x float> undef, float [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <6 x float> [[VECINSERT]], <6 x float> undef, <6 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_scalar_double(fx2x3_t b, double s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_short( +// CHECK: [[S:%.*]] = load i16, i16* %s.addr, align 2 +// CHECK-NEXT: [[S_EXT:%.*]] = sext i16 [[S]] to i32 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_EXT]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[VECSPLAT]], [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_short(ix9x3_t b, short s) { + b = s * b; +} + +// CHECK-LABEL: @multiply_int_matrix_scalar_ull( +// CHECK: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4 +// CHECK-NEXT: [[S:%.*]] = load i64, i64* %s.addr, align 8 +// CHECK-NEXT: [[S_TRUNC:%.*]] = trunc i64 [[S]] to i32 +// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> undef, i32 [[S_TRUNC]], i32 0 +// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> undef, <27 x i32> zeroinitializer +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> [[MAT]], [[VECSPLAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_scalar_ull(ix9x3_t b, unsigned long long s) { + b = b * s; +} + +// CHECK-LABEL: @multiply_float_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [6 x float], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [6 x float]* [[A_ADDR]] to <6 x float>* +// CHECK-NEXT: store <6 x float> [[A:%.*]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = fmul <6 x float> [[MAT]], +// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_float_matrix_constant(fx2x3_t a) { + a = a * 2.5; +} + +// CHECK-LABEL: @multiply_int_matrix_constant( +// CHECK-NEXT: entry: +// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [27 x i32], align 4 +// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [27 x i32]* [[A_ADDR]] to <27 x i32>* +// CHECK-NEXT: store <27 x i32> [[A:%.*]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: [[RES:%.*]] = mul <27 x i32> , [[MAT]] +// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4 +// CHECK-NEXT: ret void +// +void multiply_int_matrix_constant(ix9x3_t a) { + a = 5 * a; +} + // Tests for the matrix type operators. typedef double dx5x5_t __attribute__((matrix_type(5, 5))); diff --git a/clang/test/CodeGenCXX/matrix-type-operators.cpp b/clang/test/CodeGenCXX/matrix-type-operators.cpp --- a/clang/test/CodeGenCXX/matrix-type-operators.cpp +++ b/clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -157,6 +157,45 @@ m.value = w3 - m.value; } +template +typename MyMatrix::matrix_t multiply(MyMatrix &A, MyMatrix &B) { + return A.value * B.value; +} + +MyMatrix test_multiply_template(MyMatrix Mat1, + MyMatrix Mat2) { + // CHECK-LABEL: define void @_Z22test_multiply_template8MyMatrixIfLj2ELj5EES_IfLj5ELj2EE( + // CHECK-NEXT: entry: + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE(%struct.MyMatrix* nonnull align 4 dereferenceable(40) %Mat1, %struct.MyMatrix.2* nonnull align 4 dereferenceable(40) %Mat2) + // CHECK-NEXT: %value = getelementptr inbounds %struct.MyMatrix.1, %struct.MyMatrix.1* %agg.result, i32 0, i32 0 + // CHECK-NEXT: [[VALUE_ADDR:%.*]] = bitcast [4 x float]* %value to <4 x float>* + // CHECK-NEXT: store <4 x float> [[RES]], <4 x float>* [[VALUE_ADDR]], align 4 + // CHECK-NEXT: ret void + // + // CHECK-LABEL: define linkonce_odr <4 x float> @_Z8multiplyIfLj2ELj5ELj2EEN8MyMatrixIT_XT0_EXT2_EE8matrix_tERS0_IS1_XT0_EXT1_EERS0_IS1_XT1_EXT2_EE( + // CHECK: [[MAT1:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK: [[MAT2:%.*]] = load <10 x float>, <10 x float>* {{.*}}, align 4 + // CHECK-NEXT: [[RES:%.*]] = call <4 x float> @llvm.matrix.multiply.v4f32.v10f32.v10f32(<10 x float> [[MAT1]], <10 x float> [[MAT2]], i32 2, i32 5, i32 2) + // CHECK-NEXT: ret <4 x float> [[RES]] + + MyMatrix Res; + Res.value = multiply(Mat1, Mat2); + return Res; +} + +void test_IntWrapper_Multiply(MyMatrix &m, IntWrapper &w3) { + // CHECK-LABEL: define void @_Z24test_IntWrapper_MultiplyR8MyMatrixIdLj10ELj9EER10IntWrapper( + // CHECK: [[SCALAR:%.*]] = call i32 @_ZN10IntWrappercviEv(%struct.IntWrapper* {{.*}}) + // CHECK-NEXT: [[SCALAR_FP:%.*]] = sitofp i32 %call to double + // CHECK: [[MATRIX:%.*]] = load <90 x double>, <90 x double>* {{.*}}, align 8 + // CHECK-NEXT: [[SCALAR_EMBED:%.*]] = insertelement <90 x double> undef, double [[SCALAR_FP]], i32 0 + // CHECK-NEXT: [[SCALAR_EMBED1:%.*]] = shufflevector <90 x double> [[SCALAR_EMBED]], <90 x double> undef, <90 x i32> zeroinitializer + // CHECK-NEXT: [[RES:%.*]] = fmul <90 x double> [[SCALAR_EMBED1]], [[MATRIX]] + // CHECK: store <90 x double> [[RES]], <90 x double>* {{.*}}, align 8 + // CHECK: ret void + m.value = w3 * m.value; +} + template void insert(MyMatrix &Mat, EltTy e, unsigned i, unsigned j) { Mat.value[i][j] = e; @@ -164,11 +203,11 @@ void test_insert_template1(MyMatrix &Mat, unsigned e, unsigned i, unsigned j) { // CHECK-LABEL: @_Z21test_insert_template1R8MyMatrixIjLj2ELj2EEjjj( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.1*, %struct.MyMatrix.1** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.3*, %struct.MyMatrix.3** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load i32, i32* %e.addr, align 4 // CHECK-NEXT: [[I:%.*]] = load i32, i32* %i.addr, align 4 // CHECK-NEXT: [[J:%.*]] = load i32, i32* %j.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.1* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) + // CHECK-NEXT: call void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT_ADDR]], i32 [[E]], i32 [[I]], i32 [[J]]) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIjLj2ELj2EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -190,9 +229,9 @@ void test_insert_template2(MyMatrix &Mat, float e) { // CHECK-LABEL: @_Z21test_insert_template2R8MyMatrixIfLj3ELj8EEf( - // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.2*, %struct.MyMatrix.2** %Mat.addr, align 8 + // CHECK: [[MAT_ADDR:%.*]] = load %struct.MyMatrix.4*, %struct.MyMatrix.4** %Mat.addr, align 8 // CHECK-NEXT: [[E:%.*]] = load float, float* %e.addr, align 4 - // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.2* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) + // CHECK-NEXT: call void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj(%struct.MyMatrix.4* nonnull align 4 dereferenceable(96) [[MAT_ADDR]], float [[E]], i32 2, i32 5) // CHECK-NEXT: ret void // // CHECK-LABEL: define linkonce_odr void @_Z6insertIfLj3ELj8EEvR8MyMatrixIT_XT0_EXT1_EES1_jj( @@ -220,7 +259,7 @@ int test_extract_template(MyMatrix Mat1) { // CHECK-LABEL: @_Z21test_extract_template8MyMatrixIiLj2ELj2EE( // CHECK-NEXT: entry: - // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.3* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) + // CHECK-NEXT: [[CALL:%.*]] = call i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE(%struct.MyMatrix.5* nonnull align 4 dereferenceable(16) [[MAT1:%.*]]) // CHECK-NEXT: ret i32 [[CALL]] // // CHECK-LABEL: define linkonce_odr i32 @_Z7extractIiLj2ELj2EET_R8MyMatrixIS0_XT0_EXT1_EE( @@ -301,7 +340,7 @@ constexpr identmatrix_t identmatrix; void test_constexpr1(matrix_type &m) { - // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef([16 x float]* nonnull align 4 dereferenceable(64) %m) #3 { + // CHECK-LABEL: define void @_Z15test_constexpr1RU11matrix_typeLm4ELm4Ef( // CHECK: [[MAT:%.*]] = load <16 x float>, <16 x float>* {{.*}}, align 4 // CHECK-NEXT: [[IM:%.*]] = call <16 x float> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IfLj4EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK-NEXT: [[ADD:%.*]] = fadd <16 x float> [[MAT]], [[IM]] @@ -327,7 +366,7 @@ } void test_constexpr2(matrix_type &m) { - // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei([25 x i32]* nonnull align 4 dereferenceable(100) %m) #4 { + // CHECK-LABEL: define void @_Z15test_constexpr2RU11matrix_typeLm5ELm5Ei( // CHECK: [[IM:%.*]] = call <25 x i32> @_ZNK13identmatrix_tcvU11matrix_typeXT0_EXT0_ET_IiLj5EEEv(%struct.identmatrix_t* @_ZL11identmatrix) // CHECK: [[MAT:%.*]] = load <25 x i32>, <25 x i32>* {{.*}}, align 4 // CHECK-NEXT: [[SUB:%.*]] = sub <25 x i32> [[IM]], [[MAT]] diff --git a/clang/test/Sema/matrix-type-operators.c b/clang/test/Sema/matrix-type-operators.c --- a/clang/test/Sema/matrix-type-operators.c +++ b/clang/test/Sema/matrix-type-operators.c @@ -32,6 +32,44 @@ // expected-error@-2 {{casting 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*') to incompatible type 'float'}} } +typedef int ix10x5_t __attribute__((matrix_type(10, 5))); +typedef int ix10x10_t __attribute__((matrix_type(10, 10))); + +void matrix_matrix_multiply(sx10x10_t a, sx5x10_t b, ix10x5_t c, ix10x10_t d, float sf, char *p) { + // Check dimension mismatches. + a = a * b; + // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))'))}} + b = a * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + // Check element type mismatches. + a = b * c; + // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'ix10x5_t' (aka 'int __attribute__((matrix_type(10, 5)))'))}} + d = a * a; + // expected-error@-1 {{assigning to 'ix10x10_t' (aka 'int __attribute__((matrix_type(10, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + p = a * a; + // expected-error@-1 {{assigning to 'char *' from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} +} + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float sf, char *p) { + // Shape of multiplication result does not match the type of b. + b = a * sf; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} + b = sf * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} + + a = a * p; + // expected-error@-1 {{casting 'char *' to incompatible type 'float'}} + // expected-error@-2 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'char *')}} + a = p * a; + // expected-error@-1 {{casting 'char *' to incompatible type 'float'}} + // expected-error@-2 {{invalid operands to binary expression ('char *' and 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))'))}} + + sf = a * sf; + // expected-error@-1 {{assigning to 'float' from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} + sx5x10_t get_matrix(); void insert(sx5x10_t a, float f) { diff --git a/clang/test/SemaCXX/matrix-type-operators.cpp b/clang/test/SemaCXX/matrix-type-operators.cpp --- a/clang/test/SemaCXX/matrix-type-operators.cpp +++ b/clang/test/SemaCXX/matrix-type-operators.cpp @@ -65,6 +65,45 @@ // expected-note@-1 {{in instantiation of function template specialization 'subtract' requested here}} } +template +typename MyMatrix::matrix_t multiply(MyMatrix &A, MyMatrix &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} + + MyMatrix m; + B.value = m.value * A.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))'))}} + // expected-error@-3 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'int __attribute__((matrix_type(5, 6)))') and 'MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'float __attribute__((matrix_type(2, 2)))') and 'MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + MyMatrix Mat2; + MyMatrix Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + MyMatrix Mat4; + Mat1.value = multiply(Mat4, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + + Mat1.value = multiply(Mat3, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply' requested here}} + + Mat4.value = Mat4.value * Mat1; + // expected-error@-1 {{no viable conversion from 'MyMatrix' to 'unsigned int'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 2)))') and 'MyMatrix')}} +} + struct UserT {}; struct StructWithC { diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -33,6 +33,21 @@ IRBuilderTy &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + std::pair splatScalarOperandIfNeeded(Value *LHS, + Value *RHS) { + assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && + "One of the operands must be a matrix (embedded in a vector)"); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return {LHS, RHS}; + } + public: MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} @@ -164,15 +179,13 @@ : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); + if (LHS->getType()->getScalarType()->isFloatingPointTy()) + return B.CreateFMul(LHS, RHS); + return B.CreateMul(LHS, RHS); } /// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.