diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -577,6 +577,7 @@ BUILTIN(__builtin_matrix_transpose, "v.", "nFt") BUILTIN(__builtin_matrix_column_major_load, "v.", "nFt") +BUILTIN(__builtin_matrix_column_major_store, "v.", "nFt") // "Overloaded" Atomic operator builtins. These are overloaded to support data // types of i8, i16, i32, i64, and i128. The front-end sees calls to the diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10762,6 +10762,9 @@ def err_builtin_matrix_pointer_arg: Error< "%select{first|second}0 argument must be a pointer to a valid matrix element type">; +def err_builtin_matrix_pointer_arg_mismatch: Error< + "the pointee of the second argument must match the element type of the first argument (%0 != %1)">; + def err_builtin_matrix_stride_too_small: Error< "stride must be greater or equal to the number of rows">; diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -12088,6 +12088,8 @@ ExprResult CallResult); ExprResult SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixColumnMajorStoreOverload(CallExpr *TheCall, + ExprResult CallResult); public: enum FormatStringType { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2409,6 +2409,21 @@ } return RValue::get(Result); } + case Builtin::BI__builtin_matrix_column_major_store: { + MatrixBuilder MB(Builder); + Value *Matrix = EmitScalarExpr(E->getArg(0)); + auto *MatrixTy = getMatrixTy(E->getArg(0)->getType()); + Address Dst = EmitPointerWithAlignment(E->getArg(1)); + EmitNonNullArgCheck(RValue::get(Dst.getPointer()), E->getArg(1)->getType(), + E->getArg(1)->getExprLoc(), FD, 1); + Value *Stride = EmitScalarExpr(E->getArg(2)); + + // TODO: Pass Dst alignment to intrinsic + MB.CreateMatrixColumnwiseStore(Matrix, Dst.getPointer(), Stride, + MatrixTy->getNumRows(), + MatrixTy->getNumColumns()); + return RValue::get(Dst.getPointer()); + } case Builtin::BI__builtin_matrix_transpose: { const ConstantMatrixType *MatrixTy = getMatrixTy(E->getArg(0)->getType()); Value *MatValue = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1914,6 +1914,7 @@ case Builtin::BI__builtin_matrix_transpose: case Builtin::BI__builtin_matrix_column_major_load: + case Builtin::BI__builtin_matrix_column_major_store: if (!getLangOpts().MatrixTypes) { Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled); return ExprError(); @@ -1924,6 +1925,8 @@ return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult); case Builtin::BI__builtin_matrix_column_major_load: return SemaBuiltinMatrixColumnMajorLoadOverload(TheCall, TheCallResult); + case Builtin::BI__builtin_matrix_column_major_store: + return SemaBuiltinMatrixColumnMajorStoreOverload(TheCall, TheCallResult); default: llvm_unreachable("All matrix builtins should be handled here!"); } @@ -14896,3 +14899,60 @@ TheCall->setType(ReturnType); return CallResult; } + +ExprResult +Sema::SemaBuiltinMatrixColumnMajorStoreOverload(CallExpr *TheCall, + ExprResult CallResult) { + if (checkArgCount(*this, TheCall, 3)) + return ExprError(); + + Expr *MatrixExpr = TheCall->getArg(0); + Expr *DataExpr = TheCall->getArg(1); + Expr *StrideExpr = TheCall->getArg(2); + + bool ArgError = false; + if (!MatrixExpr->getType()->isConstantMatrixType()) { + Diag(MatrixExpr->getBeginLoc(), diag::err_builtin_matrix_arg) << 0; + ArgError = true; + } + ConstantMatrixType const *MType = dyn_cast( + MatrixExpr->getType().getCanonicalType()); + + QualType PointerTy = DataExpr->getType(); + if (!PointerTy->isPointerType()) { + Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) + << 1 << 0; + ArgError = true; + } else if (MType) { + auto PointeeTy = PointerTy->getPointeeType().getCanonicalType(); + auto ElementTy = MType->getElementType().getCanonicalType(); + if (PointeeTy != ElementTy) { + Diag(DataExpr->getBeginLoc(), + diag::err_builtin_matrix_pointer_arg_mismatch) + << PointeeTy << ElementTy; + ArgError = true; + } + } + + // Check stride. + if (!StrideExpr->getType()->isIntegralType(Context)) { + Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg) + << 2 << 1; + ArgError = true; + } else { + llvm::APSInt Value(64); + if (StrideExpr->isIntegerConstantExpr(Value, Context)) { + uint64_t Stride = Value.getZExtValue(); + if (MType && Stride < MType->getNumRows()) { + Diag(StrideExpr->getBeginLoc(), + diag::err_builtin_matrix_stride_too_small); + ArgError = true; + } + } + } + + if (ArgError) + return ExprError(); + + return CallResult; +} diff --git a/clang/test/CodeGen/matrix-type-builtins.c b/clang/test/CodeGen/matrix-type-builtins.c --- a/clang/test/CodeGen/matrix-type-builtins.c +++ b/clang/test/CodeGen/matrix-type-builtins.c @@ -111,3 +111,54 @@ // CHECK-NEXT: ret void ix4x20_t m_c = __builtin_matrix_column_major_load(c, 4, 20, 4); } + +void column_major_store1(dx5x5_t *a_m, double *a, fx2x3_t *b_m, float *b, ix4x20_t *c_m, int *c, unsigned Stride, fx3x2_t *d_m) { + // CHECK-LABEL: define void @column_major_store1([25 x double]* %a_m, double* %a, [6 x float]* %b_m, float* %b, [80 x i32]* %c_m, i32* %c, i32 %Stride, [6 x float]* %d_m) #0 { + // CHECK-NEXT: entry: + // CHECK-NEXT: %a_m.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %a.addr = alloca double*, align 8 + // CHECK-NEXT: %b_m.addr = alloca [6 x float]*, align 8 + // CHECK-NEXT: %b.addr = alloca float*, align 8 + // CHECK-NEXT: %c_m.addr = alloca [80 x i32]*, align 8 + // CHECK-NEXT: %c.addr = alloca i32*, align 8 + // CHECK-NEXT: %Stride.addr = alloca i32, align 4 + // CHECK-NEXT: %d_m.addr = alloca [6 x float]*, align 8 + // CHECK-NEXT: store [25 x double]* %a_m, [25 x double]** %a_m.addr, align 8 + // CHECK-NEXT: store double* %a, double** %a.addr, align 8 + // CHECK-NEXT: store [6 x float]* %b_m, [6 x float]** %b_m.addr, align 8 + // CHECK-NEXT: store float* %b, float** %b.addr, align 8 + // CHECK-NEXT: store [80 x i32]* %c_m, [80 x i32]** %c_m.addr, align 8 + // CHECK-NEXT: store i32* %c, i32** %c.addr, align 8 + // CHECK-NEXT: store i32 %Stride, i32* %Stride.addr, align 4 + // CHECK-NEXT: store [6 x float]* %d_m, [6 x float]** %d_m.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %a_m.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double*, double** %a.addr, align 8 + // CHECK-NEXT: call void @llvm.matrix.columnwise.store.v25f64.p0f64(<25 x double> %2, double* %3, i32 5, i32 5, i32 5) + __builtin_matrix_column_major_store(*a_m, a, 5); + + // CHECK-NEXT: %4 = load [6 x float]*, [6 x float]** %b_m.addr, align 8 + // CHECK-NEXT: %5 = bitcast [6 x float]* %4 to <6 x float>* + // CHECK-NEXT: %6 = load <6 x float>, <6 x float>* %5, align 4 + // CHECK-NEXT: %7 = load float*, float** %b.addr, align 8 + // CHECK-NEXT: %8 = load i32, i32* %Stride.addr, align 4 + // CHECK-NEXT: call void @llvm.matrix.columnwise.store.v6f32.p0f32(<6 x float> %6, float* %7, i32 %8, i32 2, i32 3) + __builtin_matrix_column_major_store(*b_m, b, Stride); + + // CHECK-NEXT: %9 = load [80 x i32]*, [80 x i32]** %c_m.addr, align 8 + // CHECK-NEXT: %10 = bitcast [80 x i32]* %9 to <80 x i32>* + // CHECK-NEXT: %11 = load <80 x i32>, <80 x i32>* %10, align 4 + // CHECK-NEXT: %12 = load i32*, i32** %c.addr, align 8 + // CHECK-NEXT: call void @llvm.matrix.columnwise.store.v80i32.p0i32(<80 x i32> %11, i32* %12, i32 14, i32 4, i32 20) + __builtin_matrix_column_major_store(*c_m, c, 14); + + // CHECK-NEXT: %13 = load [6 x float]*, [6 x float]** %d_m.addr, align 8 + // CHECK-NEXT: %14 = bitcast [6 x float]* %13 to <6 x float>* + // CHECK-NEXT: %15 = load <6 x float>, <6 x float>* %14, align 4 + // CHECK-NEXT: %16 = load float*, float** %b.addr, align 8 + // CHECK-NEXT: %17 = load i32, i32* %Stride.addr, align 4 + // CHECK-NEXT: call void @llvm.matrix.columnwise.store.v6f32.p0f32(<6 x float> %15, float* %16, i32 %17, i32 3, i32 2) + // CHECK-NEXT: ret void + __builtin_matrix_column_major_store(*d_m, b, Stride); +} diff --git a/clang/test/CodeGenCXX/matrix-type-builtins.cpp b/clang/test/CodeGenCXX/matrix-type-builtins.cpp --- a/clang/test/CodeGenCXX/matrix-type-builtins.cpp +++ b/clang/test/CodeGenCXX/matrix-type-builtins.cpp @@ -211,3 +211,96 @@ // CHECK-NEXT: %0 = load double*, double** %Ptr.addr, align 8 // CHECK-NEXT: %matrix = call <63 x double> @llvm.matrix.columnwise.load.v63f64.p0f64(double* %0, i32 7, i32 7, i32 9) // CHECK-NEXT: ret <63 x double> %matrix +// + +void column_major_store(matrix_t &M1, float *Ptr1, matrix_t &M2, unsigned *Ptr2, matrix_t &M3) { + __builtin_matrix_column_major_store(M1, Ptr1, 5); + __builtin_matrix_column_major_store(M2, Ptr2, 4); + __builtin_matrix_column_major_store(M3, Ptr2, 4); +} + +template +void column_major_store(T &M, PtrTy Ptr, unsigned Stride) { + __builtin_matrix_column_major_store(M, Ptr, Stride); +} + +void test_column_major_load_template(matrix_t &M1, int *Ptr1, matrix_t &M2, double *Ptr2, matrix_t &M3, unsigned Stride) { + // CHECK-LABEL: define void @_Z31test_column_major_load_templateRU11matrix_typeLm10ELm4EiPiRU11matrix_typeLm7ELm9EiPdRU11matrix_typeLm7ELm9Edj([40 x i32]* dereferenceable(160) %M1, i32* %Ptr1, [63 x i32]* dereferenceable(252) %M2, double* %Ptr2, [63 x double]* dereferenceable(504) %M3, i32 %Stride) + // CHECK-NEXT: entry: + // CHECK-NEXT: %M1.addr = alloca [40 x i32]*, align 8 + // CHECK-NEXT: %Ptr1.addr = alloca i32*, align 8 + // CHECK-NEXT: %M2.addr = alloca [63 x i32]*, align 8 + // CHECK-NEXT: %Ptr2.addr = alloca double*, align 8 + // CHECK-NEXT: %M3.addr = alloca [63 x double]*, align 8 + // CHECK-NEXT: %Stride.addr = alloca i32, align 4 + // CHECK-NEXT: store [40 x i32]* %M1, [40 x i32]** %M1.addr, align 8 + // CHECK-NEXT: store i32* %Ptr1, i32** %Ptr1.addr, align 8 + // CHECK-NEXT: store [63 x i32]* %M2, [63 x i32]** %M2.addr, align 8 + // CHECK-NEXT: store double* %Ptr2, double** %Ptr2.addr, align 8 + // CHECK-NEXT: store [63 x double]* %M3, [63 x double]** %M3.addr, align 8 + // CHECK-NEXT: store i32 %Stride, i32* %Stride.addr, align 4 + // CHECK-NEXT: %0 = load [40 x i32]*, [40 x i32]** %M1.addr, align 8 + // CHECK-NEXT: %1 = load i32*, i32** %Ptr1.addr, align 8 + // CHECK-NEXT: call void @_Z18column_major_storeIU11matrix_typeLm10ELm4EiPiEvRT_T0_j([40 x i32]* dereferenceable(160) %0, i32* %1, i32 10) + column_major_store(M1, Ptr1, 10); + + // CHECK-NEXT: %2 = load [63 x i32]*, [63 x i32]** %M2.addr, align 8 + // CHECK-NEXT: %3 = load i32*, i32** %Ptr1.addr, align 8 + // CHECK-NEXT: %4 = load i32, i32* %Stride.addr, align 4 + // CHECK-NEXT: call void @_Z18column_major_storeIU11matrix_typeLm7ELm9EiPiEvRT_T0_j([63 x i32]* dereferenceable(252) %2, i32* %3, i32 %4) + column_major_store(M2, Ptr1, Stride); + + // CHECK-NEXT: %5 = load [63 x double]*, [63 x double]** %M3.addr, align 8 + // CHECK-NEXT: %6 = load double*, double** %Ptr2.addr, align 8 + // CHECK-NEXT: call void @_Z18column_major_storeIU11matrix_typeLm7ELm9EdPdEvRT_T0_j([63 x double]* dereferenceable(504) %5, double* %6, i32 10) + // CHECK-NEXT: ret void + column_major_store(M3, Ptr2, 10); +} + +// CHECK-LABEL: define linkonce_odr void @_Z18column_major_storeIU11matrix_typeLm10ELm4EiPiEvRT_T0_j([40 x i32]* dereferenceable(160) %M, i32* %Ptr, i32 %Stride) +// CHECK-NEXT: entry: +// CHECK-NEXT: %M.addr = alloca [40 x i32]*, align 8 +// CHECK-NEXT: %Ptr.addr = alloca i32*, align 8 +// CHECK-NEXT: %Stride.addr = alloca i32, align 4 +// CHECK-NEXT: store [40 x i32]* %M, [40 x i32]** %M.addr, align 8 +// CHECK-NEXT: store i32* %Ptr, i32** %Ptr.addr, align 8 +// CHECK-NEXT: store i32 %Stride, i32* %Stride.addr, align 4 +// CHECK-NEXT: %0 = load [40 x i32]*, [40 x i32]** %M.addr, align 8 +// CHECK-NEXT: %1 = bitcast [40 x i32]* %0 to <40 x i32>* +// CHECK-NEXT: %2 = load <40 x i32>, <40 x i32>* %1, align 4 +// CHECK-NEXT: %3 = load i32*, i32** %Ptr.addr, align 8 +// CHECK-NEXT: %4 = load i32, i32* %Stride.addr, align 4 +// CHECK-NEXT: call void @llvm.matrix.columnwise.store.v40i32.p0i32(<40 x i32> %2, i32* %3, i32 %4, i32 10, i32 4) +// CHECK-NEXT: ret void + +// CHECK-LABEL: define linkonce_odr void @_Z18column_major_storeIU11matrix_typeLm7ELm9EiPiEvRT_T0_j([63 x i32]* dereferenceable(252) %M, i32* %Ptr, i32 %Stride) +// CHECK-NEXT: entry: +// CHECK-NEXT: %M.addr = alloca [63 x i32]*, align 8 +// CHECK-NEXT: %Ptr.addr = alloca i32*, align 8 +// CHECK-NEXT: %Stride.addr = alloca i32, align 4 +// CHECK-NEXT: store [63 x i32]* %M, [63 x i32]** %M.addr, align 8 +// CHECK-NEXT: store i32* %Ptr, i32** %Ptr.addr, align 8 +// CHECK-NEXT: store i32 %Stride, i32* %Stride.addr, align 4 +// CHECK-NEXT: %0 = load [63 x i32]*, [63 x i32]** %M.addr, align 8 +// CHECK-NEXT: %1 = bitcast [63 x i32]* %0 to <63 x i32>* +// CHECK-NEXT: %2 = load <63 x i32>, <63 x i32>* %1, align 4 +// CHECK-NEXT: %3 = load i32*, i32** %Ptr.addr, align 8 +// CHECK-NEXT: %4 = load i32, i32* %Stride.addr, align 4 +// CHECK-NEXT: call void @llvm.matrix.columnwise.store.v63i32.p0i32(<63 x i32> %2, i32* %3, i32 %4, i32 7, i32 9) +// CHECK-NEXT: ret void + +// CHECK-LABEL: define linkonce_odr void @_Z18column_major_storeIU11matrix_typeLm7ELm9EdPdEvRT_T0_j([63 x double]* dereferenceable(504) %M, double* %Ptr, i32 %Stride) +// CHECK-NEXT: entry: +// CHECK-NEXT: %M.addr = alloca [63 x double]*, align 8 +// CHECK-NEXT: %Ptr.addr = alloca double*, align 8 +// CHECK-NEXT: %Stride.addr = alloca i32, align 4 +// CHECK-NEXT: store [63 x double]* %M, [63 x double]** %M.addr, align 8 +// CHECK-NEXT: store double* %Ptr, double** %Ptr.addr, align 8 +// CHECK-NEXT: store i32 %Stride, i32* %Stride.addr, align 4 +// CHECK-NEXT: %0 = load [63 x double]*, [63 x double]** %M.addr, align 8 +// CHECK-NEXT: %1 = bitcast [63 x double]* %0 to <63 x double>* +// CHECK-NEXT: %2 = load <63 x double>, <63 x double>* %1, align 8 +// CHECK-NEXT: %3 = load double*, double** %Ptr.addr, align 8 +// CHECK-NEXT: %4 = load i32, i32* %Stride.addr, align 4 +// CHECK-NEXT: call void @llvm.matrix.columnwise.store.v63f64.p0f64(<63 x double> %2, double* %3, i32 %4, i32 7, i32 9) +// CHECK-NEXT: ret void diff --git a/clang/test/Sema/matrix-type-builtins.c b/clang/test/Sema/matrix-type-builtins.c --- a/clang/test/Sema/matrix-type-builtins.c +++ b/clang/test/Sema/matrix-type-builtins.c @@ -60,3 +60,22 @@ "", // expected-error {{column argument must be a constant unsigned integer expression}} 10); } + +void column_major_store(sx5x10_t *m1, ix3x2_t *m2, float *p1, int *p2, struct Foo *p3) { + __builtin_matrix_column_major_store(*m1, p1, 1); + // expected-error@-1 {{stride must be greater or equal to the number of rows}} + __builtin_matrix_column_major_store(*m1, p2, 10); + // expected-error@-1 {{the pointee of the second argument must match the element type of the first argument ('int' != 'float')}} + __builtin_matrix_column_major_store(p1, p2, 10); + // expected-error@-1 {{first argument must be a matrix}} + + __builtin_matrix_column_major_store( + "", // expected-error {{first argument must be a matrix}} + 10, // expected-error {{second argument must be a pointer to a valid matrix element type}} + 10); + + *m1 = __builtin_matrix_column_major_store(*m1, p1, 10); + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'void'}} + int x = __builtin_matrix_column_major_store(*m1, p1, 10); + // expected-error@-1 {{initializing 'int' with an expression of incompatible type 'void'}} +} diff --git a/clang/test/SemaCXX/matrix-type-builtins.cpp b/clang/test/SemaCXX/matrix-type-builtins.cpp --- a/clang/test/SemaCXX/matrix-type-builtins.cpp +++ b/clang/test/SemaCXX/matrix-type-builtins.cpp @@ -56,3 +56,23 @@ Mat1.value = column_major_load(Mat2, Ptr2); // expected-note@-1 {{in instantiation of function template specialization 'column_major_load' requested here}} } + +template +void column_major_store(MyMatrix &A, PtrTy Ptr, unsigned Stride) { + __builtin_matrix_column_major_store(A.value, Ptr, Stride); + // expected-error@-1 {{the pointee of the second argument must match the element type of the first argument ('float' != 'unsigned int')}} +} + +template +void column_major_store(MTy &A, PtrTy Ptr) { + __builtin_matrix_column_major_store(A.value, Ptr, Stride); + // expected-error@-1 {{stride must be greater or equal to the number of rows}} +} + +void test_column_major_stores_template(MyMatrix &M1, unsigned *Ptr1, MyMatrix &M2, float *Ptr2) { + column_major_store(M1, Ptr2, 10); + // expected-note@-1 {{in instantiation of function template specialization 'column_major_store' requested here}} + + column_major_store(M2, Ptr2); + // expected-note@-1 {{in instantiation of function template specialization 'column_major_store &, float *, 1>' requested here}} +}