diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -579,6 +579,7 @@ BUILTIN(__builtin_matrix_add, "v.", "nt") BUILTIN(__builtin_matrix_multiply, "v.", "nt") BUILTIN(__builtin_matrix_transpose, "v.", "nFt") +BUILTIN(__builtin_matrix_column_load, "v.", "nFt") // "Overloaded" Atomic operator builtins. These are overloaded to support data // types of i8, i16, i32, i64, and i128. The front-end sees calls to the diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -11623,6 +11623,8 @@ ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall, + ExprResult CallResult); public: enum FormatStringType { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2340,6 +2340,39 @@ V = Builder.CreateFCmpUNO(V, V, "cmp"); return RValue::get(Builder.CreateZExt(V, ConvertType(E->getType()))); } + case Builtin::BI__builtin_matrix_column_load: { + MatrixBuilder MB(Builder); + // Emit everything that isn't dependent on the first parameter type + Value *Stride = EmitScalarExpr(E->getArg(3)); + const MatrixType *ResultTy = getMatrixTy(E->getType()); + + // If it's an address we need to emit the pointer + // otherwise, emit the array + Value *Result = nullptr; + if (const PointerType *PTy = + dyn_cast(E->getArg(0)->getType())) { + Address Src = EmitPointerWithAlignment(E->getArg(0)); + EmitNonNullArgCheck(RValue::get(Src.getPointer()), + E->getArg(0)->getType(), E->getArg(0)->getExprLoc(), + FD, 0); + Result = MB.CreateMatrixColumnwiseLoad( + Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(), + Stride, "matrix"); + } else if (const ArrayType *ATy = + dyn_cast(E->getArg(0)->getType())) { + Address Src = EmitArrayToPointerDecay(E->getArg(0)); + EmitNonNullArgCheck(RValue::get(Src.getPointer()), + E->getArg(0)->getType(), E->getArg(0)->getExprLoc(), + FD, 0); + Result = MB.CreateMatrixColumnwiseLoad( + Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(), + Stride, "matrix"); + } else { + llvm_unreachable( + "CGBuiltin.cpp: First argument must either be a pointer or an array"); + } + return RValue::get(Result); + } case Builtin::BI__builtin_matrix_insert: { MatrixBuilder MB(Builder); Value *MatValue = EmitScalarExpr(E->getArg(0)); diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1619,6 +1619,7 @@ case Builtin::BI__builtin_matrix_subtract: case Builtin::BI__builtin_matrix_multiply: case Builtin::BI__builtin_matrix_transpose: + case Builtin::BI__builtin_matrix_column_load: if (!getLangOpts().EnableMatrix) { Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled); return ExprError(); @@ -1636,6 +1637,8 @@ return SemaBuiltinMatrixMultiplyOverload(TheCall, TheCallResult); case Builtin::BI__builtin_matrix_transpose: return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult); + case Builtin::BI__builtin_matrix_column_load: + return SemaBuiltinMatrixColumnLoadOverload(TheCall, TheCallResult); default: llvm_unreachable("All matrix builtins should be handled here!"); } @@ -15530,3 +15533,121 @@ TheCall->setType(ResultType); return CallResult; } + +ExprResult Sema::SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall, + ExprResult CallResult) { + // Must have exactly four operands + // 1: Pointer to data + // 2: Rows (constant) + // 3: Columns (constant) + // 5: Stride + + // Operands have very similar semantics to glVertexAttribPointer from OpenGL. + // Instead of the attribute index, it is a pointer to the memory that is being + // loaded from Instead of size, we need the rows and columns. Note that these + // must be constant to construct the matrix type. + + if (checkArgCount(*this, TheCall, 4)) + return ExprError(); + + Expr *DataExpr = TheCall->getArg(0); + Expr *RowsExpr = TheCall->getArg(1); + Expr *ColsExpr = TheCall->getArg(2); + Expr *StrideExpr = TheCall->getArg(3); + + unsigned Rows = 0; + unsigned Cols = 0; + + if (!(DataExpr->getType()->isPointerType() || + DataExpr->getType()->isArrayType())) { + Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) + << 0 << 0; + } + + bool ArgError = false; + // get the matrix dimensions + { + llvm::APSInt Value(32); + SourceLocation RowColErrorPos; + + if (!RowsExpr->isIntegerConstantExpr(Value, Context, &RowColErrorPos)) { + Diag(RowsExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg) + << 0 << 1; + ArgError = true; + } else + Rows = Value.getZExtValue(); + + if (!ColsExpr->isIntegerConstantExpr(Value, Context, &RowColErrorPos)) { + Diag(ColsExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg) + << 1 << 1; + ArgError = true; + } else + Cols = Value.getZExtValue(); + } + if (!StrideExpr->getType()->isIntegralType(Context)) { + Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg) + << 3 << 1; + ArgError = true; + } + if (ArgError) + return ExprError(); + + QualType ElementType; + + if (const PointerType *PTy = dyn_cast(DataExpr->getType())) { + ElementType = PTy->getPointeeType(); + } else if (const ArrayType *ATy = dyn_cast(DataExpr->getType())) { + ElementType = ATy->getElementType(); + } else { + llvm_unreachable("Pointer Expression must be a pointer or an array"); + return ExprError(); + } + ElementType.removeLocalConst(); + + if (!ElementType->isIntegralType(Context) && !ElementType->isFloatingType()) { + Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) + << 0 << 1; + return ExprError(); + } + + // TODO: Check this, it seems weird to have to cast a pointer to an l-value + // I guess it needs to be materialized as a pointer before we can work with it + if (!DataExpr->isRValue()) { + ExprResult CastExprResult = ImplicitCastExpr::Create( + Context, DataExpr->getType(), CK_LValueToRValue, DataExpr, nullptr, + VK_RValue); + assert(!CastExprResult.isInvalid() && + "Pointer failed to be casted to an R-value"); + DataExpr = CastExprResult.get(); + TheCall->setArg(0, DataExpr); + } + + QualType ReturnType = Context.getMatrixType(ElementType, Rows, Cols); + + llvm::SmallVector ParameterTypes = { + DataExpr->getType().withConst(), RowsExpr->getType().withConst(), + ColsExpr->getType().withConst(), StrideExpr->getType().withConst()}; + Expr *Callee = TheCall->getCallee(); + DeclRefExpr *DRE = cast(Callee->IgnoreParenCasts()); + FunctionDecl *FDecl = cast(DRE->getDecl()); + + // Create a new DeclRefExpr to refer to the new decl. + DeclRefExpr *NewDRE = DeclRefExpr::Create( + Context, DRE->getQualifierLoc(), SourceLocation(), FDecl, + /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy, + DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse()); + + // Set the callee in the CallExpr. + // FIXME: This loses syntactic information. + QualType CalleePtrTy = Context.getPointerType(FDecl->getType()); + ExprResult PromotedCall = + ImpCastExprToType(NewDRE, CalleePtrTy, CK_BuiltinFnToFnPtr); + TheCall->setCallee(PromotedCall.get()); + + // Change the result type of the call to match the original value type. This + // is arbitrary, but the codegen for these builtins ins design to handle it + // gracefully. + TheCall->setType(ReturnType); + + return CallResult; +} diff --git a/clang/test/CodeGen/builtin-matrix.c b/clang/test/CodeGen/builtin-matrix.c --- a/clang/test/CodeGen/builtin-matrix.c +++ b/clang/test/CodeGen/builtin-matrix.c @@ -271,4 +271,23 @@ } // CHECK: declare <25 x double> @llvm.matrix.transpose.v25f64(<25 x double>, i32 immarg, i32 immarg) [[READNONE]] +void column_load1(dx5x5_t *a, double *b) { + *a = __builtin_matrix_column_load(b, 5, 5, 10); + + // CHECK-LABEL: @column_load1( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca double*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store double* %b, double** %b.addr, align 8 + // CHECK-NEXT: %0 = load double*, double** %b.addr, align 8 + // CHECK-NEXT: %matrix = call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* %0, i32 10, i32 5, i32 5) + // CHECK-NEXT: %1 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %2 = bitcast [25 x double]* %1 to <25 x double>* + // CHECK-NEXT: store <25 x double> %matrix, <25 x double>* %2, align 8 + // CHECK-NEXT: ret void +} +// CHECK: declare <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double*, i32, i32 immarg, i32 immarg) [[READONLY:#[0-9]]] + // CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn } +// CHECK: attributes [[READONLY]] = { nounwind readonly willreturn }