diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -578,6 +578,7 @@ BUILTIN(__builtin_matrix_subtract, "v.", "nt") BUILTIN(__builtin_matrix_add, "v.", "nt") BUILTIN(__builtin_matrix_multiply, "v.", "nt") +BUILTIN(__builtin_matrix_transpose, "v.", "nFt") // "Overloaded" Atomic operator builtins. These are overloaded to support data // types of i8, i16, i32, i64, and i128. The front-end sees calls to the diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11620,6 +11620,9 @@ ExprResult CallResult); ExprResult SemaBuiltinMatrixMultiplyOverload(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall, + ExprResult CallResult); + public: enum FormatStringType { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2366,6 +2366,15 @@ return RValue::get(Result); } + case Builtin::BI__builtin_matrix_transpose: { + const MatrixType *MatrixTy = getMatrixTy(E->getArg(0)->getType()); + Value *MatValue = EmitScalarExpr(E->getArg(0)); + MatrixBuilder MB(Builder); + Value *Result = MB.CreateMatrixTranspose( + MatValue, MatrixTy->getNumRows(), MatrixTy->getNumColumns()); + return RValue::get(Result); + } + case Builtin::BI__builtin_matrix_add: { MatrixBuilder MB(Builder); Value *Matrix1 = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1618,6 +1618,7 @@ case Builtin::BI__builtin_matrix_add: case Builtin::BI__builtin_matrix_subtract: case Builtin::BI__builtin_matrix_multiply: + case Builtin::BI__builtin_matrix_transpose: if (!getLangOpts().EnableMatrix) { Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled); return ExprError(); @@ -1633,6 +1634,8 @@ return SemaBuiltinMatrixEltwiseOverload(TheCall, TheCallResult); case Builtin::BI__builtin_matrix_multiply: return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult); + case Builtin::BI__builtin_matrix_transpose: + return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult); default: llvm_unreachable("All matrix builtins should be handled here!"); } @@ -15470,3 +15473,60 @@ return CallResult; } + +ExprResult Sema::SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall, + ExprResult CallResult) { + if (checkArgCount(*this, TheCall, 1)) + return ExprError(); + + Expr *Arg = TheCall->getArg(0); + + // Some very basic type chekcing, the parameter must be a matrix + if (!Arg->getType()->isMatrixType()) { + Diag(Arg->getBeginLoc(), diag::err_builtin_matrix_arg) << 0; + return ExprError(); + } + + MatrixType const *MType = + cast(Arg->getType().getCanonicalType()); + + unsigned R = MType->getNumRows(); + unsigned C = MType->getNumColumns(); + // Full Type Checking + + // Set up the function prototype + + if (!Arg->isRValue()) { + ExprResult Res = ImplicitCastExpr::Create( + Context, Arg->getType(), CK_LValueToRValue, Arg, nullptr, VK_RValue); + assert(!Res.isInvalid() && "Matrix Cast failed"); + TheCall->setArg(0, Res.get()); + } + + Expr *Callee = TheCall->getCallee(); + DeclRefExpr *DRE = cast(Callee->IgnoreParenCasts()); + FunctionDecl *FDecl = cast(DRE->getDecl()); + + // Function Return Type + QualType ReturnElementType = MType->getElementType(); + QualType ResultType = Context.getMatrixType(ReturnElementType, C, R); + + // Create a new DeclRefExpr to refer to the new decl. + DeclRefExpr *NewDRE = DeclRefExpr::Create( + Context, DRE->getQualifierLoc(), SourceLocation(), FDecl, + /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy, + DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse()); + + // Set the callee in the CallExpr. + // FIXME: This loses syntactic information. + QualType CalleePtrTy = Context.getPointerType(FDecl->getType()); + ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy, + CK_BuiltinFnToFnPtr); + TheCall->setCallee(PromotedCall.get()); + + // Change the result type of the call to match the original value type. This + // is arbitrary, but the codegen for these builtins ins design to handle it + // gracefully. + TheCall->setType(ResultType); + return CallResult; +} diff --git a/clang/test/CodeGen/builtin-matrix.c b/clang/test/CodeGen/builtin-matrix.c --- a/clang/test/CodeGen/builtin-matrix.c +++ b/clang/test/CodeGen/builtin-matrix.c @@ -251,4 +251,24 @@ } // CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]] +void transpose1(dx5x5_t *a, dx5x5_t *b) { + *a = __builtin_matrix_transpose(*b); + + // CHECK-LABEL: @transpose1( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = call <25 x double> @llvm.matrix.transpose.v25f64(<25 x double> %2, i32 5, i32 5) + // CHECK-NEXT: %4 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %5 = bitcast [25 x double]* %4 to <25 x double>* + // CHECK-NEXT: store <25 x double> %3, <25 x double>* %5, align 8 + // CHECK-NEXT: ret void +} +// CHECK: declare <25 x double> @llvm.matrix.transpose.v25f64(<25 x double>, i32 immarg, i32 immarg) [[READNONE]] + // CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn }