diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -577,6 +577,7 @@ BUILTIN(__builtin_matrix_extract, "v.", "nt") BUILTIN(__builtin_matrix_subtract, "v.", "nt") BUILTIN(__builtin_matrix_add, "v.", "nt") +BUILTIN(__builtin_matrix_multiply, "v.", "nt") // "Overloaded" Atomic operator builtins. These are overloaded to support data // types of i8, i16, i32, i64, and i128. The front-end sees calls to the diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10287,18 +10287,27 @@ def err_builtin_matrix_disabled: Error< "Builtin matrix support is disabled. Pass -fenable-matrix to enable it.">; +def err_builtin_matrix_element_type: Error< + "Element types of input matrixes do not match (%0 != %1)">; + def err_builtin_matrix_arg: Error< "%select{First|Second}0 argument must be a matrix">; +def err_builtin_matrix_pointer_arg: Error< + "%select{First|Second}0 argument must be a %select{pointer|pointer to integers or floats}1">; + def err_builtin_matrix_scalar_int_arg: Error< "%select{Row|Column|Offset|Stride}0 argument must be %select{an unsigned integer|a constant unsigned integer expression}1">; -def err_builtin_matrix_implicit_cast_error: Error< - "Implicit cast to from %0 to %1 failed">; - def err_builtin_matrix_type_match: Error< "Matrix types must match">; +def err_builtin_matrix_dimension_error: Error< + "Matrix dimensions do not match operation">; + +def err_builtin_matrix_implicit_cast_error: Error< + "Implicit cast to from %0 to %1 failed">; + def err_preserve_field_info_not_field : Error< "__builtin_preserve_field_info argument %0 not a field access">; def err_preserve_field_info_not_const: Error< diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11618,6 +11618,8 @@ ExprResult CallResult); ExprResult SemaBuiltinMatrixEltwiseOverload(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall, + ExprResult CallResult); public: enum FormatStringType { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2380,6 +2380,20 @@ Value *Result = MB.CreateSub(Matrix1, Matrix2); return RValue::get(Result); } + + case Builtin::BI__builtin_matrix_multiply: { + MatrixBuilder MB(Builder); + Value *Matrix1 = EmitScalarExpr(E->getArg(0)); + Value *Matrix2 = EmitScalarExpr(E->getArg(1)); + + const MatrixType *Matrix1Ty = getMatrixTy(E->getArg(0)->getType()); + const MatrixType *Matrix2Ty = getMatrixTy(E->getArg(1)->getType()); + Value *Result = MB.CreateMatrixMultiply( + Matrix1, Matrix2, Matrix1Ty->getNumRows(), Matrix1Ty->getNumColumns(), + Matrix2Ty->getNumColumns()); + return RValue::get(Result); + } + case Builtin::BIfinite: case Builtin::BI__finite: case Builtin::BIfinitef: diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1617,6 +1617,7 @@ case Builtin::BI__builtin_matrix_extract: case Builtin::BI__builtin_matrix_add: case Builtin::BI__builtin_matrix_subtract: + case Builtin::BI__builtin_matrix_multiply: if (!getLangOpts().EnableMatrix) { Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled); return ExprError(); @@ -1630,6 +1631,8 @@ case Builtin::BI__builtin_matrix_add: case Builtin::BI__builtin_matrix_subtract: return SemaBuiltinMatrixEltwiseOverload(TheCall, TheCallResult); + case Builtin::BI__builtin_matrix_multiply: + return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult); default: llvm_unreachable("All matrix builtins should be handled here!"); } @@ -15372,3 +15375,98 @@ return CallResult; } + +ExprResult Sema::SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall, + ExprResult CallResult) { + if (checkArgCount(*this, TheCall, 2)) + return ExprError(); + + Expr *Callee = TheCall->getCallee(); + DeclRefExpr *DRE = cast(Callee->IgnoreParenCasts()); + FunctionDecl *FDecl = cast(DRE->getDecl()); + + Expr *AArg = TheCall->getArg(0); + Expr *BArg = TheCall->getArg(1); + + bool ArgError = false; + // Some very basic type checking, both parameters must be matrices + if (!AArg->getType()->isMatrixType()) { + Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_arg) << 0; + ArgError = true; + } + if (!BArg->getType()->isMatrixType()) { + Diag(BArg->getBeginLoc(), diag::err_builtin_matrix_arg) << 1; + ArgError = true; + } + if (ArgError) + return ExprError(); + + MatrixType const *AMType = + cast(AArg->getType().getCanonicalType()); + MatrixType const *BMType = + cast(BArg->getType().getCanonicalType()); + + unsigned m = AMType->getNumRows(); + unsigned n = AMType->getNumColumns(); + unsigned r = BMType->getNumColumns(); + // Full Type Checking + + // Requirements: + // A (m x n) * B (n x r) = AB (m x r) + // The A Column must match the number of rows in B + + if (BMType->getNumRows() != n) { + Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_dimension_error); + return ExprError(); + } + + // Element types of both matrices must match + if (AMType->getElementType() != BMType->getElementType()) { + Diag(AArg->getBeginLoc(), diag::err_builtin_matrix_element_type) + << AMType->getElementType() << BMType->getElementType(); + return ExprError(); + } + + // Set up the function prototype + + if (!AArg->isRValue()) { + ExprResult Res = ImplicitCastExpr::Create( + Context, AArg->getType(), CK_LValueToRValue, AArg, nullptr, VK_RValue); + assert(!Res.isInvalid() && "Matrix Cast failed"); + TheCall->setArg(0, Res.get()); + } + + if (!BArg->isRValue()) { + ExprResult Res = ImplicitCastExpr::Create( + Context, BArg->getType(), CK_LValueToRValue, BArg, nullptr, VK_RValue); + assert(!Res.isInvalid() && "Matrix Cast failed"); + TheCall->setArg(1, Res.get()); + } + + // Function Return Type + QualType ReturnElementType = AMType->getElementType(); + QualType ResultType = Context.getMatrixType(ReturnElementType, m, r); + + llvm::SmallVector ParameterTypes = { + AArg->getType().getCanonicalType(), BArg->getType().getCanonicalType()}; + + // Create a new DeclRefExpr to refer to the new decl. + DeclRefExpr *NewDRE = DeclRefExpr::Create( + Context, DRE->getQualifierLoc(), SourceLocation(), FDecl, + /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy, + DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse()); + + // Set the callee in the CallExpr. + // FIXME: This loses syntactic information. + QualType CalleePtrTy = Context.getPointerType(FDecl->getType()); + ExprResult PromotedCall = + ImpCastExprToType(NewDRE, CalleePtrTy, CK_BuiltinFnToFnPtr); + TheCall->setCallee(PromotedCall.get()); + + // Change the result type of the call to match the original value type. This + // is arbitrary, but the codegen for these builtins ins design to handle it + // gracefully. + TheCall->setType(ResultType); + + return CallResult; +} diff --git a/clang/test/CodeGen/builtin-matrix.c b/clang/test/CodeGen/builtin-matrix.c --- a/clang/test/CodeGen/builtin-matrix.c +++ b/clang/test/CodeGen/builtin-matrix.c @@ -225,3 +225,30 @@ // CHECK-NEXT: store <27 x i32> %11, <27 x i32>* %3, align 4 // CHECK-NEXT: ret void } + +void multiply1(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = __builtin_matrix_multiply(*b, *c); + + // CHECK-LABEL: @multiply1( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]] + +// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn }