diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -579,6 +579,7 @@ BUILTIN(__builtin_matrix_add, "v.", "nt") BUILTIN(__builtin_matrix_multiply, "v.", "nt") BUILTIN(__builtin_matrix_transpose, "v.", "nFt") +BUILTIN(__builtin_matrix_scalar_multiply, "v.", "nFt") BUILTIN(__builtin_matrix_column_load, "v.", "nFt") BUILTIN(__builtin_matrix_column_store, "v.", "nFt") diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10308,6 +10308,10 @@ def err_builtin_matrix_implicit_cast_error: Error< "Implicit cast to from %0 to %1 failed">; +def err_builtin_matrix_scalar_type_error: Error< + "%select{First|Scalar}0 argument must be a " + "%select{matrix|float or integer}0">; + def err_preserve_field_info_not_field : Error< "__builtin_preserve_field_info argument %0 not a field access">; def err_preserve_field_info_not_const: Error< diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11618,6 +11618,8 @@ ExprResult CallResult); ExprResult SemaBuiltinMatrixEltwiseOverload(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixScalarOverload(CallExpr *TheCall, + ExprResult CallResult); ExprResult SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall, ExprResult CallResult); ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall, diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2451,6 +2451,14 @@ return RValue::get(Result); } + case Builtin::BI__builtin_matrix_scalar_multiply: { + MatrixBuilder MB(Builder); + Value *Matrix = EmitScalarExpr(E->getArg(0)); + Value *Scalar = EmitScalarExpr(E->getArg(1)); + Value *Result = MB.CreateScalarMultiply(Matrix, Scalar); + return RValue::get(Result); + } + case Builtin::BIfinite: case Builtin::BI__finite: case Builtin::BIfinitef: diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1619,6 +1619,7 @@ case Builtin::BI__builtin_matrix_subtract: case Builtin::BI__builtin_matrix_multiply: case Builtin::BI__builtin_matrix_transpose: + case Builtin::BI__builtin_matrix_scalar_multiply: case Builtin::BI__builtin_matrix_column_load: case Builtin::BI__builtin_matrix_column_store: if (!getLangOpts().EnableMatrix) { @@ -1638,6 +1639,8 @@ return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult); case Builtin::BI__builtin_matrix_transpose: return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult); + case Builtin::BI__builtin_matrix_scalar_multiply: + return SemaBuiltinMatrixScalarOverload(TheCall, TheCallResult); case Builtin::BI__builtin_matrix_column_load: return SemaBuiltinMatrixColumnLoadOverload(TheCall, TheCallResult); case Builtin::BI__builtin_matrix_column_store: @@ -15537,6 +15540,89 @@ return CallResult; } +ExprResult Sema::SemaBuiltinMatrixScalarOverload(CallExpr *TheCall, + ExprResult CallResult) { + if (checkArgCount(*this, TheCall, 2)) { + return ExprError(); + } + + // First argument must be a matrix type + Expr *MatrixArg = TheCall->getArg(0); + Expr *ScalarArg = TheCall->getArg(1); + + if (!MatrixArg->getType()->isMatrixType()) { + Diag(MatrixArg->getBeginLoc(), diag::err_builtin_matrix_scalar_type_error) + << 0; + return ExprError(); + } + + MatrixType const *MType = + cast(MatrixArg->getType().getCanonicalType()); + + // If the scalar type and matrix type don't match, try to cast it, otherwise, + // be sad + if (MType->getElementType() != ScalarArg->getType()) { + ExprResult TypeCastRes = ImplicitCastExpr::Create( + Context, MType->getElementType(), CK_IntegralToFloating, ScalarArg, + nullptr, VK_RValue); + + if (!ScalarArg->getType()->isFloatingType() && + !ScalarArg->getType()->isIntegralType(Context)) { + Diag(ScalarArg->getBeginLoc(), diag::err_builtin_matrix_scalar_type_error) + << 1; + return ExprError(); + } + + if (TypeCastRes.isInvalid()) { + Diag(MatrixArg->getBeginLoc(), + diag::err_builtin_matrix_implicit_cast_error) + << MType->getElementType() << ScalarArg->getType(); + return ExprError(); + } + + ScalarArg = TypeCastRes.get(); + TheCall->setArg(1, ScalarArg); + } + + if (!MatrixArg->isRValue()) { + ExprResult CastExprResult = ImplicitCastExpr::Create( + Context, MatrixArg->getType(), CK_LValueToRValue, MatrixArg, nullptr, + VK_RValue); + assert(!CastExprResult.isInvalid() && "Matrix cast to R-value failed"); + MatrixArg = CastExprResult.get(); + TheCall->setArg(0, MatrixArg); + } + + // Create the new function prototype + llvm::SmallVector ParameterTypes = {MatrixArg->getType(), + ScalarArg->getType()}; + + Expr *Callee = TheCall->getCallee(); + DeclRefExpr *DRE = cast(Callee->IgnoreParenCasts()); + FunctionDecl *FDecl = cast(DRE->getDecl()); + + // Create a new DeclRefExpr to refer to the new decl. + DeclRefExpr *NewDRE = DeclRefExpr::Create( + Context, DRE->getQualifierLoc(), SourceLocation(), FDecl, + /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy, + DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse()); + + // Set the callee in the CallExpr. + // FIXME: This loses syntactic information. + QualType CalleePtrTy = Context.getPointerType(FDecl->getType()); + ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy, + CK_BuiltinFnToFnPtr); + TheCall->setCallee(PromotedCall.get()); + + // Change the result type of the call to match the original value type. This + // is arbitrary, but the codegen for these builtins ins design to handle it + // gracefully. + TheCall->setType(MatrixArg->getType()); + + + return CallResult; +} + ExprResult Sema::SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall, ExprResult CallResult) { // Must have exactly four operands