diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -3476,6 +3476,11 @@ NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension; } + /// Returns the maximum number of elements per dimension. + static unsigned getMaxElementsPerDimension() { + return ConstantMatrixTypeBitfields::MaxElementsPerDimension; + } + void Profile(llvm::FoldingSetNodeID &ID) { Profile(ID, getElementType(), getNumRows(), getNumColumns(), getTypeClass()); diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -578,6 +578,7 @@ BUILTIN(__builtin_call_with_static_chain, "v.", "nt") BUILTIN(__builtin_matrix_transpose, "v.", "nFt") +BUILTIN(__builtin_matrix_column_major_load, "v.", "nFt") // "Overloaded" Atomic operator builtins. These are overloaded to support data // types of i8, i16, i32, i64, and i128. The front-end sees calls to the diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -10781,6 +10781,18 @@ def err_builtin_matrix_arg: Error< "%select{first|second}0 argument must be a matrix">; +def err_builtin_matrix_scalar_int_arg: Error< + "%select{row|column|stride}0 argument must be %select{an unsigned integer|a constant unsigned integer expression}1">; + +def err_builtin_matrix_pointer_arg: Error< + "%select{first|second}0 argument must be a pointer to a valid matrix element type">; + +def err_builtin_matrix_stride_too_small: Error< + "stride must be greater or equal to the number of rows">; + +def err_builtin_matrix_invalid_dimension: Error< + "%select{row|column}0 dimension is outside the allowed range [1, %1]">; + def err_preserve_field_info_not_field : Error< "__builtin_preserve_field_info argument %0 not a field access">; def err_preserve_field_info_not_const: Error< diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -4703,6 +4703,10 @@ bool tryExprAsCall(Expr &E, QualType &ZeroArgCallReturnTy, UnresolvedSetImpl &NonTemplateOverloads); + /// Try to convert an expression \p E to type \p Ty. Returns the result of the + /// conversion. + ExprResult tryConvertExprToTy(Expr *E, QualType Ty); + /// Conditionally issue a diagnostic based on the current /// evaluation context. /// @@ -12118,6 +12122,8 @@ // Matrix builtin handling. ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall, + ExprResult CallResult); public: enum FormatStringType { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -2383,6 +2383,31 @@ return RValue::get(Result); } + case Builtin::BI__builtin_matrix_column_major_load: { + MatrixBuilder MB(Builder); + // Emit everything that isn't dependent on the first parameter type + Value *Stride = EmitScalarExpr(E->getArg(3)); + const auto *ResultTy = E->getType()->getAs(); + + QualType PtrTy = E->getArg(0)->getType(); + // If it's an address we need to emit the pointer + // otherwise, emit the array + Address Src = Address::invalid(); + if (isa(PtrTy)) + Src = EmitPointerWithAlignment(E->getArg(0)); + else if (isa(PtrTy)) + Src = EmitArrayToPointerDecay(E->getArg(0)); + else + llvm_unreachable("first argument must either be a pointer or an array"); + + EmitNonNullArgCheck(RValue::get(Src.getPointer()), PtrTy, + E->getArg(0)->getExprLoc(), FD, 0); + Value *Result = MB.CreateMatrixColumnwiseLoad( + Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(), + Stride, "matrix"); + return RValue::get(Result); + } + case Builtin::BIfinite: case Builtin::BI__finite: case Builtin::BIfinitef: diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1915,6 +1915,9 @@ case Builtin::BI__builtin_matrix_transpose: return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult); + + case Builtin::BI__builtin_matrix_column_major_load: + return SemaBuiltinMatrixColumnMajorLoadOverload(TheCall, TheCallResult); } // Since the target specific builtins for each arch overlap, only check those @@ -15066,3 +15069,135 @@ TheCall->setArg(0, Matrix); return CallResult; } + +// Get and verify the matrix dimensions. +static llvm::Optional +getAndVerifyMatrixDimension(Expr *Expr, unsigned ErrIdx, Sema &S) { + llvm::APSInt Value(64); + SourceLocation ErrorPos; + if (!Expr->isIntegerConstantExpr(Value, S.Context, &ErrorPos)) { + S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg) + << ErrIdx << 1; + return {}; + } + uint64_t Dim = Value.getZExtValue(); + if (!ConstantMatrixType::isDimensionValid(Dim)) { + S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_invalid_dimension) + << ErrIdx << ConstantMatrixType::getMaxElementsPerDimension(); + return {}; + } + return Dim; +} + +ExprResult +Sema::SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall, + ExprResult CallResult) { + if (checkArgCount(*this, TheCall, 4)) + return ExprError(); + + Expr *PtrExpr = TheCall->getArg(0); + Expr *RowsExpr = TheCall->getArg(1); + Expr *ColumnsExpr = TheCall->getArg(2); + Expr *StrideExpr = TheCall->getArg(3); + + bool ArgError = false; + + // Check pointer argument. + { + ExprResult PtrConv = DefaultLvalueConversion(PtrExpr); + if (PtrConv.isInvalid()) + return PtrConv; + PtrExpr = PtrConv.get(); + } + + QualType PtrTy = PtrExpr->getType(); + // TODO: We need to loop through template substitutions properly somewhere. + if (auto *SubstTy = PtrTy->getAs()) + PtrTy = SubstTy->getReplacementType(); + + QualType ElementTy; + if (!(PtrTy->isPointerType() || PtrTy->isArrayType())) { + Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0; + ArgError = true; + } else { + if (const PointerType *PTy = dyn_cast(PtrTy)) + ElementTy = PTy->getPointeeType(); + else if (const ArrayType *ATy = dyn_cast(PtrTy)) + ElementTy = ATy->getElementType(); + else + llvm_unreachable("Pointer Expression must be a pointer or an array"); + + ElementTy.removeLocalConst(); + if (!ConstantMatrixType::isValidElementType(ElementTy)) { + Diag(PtrExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0; + ArgError = true; + } + } + + if (RowsExpr->isValueDependent() || RowsExpr->isTypeDependent() || + ColumnsExpr->isValueDependent() || ColumnsExpr->isTypeDependent()) { + QualType ReturnType = Context.getDependentSizedMatrixType( + ElementTy, RowsExpr, ColumnsExpr, {}); + TheCall->setType(ReturnType); + return CallResult; + } + + // Apply default Lvalue conversions and convert the expression to size_t. + auto ApplyArgumentConversions = [this](Expr *E) { + ExprResult Conv = DefaultLvalueConversion(E); + if (Conv.isInvalid()) + return Conv; + + return tryConvertExprToTy(Conv.get(), Context.getSizeType()); + }; + + // Check rows argument. + llvm::Optional MaybeRows; + ExprResult RowsConv = ApplyArgumentConversions(RowsExpr); + if (!RowsConv.isInvalid()) { + RowsExpr = RowsConv.get(); + MaybeRows = getAndVerifyMatrixDimension(RowsExpr, 0, *this); + } + + // Check columns argument. + llvm::Optional MaybeColumns; + ExprResult ColumnsConv = ApplyArgumentConversions(ColumnsExpr); + if (!ColumnsConv.isInvalid()) { + ColumnsExpr = ColumnsConv.get(); + MaybeColumns = getAndVerifyMatrixDimension(ColumnsExpr, 1, *this); + } + + // Check stride argument. + ExprResult StrideConv = ApplyArgumentConversions(StrideExpr); + if (StrideConv.isInvalid()) + return ExprError(); + StrideExpr = StrideConv.get(); + + if (!StrideExpr->getType()->isIntegralType(Context)) { + Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg) + << 2 << 1; + ArgError = true; + } else { + llvm::APSInt Value(64); + if (StrideExpr->isIntegerConstantExpr(Value, Context)) { + uint64_t Stride = Value.getZExtValue(); + if (MaybeRows && Stride < *MaybeRows) { + Diag(StrideExpr->getBeginLoc(), + diag::err_builtin_matrix_stride_too_small); + ArgError = true; + } + } + } + + if (ArgError || !MaybeRows || !MaybeColumns) + return ExprError(); + + QualType ReturnType = + Context.getConstantMatrixType(ElementTy, *MaybeRows, *MaybeColumns); + TheCall->setType(ReturnType); + TheCall->setArg(0, PtrExpr); + TheCall->setArg(1, RowsExpr); + TheCall->setArg(2, ColumnsExpr); + TheCall->setArg(3, StrideExpr); + return CallResult; +} diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -4674,15 +4674,12 @@ return Res; } -static bool tryConvertToTy(Sema &S, QualType ElementType, ExprResult *Scalar) { - InitializedEntity Entity = - InitializedEntity::InitializeTemporary(ElementType); - InitializationKind Kind = InitializationKind::CreateCopy( - Scalar->get()->getBeginLoc(), SourceLocation()); - Expr *Arg = Scalar->get(); - InitializationSequence InitSeq(S, Entity, Kind, Arg); - *Scalar = InitSeq.Perform(S, Entity, Kind, Arg); - return !Scalar->isInvalid(); +ExprResult Sema::tryConvertExprToTy(Expr *E, QualType Ty) { + InitializedEntity Entity = InitializedEntity::InitializeTemporary(Ty); + InitializationKind Kind = + InitializationKind::CreateCopy(E->getBeginLoc(), SourceLocation()); + InitializationSequence InitSeq(*this, Entity, Kind, E); + return InitSeq.Perform(*this, Entity, Kind, E); } ExprResult Sema::CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx, @@ -4733,11 +4730,9 @@ return nullptr; } - ExprResult ConvExpr = IndexExpr; - bool ConversionOk = tryConvertToTy(*this, Context.getSizeType(), &ConvExpr); - assert(ConversionOk && + ExprResult ConvExpr = tryConvertExprToTy(IndexExpr, Context.getSizeType()); + assert(!ConvExpr.isInvalid() && "should be able to convert any integer type to size type"); - (void)ConversionOk; return ConvExpr.get(); }; @@ -12109,13 +12104,16 @@ ExprResult OriginalLHS = LHS; ExprResult OriginalRHS = RHS; if (LHSMatType && !RHSMatType) { - if (tryConvertToTy(*this, LHSMatType->getElementType(), &RHS)) + RHS = tryConvertExprToTy(RHS.get(), LHSMatType->getElementType()); + if (!RHS.isInvalid()) return LHSType; + return InvalidOperands(Loc, OriginalLHS, OriginalRHS); } if (!LHSMatType && RHSMatType) { - if (tryConvertToTy(*this, RHSMatType->getElementType(), &LHS)) + LHS = tryConvertExprToTy(LHS.get(), RHSMatType->getElementType()); + if (!LHS.isInvalid()) return RHSType; return InvalidOperands(Loc, OriginalLHS, OriginalRHS); } diff --git a/clang/test/CodeGen/matrix-type-builtins.c b/clang/test/CodeGen/matrix-type-builtins.c --- a/clang/test/CodeGen/matrix-type-builtins.c +++ b/clang/test/CodeGen/matrix-type-builtins.c @@ -96,3 +96,50 @@ dx5x5_t m_t = __builtin_matrix_transpose(global_matrix); } + +void column_major_load_with_const_stride_double(double *Ptr) { + // CHECK-LABEL: define void @column_major_load_with_const_stride_double(double* %Ptr) + // CHECK: [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8 + // CHECK-NEXT: call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* [[PTR]], i64 5, i64 5, i64 5) + + dx5x5_t m_a1 = __builtin_matrix_column_major_load(Ptr, 5, 5, 5); +} + +void column_major_load_with_const_stride2_double(double *Ptr) { + // CHECK-LABEL: define void @column_major_load_with_const_stride2_double(double* %Ptr) + // CHECK: [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8 + // CHECK-NEXT: call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* [[PTR]], i64 15, i64 5, i64 5) + + dx5x5_t m_a2 = __builtin_matrix_column_major_load(Ptr, 5, 5, 2 * 3 + 9); +} + +void column_major_load_with_variable_stride_ull_float(float *Ptr, unsigned long long S) { + // CHECK-LABEL: define void @column_major_load_with_variable_stride_ull_float(float* %Ptr, i64 %S) + // CHECK: [[S:%.*]] = load i64, i64* %S.addr, align 8 + // CHECK-NEXT: [[PTR:%.*]] = load float*, float** %Ptr.addr, align 8 + // CHECK-NEXT: call <6 x float> @llvm.matrix.columnwise.load.v6f32.p0f32(float* [[PTR]], i64 [[S]], i64 2, i64 3) + + fx2x3_t m_b = __builtin_matrix_column_major_load(Ptr, 2, 3, S); +} + +void column_major_load_with_stride_math_int(int *Ptr, int S) { + // CHECK-LABEL: define void @column_major_load_with_stride_math_int(i32* %Ptr, i32 %S) + // CHECK: [[S:%.*]] = load i32, i32* %S.addr, align 4 + // CHECK-NEXT: [[STRIDE:%.*]] = add nsw i32 [[S]], 32 + // CHECK-NEXT: [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64 + // CHECK-NEXT: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 4, i64 20) + + ix4x20_t m_c = __builtin_matrix_column_major_load(Ptr, 4, 20, S + 32); +} + +void column_major_load_with_stride_math_s_int(int *Ptr, short S) { + // CHECK-LABEL: define void @column_major_load_with_stride_math_s_int(i32* %Ptr, i16 signext %S) + // CHECK: [[S:%.*]] = load i16, i16* %S.addr, align 2 + // CHECK-NEXT: [[S_EXT:%.*]] = sext i16 [[S]] to i32 + // CHECK-NEXT: [[STRIDE:%.*]] = add nsw i32 [[S_EXT]], 32 + // CHECK-NEXT: [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64 + // CHECK-NEXT: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: %matrix = call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 4, i64 20) + ix4x20_t m_c = __builtin_matrix_column_major_load(Ptr, 4, 20, S + 32); +} diff --git a/clang/test/CodeGenCXX/matrix-type-builtins.cpp b/clang/test/CodeGenCXX/matrix-type-builtins.cpp --- a/clang/test/CodeGenCXX/matrix-type-builtins.cpp +++ b/clang/test/CodeGenCXX/matrix-type-builtins.cpp @@ -74,3 +74,107 @@ // CHECK-NEXT: store <9 x float> [[M_T]], <9 x float>* [[M_T_ADDR]], align 4 matrix_t m_t = __builtin_matrix_transpose(m); } + +template +matrix_t column_major_load_with_stride(T *Ptr) { + return __builtin_matrix_column_major_load(Ptr, R, C, S); +} + +void test_column_major_load_with_stride_template_double(double *Ptr) { + // CHECK-LABEL: define void @_Z50test_column_major_load_with_stride_template_doublePd(double* %Ptr) + // CHECK: [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8 + // CHECK-NEXT: call <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEU11matrix_typeXT0_EXT1_ET_PS0_(double* [[PTR]]) + + // CHECK-LABEL: define linkonce_odr <40 x double> @_Z29column_major_load_with_strideIdLj10ELj4ELj15EEU11matrix_typeXT0_EXT1_ET_PS0_(double* %Ptr) + // CHECK: [[PTR:%.*]] = load double*, double** %Ptr.addr, align 8 + // CHECK-NEXT: call <40 x double> @llvm.matrix.columnwise.load.v40f64.p0f64(double* [[PTR]], i64 15, i64 10, i64 4) + + matrix_t M1 = column_major_load_with_stride(Ptr); +} + +void test_column_major_load_with_stride_template_int(int *Ptr) { + // CHECK-LABEL: define void @_Z47test_column_major_load_with_stride_template_intPi(i32* %Ptr) #5 { + // CHECK: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* [[PTR]]) + + // CHECK-LABEL: define linkonce_odr <6 x i32> @_Z29column_major_load_with_strideIiLj3ELj2ELj12EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* %Ptr) + // CHECK: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <6 x i32> @llvm.matrix.columnwise.load.v6i32.p0i32(i32* [[PTR]], i64 12, i64 3, i64 2) + + matrix_t M1 = column_major_load_with_stride(Ptr); +} + +struct UnsignedWrapper { + char x; + operator unsigned() { + return x; + } +}; + +void test_column_major_load_stride_wrapper(int *Ptr, UnsignedWrapper &W) { + // CHECK-LABEL: define void @_Z37test_column_major_load_stride_wrapperPiR15UnsignedWrapper(i32* %Ptr, %struct.UnsignedWrapper* nonnull align 1 dereferenceable(1) %W) + // CHECK: [[W:%.*]] = load %struct.UnsignedWrapper*, %struct.UnsignedWrapper** %W.addr, align 8 + // CHECK-NEXT: [[STRIDE:%.*]] = call i32 @_ZN15UnsignedWrappercvjEv(%struct.UnsignedWrapper* [[W]]) + // CHECK-NEXT: [[STRIDE_EXT:%.*]] = zext i32 [[STRIDE]] to i64 + // CHECK-NEXT: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <4 x i32> @llvm.matrix.columnwise.load.v4i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 2, i64 2) + matrix_t M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, W); +} + +constexpr int constexpr3() { return 3; } + +void test_column_major_load_constexpr_num_rows(int *Ptr) { + // CHECK-LABEL: define void @_Z41test_column_major_load_constexpr_num_rowsPi(i32* %Ptr) + // CHECK: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <6 x i32> @llvm.matrix.columnwise.load.v6i32.p0i32(i32* [[PTR]], i64 3, i64 3, i64 2) + + matrix_t M1 = __builtin_matrix_column_major_load(Ptr, constexpr3(), 2, 3); +} + +constexpr int constexpr1() { return 1; } + +void test_column_major_load_constexpr_num_columns(int *Ptr) { + // CHECK-LABEL: define void @_Z44test_column_major_load_constexpr_num_columnsPi(i32* %Ptr) + // CHECK: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <2 x i32> @llvm.matrix.columnwise.load.v2i32.p0i32(i32* [[PTR]], i64 3, i64 2, i64 1) + matrix_t M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr1(), 3); +} + +template +constexpr int constexpr_plus1() { return N + 1; } + +void test_column_major_load_constexpr_num_columns_temp(int *Ptr) { + // CHECK-LABEL: define void @_Z49test_column_major_load_constexpr_num_columns_tempPi(i32* %Ptr) + // CHECK: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <10 x i32> @llvm.matrix.columnwise.load.v10i32.p0i32(i32* [[PTR]], i64 3, i64 2, i64 5) + matrix_t M1 = __builtin_matrix_column_major_load(Ptr, 2, constexpr_plus1<4>(), 3); +} + +void test_column_major_load_constexpr_stride_constexpr(int *Ptr) { + // CHECK-LABEL: define void @_Z49test_column_major_load_constexpr_stride_constexprPi(i32* %Ptr) + // CHECK: [[STRIDE:%.*]] = call i32 @_Z10constexpr3v() + // CHECK-NEXT: [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64 + // CHECK-NEXT: [[PTR:%.*]] = load i32*, i32** %Ptr.addr, align 8 + // CHECK-NEXT: call <4 x i32> @llvm.matrix.columnwise.load.v4i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 2, i64 2) + + matrix_t M1 = __builtin_matrix_column_major_load(Ptr, 2, 2, constexpr3()); +} + +// TODO: +/*template struct remove_pointer {*/ +//typedef T type; +//}; + +//template struct remove_pointer{ +//typedef typename remove_pointer::type type; +//}; + +//// Same as column_major_load_with_stride, but with the PtrT argument itself begin a pointer type. +//template +//matrix_t::type, R, C> column_major_load_with_stride2(PtrT Ptr) { +//return __builtin_matrix_column_major_load(Ptr, R, C, S); +//} + +//void call_column_major_load_with_stride2(float *Ptr) { +//matrix_t m = column_major_load_with_stride2(Ptr); +//} diff --git a/clang/test/CodeGenObjC/matrix-type-builtins.m b/clang/test/CodeGenObjC/matrix-type-builtins.m --- a/clang/test/CodeGenObjC/matrix-type-builtins.m +++ b/clang/test/CodeGenObjC/matrix-type-builtins.m @@ -40,3 +40,23 @@ m.value = __builtin_matrix_transpose(*r); } + +__attribute__((objc_root_class)) +@interface PtrValue +@property unsigned *value; +@end + +__attribute__((objc_root_class)) +@interface IntValue +@property int value; +@end + +void test_column_major_load(PtrValue *Ptr, IntValue *Stride) { + // CHECK-LABEL: define void @test_column_major_load(%2* %Ptr, %3* %Stride) #4 { + // CHECK: [[STRIDE:%.*]] = call i32 bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32 (i8*, i8*)*) + // CHECK-NEXT: [[STRIDE_EXT:%.*]] = sext i32 [[STRIDE]] to i64 + // CHECK: [[PTR:%.*]] = call i32* bitcast (i8* (i8*, i8*, ...)* @objc_msgSend to i32* (i8*, i8*)*) + // CHECK-NEXT: call <12 x i32> @llvm.matrix.columnwise.load.v12i32.p0i32(i32* [[PTR]], i64 [[STRIDE_EXT]], i64 3, i64 4) + + u3x4 m = __builtin_matrix_column_major_load(Ptr.value, 3, 4, Stride.value); +} diff --git a/clang/test/Sema/matrix-type-builtins.c b/clang/test/Sema/matrix-type-builtins.c --- a/clang/test/Sema/matrix-type-builtins.c +++ b/clang/test/Sema/matrix-type-builtins.c @@ -20,3 +20,49 @@ ix3x3 m = __builtin_matrix_transpose(c); // expected-error@-1 {{initializing 'ix3x3' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') with an expression of incompatible type 'double __attribute__((matrix_type(3, 3)))'}} } + +struct Foo { + unsigned x; +}; + +void column_major_load(float *p1, int *p2, _Bool *p3, struct Foo *p4) { + sx5x10_t a1 = __builtin_matrix_column_major_load(p1, 5, 11, 5); + // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 11)))'}} + sx5x10_t a2 = __builtin_matrix_column_major_load(p1, 5, 9, 5); + // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 9)))'}} + sx5x10_t a3 = __builtin_matrix_column_major_load(p1, 6, 10, 6); + // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 10)))'}} + sx5x10_t a4 = __builtin_matrix_column_major_load(p1, 4, 10, 4); + // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(4, 10)))'}} + sx5x10_t a5 = __builtin_matrix_column_major_load(p1, 6, 9, 6); + // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 9)))'}} + sx5x10_t a6 = __builtin_matrix_column_major_load(p2, 5, 10, 6); + // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'int __attribute__((matrix_type(5, 10)))'}} + + sx5x10_t a7 = __builtin_matrix_column_major_load(p1, 5, 10, 3); + // expected-error@-1 {{stride must be greater or equal to the number of rows}} + + sx5x10_t a8 = __builtin_matrix_column_major_load(p3, 5, 10, 6); + // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}} + + sx5x10_t a9 = __builtin_matrix_column_major_load(p4, 5, 10, 6); + // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}} + + sx5x10_t a10 = __builtin_matrix_column_major_load(p1, 1ull << 21, 10, 6); + // expected-error@-1 {{row dimension is outside the allowed range [1, 1048575}} + sx5x10_t a11 = __builtin_matrix_column_major_load(p1, 10, 1ull << 21, 10); + // expected-error@-1 {{column dimension is outside the allowed range [1, 1048575}} + + sx5x10_t a12 = __builtin_matrix_column_major_load( + 10, // expected-error {{first argument must be a pointer to a valid matrix element type}} + 1ull << 21, // expected-error {{row dimension is outside the allowed range [1, 1048575]}} + 1ull << 21, // expected-error {{column dimension is outside the allowed range [1, 1048575]}} + ""); // expected-warning {{incompatible pointer to integer conversion casting 'char [1]' to type 'unsigned long'}} + + sx5x10_t a13 = __builtin_matrix_column_major_load( + 10, // expected-error {{first argument must be a pointer to a valid matrix element type}} + *p4, // expected-error {{casting 'struct Foo' to incompatible type 'unsigned long'}} + "", // expected-error {{column argument must be a constant unsigned integer expression}} + // expected-warning@-1 {{incompatible pointer to integer conversion casting 'char [1]' to type 'unsigned long'}} + 10); +} diff --git a/clang/test/SemaCXX/matrix-type-builtins.cpp b/clang/test/SemaCXX/matrix-type-builtins.cpp --- a/clang/test/SemaCXX/matrix-type-builtins.cpp +++ b/clang/test/SemaCXX/matrix-type-builtins.cpp @@ -39,3 +39,65 @@ Mat3.value = transpose(Mat2); // expected-note@-1 {{in instantiation of function template specialization 'transpose' requested here}} } + +template +typename MyMatrix::matrix_t column_major_load(MyMatrix &A, EltTy0 *Ptr) { + char *v1 = __builtin_matrix_column_major_load(Ptr, 9, 4, 10); + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}} + // expected-error@-2 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}} + // expected-error@-3 {{cannot initialize a variable of type 'char *' with an rvalue of type 'float __attribute__((matrix_type(9, 4)))'}} + + return __builtin_matrix_column_major_load(Ptr, R0, C0, R0); + // expected-error@-1 {{cannot initialize return object of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(5, 5)))') with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 3)))'}} + // expected-error@-2 {{cannot initialize return object of type 'typename MyMatrix::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 3)))') with an rvalue of type 'float __attribute__((matrix_type(2, 3)))'}} +} + +void test_column_major_loads_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix Mat1; + Mat1.value = column_major_load(Mat1, Ptr1); + // expected-note@-1 {{in instantiation of function template specialization 'column_major_load' requested here}} + column_major_load(Mat1, Ptr1); + // expected-note@-1 {{in instantiation of function template specialization 'column_major_load' requested here}} + + MyMatrix Mat2; + Mat1.value = column_major_load(Mat2, Ptr2); + // expected-note@-1 {{in instantiation of function template specialization 'column_major_load' requested here}} +} + +constexpr int constexpr1() { return 1; } +constexpr int constexpr_neg1() { return -1; } + +void test_column_major_load_constexpr(unsigned *Ptr) { + (void)__builtin_matrix_column_major_load(Ptr, 2, 2, constexpr1()); + // expected-error@-1 {{stride must be greater or equal to the number of rows}} + (void)__builtin_matrix_column_major_load(Ptr, constexpr_neg1(), 2, 4); + // expected-error@-1 {{row dimension is outside the allowed range [1, 1048575]}} + (void)__builtin_matrix_column_major_load(Ptr, 2, constexpr_neg1(), 4); + // expected-error@-1 {{column dimension is outside the allowed range [1, 1048575]}} +} + +struct IntWrapper { + operator int() { + return 1; + } +}; + +void test_column_major_load_wrapper(unsigned *Ptr, IntWrapper &W) { + (void)__builtin_matrix_column_major_load(Ptr, W, 2, 2); + // expected-error@-1 {{row argument must be a constant unsigned integer expression}} + (void)__builtin_matrix_column_major_load(Ptr, 2, W, 2); + // expected-error@-1 {{column argument must be a constant unsigned integer expression}} +} + +template +void test_column_major_load_temp(T Ptr) { + (void)__builtin_matrix_column_major_load(Ptr, R, C, S); +} + +void call_column_major_load_temp(unsigned *Ptr, unsigned X) { + (void)__builtin_matrix_column_major_load(Ptr, X, X, X); + // expected-error@-1 {{row argument must be a constant unsigned integer expression}} + // expected-error@-2 {{column argument must be a constant unsigned integer expression}} + (void)__builtin_matrix_column_major_load(X, 2, 2, 2); + // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}} +}