diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -580,6 +580,7 @@ BUILTIN(__builtin_matrix_multiply, "v.", "nt") BUILTIN(__builtin_matrix_transpose, "v.", "nFt") BUILTIN(__builtin_matrix_column_load, "v.", "nFt") +BUILTIN(__builtin_matrix_column_store, "v.", "nFt") // "Overloaded" Atomic operator builtins. These are overloaded to support data // types of i8, i16, i32, i64, and i128. The front-end sees calls to the diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11625,6 +11625,8 @@ ExprResult SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixColumnStoreOverload(CallExpr *TheCall, + ExprResult CallResult); public: enum FormatStringType { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2373,6 +2373,21 @@ } return RValue::get(Result); } + case Builtin::BI__builtin_matrix_column_store: { + MatrixBuilder MB(Builder); + Value *Matrix = EmitScalarExpr(E->getArg(0)); + const MatrixType *MatrixTy = getMatrixTy(E->getArg(0)->getType()); + Address Dst = EmitPointerWithAlignment(E->getArg(1)); + EmitNonNullArgCheck(RValue::get(Dst.getPointer()), E->getArg(1)->getType(), + E->getArg(1)->getExprLoc(), FD, 1); + Value *Stride = EmitScalarExpr(E->getArg(2)); + + // TODO: Pass Dst alignment to intrinsic + MB.CreateMatrixColumnwiseStore(Matrix, Dst.getPointer(), Stride, + MatrixTy->getNumRows(), + MatrixTy->getNumColumns()); + return RValue::get(Dst.getPointer()); + } case Builtin::BI__builtin_matrix_insert: { MatrixBuilder MB(Builder); Value *MatValue = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1620,6 +1620,7 @@ case Builtin::BI__builtin_matrix_multiply: case Builtin::BI__builtin_matrix_transpose: case Builtin::BI__builtin_matrix_column_load: + case Builtin::BI__builtin_matrix_column_store: if (!getLangOpts().EnableMatrix) { Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled); return ExprError(); @@ -1639,6 +1640,8 @@ return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult); case Builtin::BI__builtin_matrix_column_load: return SemaBuiltinMatrixColumnLoadOverload(TheCall, TheCallResult); + case Builtin::BI__builtin_matrix_column_store: + return SemaBuiltinMatrixColumnStoreOverload(TheCall, TheCallResult); default: llvm_unreachable("All matrix builtins should be handled here!"); } @@ -15651,3 +15654,81 @@ return CallResult; } + +ExprResult Sema::SemaBuiltinMatrixColumnStoreOverload(CallExpr *TheCall, + ExprResult CallResult) { + // Must have + // 1: Matrix to store + // 2: Pointer to store to + // 3: Stride (unsigned) + + if (checkArgCount(*this, TheCall, 3)) + return ExprError(); + + Expr *MatrixExpr = TheCall->getArg(0); + Expr *DataExpr = TheCall->getArg(1); + Expr *StrideExpr = TheCall->getArg(2); + + bool ArgError = false; + if (!MatrixExpr->getType()->isMatrixType()) { + Diag(MatrixExpr->getBeginLoc(), diag::err_builtin_matrix_arg) << 0; + ArgError = true; + } + if (!DataExpr->getType()->isPointerType()) { + Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) + << 1 << 0; + ArgError = true; + } + if (!StrideExpr->getType()->isIntegralType(Context)) { + Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg) + << 3 << 0; + ArgError = true; + } + if (ArgError) + return ExprError(); + + // TODO: Check element type compatibility, and possibly up/down cast element + // types + + // Cast matrix to an rvalue + if (!MatrixExpr->isRValue()) { + ExprResult CastExprResult = ImplicitCastExpr::Create( + Context, MatrixExpr->getType(), CK_LValueToRValue, MatrixExpr, nullptr, + VK_RValue); + assert(!CastExprResult.isInvalid() && "Matrix cast to an R-value failed"); + MatrixExpr = CastExprResult.get(); + TheCall->setArg(0, MatrixExpr); + } + + if (!DataExpr->isRValue()) { + ExprResult CastExprResult = ImplicitCastExpr::Create( + Context, DataExpr->getType(), CK_LValueToRValue, DataExpr, nullptr, + VK_RValue); + assert(!CastExprResult.isInvalid() && "Pointer cast to R-value failed"); + DataExpr = CastExprResult.get(); + TheCall->setArg(1, DataExpr); + } + + llvm::SmallVector ParameterTypes = { + MatrixExpr->getType().withConst(), DataExpr->getType(), + StrideExpr->getType().withConst()}; + + Expr *Callee = TheCall->getCallee(); + DeclRefExpr *DRE = cast(Callee->IgnoreParenCasts()); + FunctionDecl *FDecl = cast(DRE->getDecl()); + + // Create a new DeclRefExpr to refer to the new decl. + DeclRefExpr *NewDRE = DeclRefExpr::Create( + Context, DRE->getQualifierLoc(), SourceLocation(), FDecl, + /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy, + DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse()); + + // Set the callee in the CallExpr. + // FIXME: This loses syntactic information. + QualType CalleePtrTy = Context.getPointerType(FDecl->getType()); + ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy, + CK_BuiltinFnToFnPtr); + TheCall->setCallee(PromotedCall.get()); + + return CallResult; +} diff --git a/clang/test/CodeGen/builtin-matrix.c b/clang/test/CodeGen/builtin-matrix.c --- a/clang/test/CodeGen/builtin-matrix.c +++ b/clang/test/CodeGen/builtin-matrix.c @@ -289,5 +289,24 @@ } // CHECK: declare <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double*, i32, i32 immarg, i32 immarg) [[READONLY:#[0-9]]] +void column_store1(dx5x5_t *a, double *b) { + __builtin_matrix_column_store(*a, b, 10); + + // CHECK-LABEL: @column_store1( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca double*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store double* %b, double** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double*, double** %b.addr, align 8 + // CHECK-NEXT: call void @llvm.matrix.columnwise.store.v25f64.p0f64(<25 x double> %2, double* %3, i32 10, i32 5, i32 5) + // CHECK-NEXT: ret void +} +// CHECK: declare void @llvm.matrix.columnwise.store.v25f64.p0f64(<25 x double>, double* writeonly, i32, i32 immarg, i32 immarg) [[WRITEONLY:#[0-9]]] + // CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn } // CHECK: attributes [[READONLY]] = { nounwind readonly willreturn } +// CHECK: attributes [[WRITEONLY]] = { nounwind willreturn }