diff --git a/clang/docs/MatrixTypes.rst b/clang/docs/MatrixTypes.rst --- a/clang/docs/MatrixTypes.rst +++ b/clang/docs/MatrixTypes.rst @@ -204,6 +204,23 @@ * *T* - Element type * *row*, *col* - Row and column arguments respectively. +``M3 __builtin_matrix_multiply_add(M1 matrixA, M2 matrixB, M3 matrixC)`` + +**Returns**: A matrix ``Res`` equivalent to the code below, where ``row`` refers to the +number of rows of ``M1``, ``depth`` to the number of either columns of ``M1`` or rows of ``M2`` and +``col`` to the number of columns of ``M2``. + +**Effects**: Equivalent to: + +.. code-block:: c++ + + M Res; + for (int C = 0; C < col; ++C) + for (int R = 0; R < row; ++R) + Acc = matrixC[R][C]; + for (int K = 0; K < depth; ++K) + Acc += matrix[R][C]; + Res[R][C] = Acc ``M2 __builtin_matrix_transpose(M1 matrix)`` diff --git a/clang/include/clang/Basic/Builtins.def b/clang/include/clang/Basic/Builtins.def --- a/clang/include/clang/Basic/Builtins.def +++ b/clang/include/clang/Basic/Builtins.def @@ -642,6 +642,7 @@ BUILTIN(__builtin_matrix_transpose, "v.", "nFt") BUILTIN(__builtin_matrix_column_major_load, "v.", "nFt") BUILTIN(__builtin_matrix_column_major_store, "v.", "nFt") +BUILTIN(__builtin_matrix_multiply_add, "v.", "nFt") // "Overloaded" Atomic operator builtins. These are overloaded to support data // types of i8, i16, i32, i64, and i128. The front-end sees calls to the diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11107,6 +11107,10 @@ def err_matrix_subscript_comma: Error< "comma expressions are not allowed as indices in matrix subscript expressions">; def err_builtin_matrix_arg: Error<"1st argument must be a matrix">; +def err_builtin_matrix_dimension_mismatch: Error< + "The number of columns of the 1st argument must be the same as the number of rows of the 2nd argument and the number of rows of the 1st argument and columns of the 2nd argument must match 3rd argument">; +def err_builtin_matrix_scalar_type: Error< + "All arguments elements type must match">; def err_builtin_matrix_scalar_unsigned_arg: Error< "%0 argument must be a constant unsigned integer expression">; def err_builtin_matrix_pointer_arg: Error< diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -12514,6 +12514,8 @@ ExprResult CallResult); ExprResult SemaBuiltinMatrixColumnMajorStore(CallExpr *TheCall, ExprResult CallResult); + ExprResult SemaBuiltinMatrixMultiplyAdd(CallExpr *TheCall, + ExprResult CallResult); public: enum FormatStringType { diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -3067,6 +3067,23 @@ return RValue::get(Result); } + case Builtin::BI__builtin_matrix_multiply_add: { + MatrixBuilder MB(Builder); + Value *MatrixA = EmitScalarExpr(E->getArg(0)); + Value *MatrixB = EmitScalarExpr(E->getArg(1)); + Value *MatrixC = EmitScalarExpr(E->getArg(2)); + + const auto *MatrixTy1 = + E->getArg(0)->getType()->getAs(); + const auto *MatrixTy2 = + E->getArg(1)->getType()->getAs(); + + Value *Result = MB.CreateMatrixMultiplyAdd( + MatrixA, MatrixB, MatrixC, MatrixTy1->getNumRows(), + MatrixTy1->getNumColumns(), MatrixTy2->getNumColumns()); + return RValue::get(Result); + } + case Builtin::BIfinite: case Builtin::BI__finite: case Builtin::BIfinitef: diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -1967,6 +1967,9 @@ case Builtin::BI__builtin_matrix_column_major_store: return SemaBuiltinMatrixColumnMajorStore(TheCall, TheCallResult); + case Builtin::BI__builtin_matrix_multiply_add: + return SemaBuiltinMatrixMultiplyAdd(TheCall, TheCallResult); + case Builtin::BI__builtin_get_device_side_mangled_name: { auto Check = [](CallExpr *TheCall) { if (TheCall->getNumArgs() != 1) @@ -16152,6 +16155,78 @@ return CallResult; } +ExprResult Sema::SemaBuiltinMatrixMultiplyAdd(CallExpr *TheCall, + ExprResult CallResult) { + if (!getLangOpts().MatrixTypes) { + Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled); + return ExprError(); + } + + if (checkArgCount(*this, TheCall, 3)) + return ExprError(); + + ExprResult MatrixAArg = DefaultLvalueConversion(TheCall->getArg(0)); + if (MatrixAArg.isInvalid()) + return MatrixAArg; + Expr *MatrixA = MatrixAArg.get(); + + auto *MTypeA = MatrixA->getType()->getAs(); + if (!MTypeA) { + Diag(MatrixA->getBeginLoc(), diag::err_builtin_matrix_arg); + return ExprError(); + } + + ExprResult MatrixBArg = DefaultLvalueConversion(TheCall->getArg(1)); + if (MatrixBArg.isInvalid()) + return MatrixBArg; + Expr *MatrixB = MatrixBArg.get(); + + auto *MTypeB = MatrixB->getType()->getAs(); + if (!MTypeB) { + Diag(MatrixB->getBeginLoc(), diag::err_builtin_matrix_arg); + return ExprError(); + } + + ExprResult MatrixCArg = DefaultLvalueConversion(TheCall->getArg(2)); + if (MatrixCArg.isInvalid()) + return MatrixCArg; + Expr *MatrixC = MatrixCArg.get(); + + auto *MTypeC = MatrixC->getType()->getAs(); + if (!MTypeC) { + Diag(MatrixC->getBeginLoc(), diag::err_builtin_matrix_arg); + return ExprError(); + } + + // Check wether all matrices have the same element type. We don't support + // mixed precision as of yet. + if (!(Context.hasSameType(MTypeC->getElementType(), + MTypeA->getElementType()) && + Context.hasSameType(MTypeC->getElementType(), + MTypeB->getElementType()))) { + Diag(MatrixC->getBeginLoc(), diag::err_builtin_matrix_scalar_type); + return ExprError(); + } + + // Check if dimensions are appropriate. + if (MTypeA->getNumColumns() != MTypeB->getNumRows() || + !(MTypeC->getNumColumns() == MTypeB->getNumColumns() && + MTypeC->getNumRows() == MTypeA->getNumRows())) { + Diag(MatrixC->getBeginLoc(), diag::err_builtin_matrix_dimension_mismatch); + return ExprError(); + } + + // Prepare Result matrix. + QualType ResultType = Context.getConstantMatrixType( + MTypeC->getElementType(), MTypeC->getNumRows(), MTypeC->getNumColumns()); + + TheCall->setType(ResultType); + TheCall->setArg(0, MatrixA); + TheCall->setArg(1, MatrixB); + TheCall->setArg(2, MatrixC); + return CallResult; +} + ExprResult Sema::SemaBuiltinMatrixColumnMajorStore(CallExpr *TheCall, ExprResult CallResult) { if (checkArgCount(*this, TheCall, 3)) diff --git a/clang/test/CodeGen/matrix-type-builtins.c b/clang/test/CodeGen/matrix-type-builtins.c --- a/clang/test/CodeGen/matrix-type-builtins.c +++ b/clang/test/CodeGen/matrix-type-builtins.c @@ -9,11 +9,36 @@ typedef double dx5x5_t __attribute__((matrix_type(5, 5))); typedef float fx2x3_t __attribute__((matrix_type(2, 3))); typedef float fx3x2_t __attribute__((matrix_type(3, 2))); +typedef float fx2x2_t __attribute__((matrix_type(5, 5))); typedef int ix20x4_t __attribute__((matrix_type(20, 4))); typedef int ix4x20_t __attribute__((matrix_type(4, 20))); typedef unsigned ux1x6_t __attribute__((matrix_type(1, 6))); typedef unsigned ux6x1_t __attribute__((matrix_type(6, 1))); +void multiply_add_2x2(const fx2x2_t *a, const fx2x2_t *b, fx2x2_t *c) { + // CHECK-LABEL: define{{.*.*.*}} void @multiply_add_2x2( + // CHECK: [[A_ADDR:%.*]] = alloca [25 x float]*, align 8 + // CHECK-NEXT: [[B_ADDR:%.*]] = alloca [25 x float]*, align 8 + // CHECK-NEXT: [[C_ADDR:%.*]] = alloca [25 x float]*, align 8 + // CHECK-NEXT: store [25 x float]* %a, [25 x float]** [[A_ADDR]], align 8 + // CHECK-NEXT: store [25 x float]* %b, [25 x float]** [[B_ADDR]], align 8 + // CHECK-NEXT: store [25 x float]* %c, [25 x float]** [[C_ADDR]], align 8 + // CHECK-NEXT: [[A_L:%.*]] = load [25 x float]*, [25 x float]** [[A_ADDR]], align 8 + // CHECK-NEXT: [[A_B:%.*]] = bitcast [25 x float]* [[A_L]] to <25 x float>* + // CHECK-NEXT: [[A:%.*]] = load <25 x float>, <25 x float>* [[A_B]], align 4 + // CHECK-NEXT: [[B_L:%.*]] = load [25 x float]*, [25 x float]** [[B_ADDR]], align 8 + // CHECK-NEXT: [[B_B:%.*]] = bitcast [25 x float]* [[B_L]] to <25 x float>* + // CHECK-NEXT: [[B:%.*]] = load <25 x float>, <25 x float>* [[B_B]], align 4 + // CHECK-NEXT: [[C_L:%.*]] = load [25 x float]*, [25 x float]** [[C_ADDR]], align 8 + // CHECK-NEXT: [[C_B:%.*]] = bitcast [25 x float]* [[C_L]] to <25 x float>* + // CHECK-NEXT: [[C:%.*]] = load <25 x float>, <25 x float>* [[C_B]], align 4 + // CHECK-NEXT: [[MADD:%.*]] = call <25 x float> @llvm.matrix.multiply.add.v25f32.v25f32.v25f32.v25f32(<25 x float> [[A]], <25 x float> [[B]], <25 x float> [[C]], i32 5, i32 5, i32 5) + // CHECK-NEXT: [[CR_L:%.*]] = load [25 x float]*, [25 x float]** [[C_ADDR]], align 8 + // CHECK-NEXT: [[CR_B:%.*]] = bitcast [25 x float]* [[CR_L]] to <25 x float>* + // CHECK-NEXT: store <25 x float> [[MADD]], <25 x float>* [[CR_B]], align 4 + *c = __builtin_matrix_multiply_add(*a, *b, *c); +} + void transpose_double_5x5(dx5x5_t *a) { // CHECK-LABEL: define{{.*}} void @transpose_double_5x5( // CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8 diff --git a/clang/test/Sema/matrix-type-builtins.c b/clang/test/Sema/matrix-type-builtins.c --- a/clang/test/Sema/matrix-type-builtins.c +++ b/clang/test/Sema/matrix-type-builtins.c @@ -96,3 +96,14 @@ __builtin_matrix_column_major_store(*m1, p4, 20); // expected-error@-1 {{cannot store matrix to read-only pointer}} } + +void multiply_add(sx5x10_t a, sx5x10_t b, sx5x10_t c, dx3x3 d, dx3x3 e, ix3x3 f) { + c = __builtin_matrix_multiply_add(a, b, c); + // expected-error@-1 {{The number of columns of the 1st argument must be the same as the number of rows of the 2nd argument and the number of rows of the 1st argument and columns of the 2nd argument must match 3rd argument}} + + f = __builtin_matrix_multiply_add(d, e, f); + // expected-error@-1 {{All arguments elements type must match}} + + f = __builtin_matrix_multiply_add(d, e, e); + // expected-error@-1 {{assigning to 'ix3x3' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') from incompatible type 'double __attribute__((matrix_type(3, 3)))'}} +} diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1571,6 +1571,13 @@ [IntrNoSync, IntrWillReturn, IntrNoMem, IntrSpeculatable, ImmArg>, ImmArg>, ImmArg>]>; +def int_matrix_multiply_add + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [llvm_anyvector_ty, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i32_ty, llvm_i32_ty, + llvm_i32_ty], + [IntrNoSync, IntrWillReturn, IntrNoMem, IntrSpeculatable, ImmArg>, + ImmArg>, ImmArg>]>; + def int_matrix_column_major_load : DefaultAttrsIntrinsic<[llvm_anyvector_ty], [LLVMPointerToElt<0>, llvm_i64_ty, llvm_i1_ty, diff --git a/llvm/include/llvm/IR/MatrixBuilder.h b/llvm/include/llvm/IR/MatrixBuilder.h --- a/llvm/include/llvm/IR/MatrixBuilder.h +++ b/llvm/include/llvm/IR/MatrixBuilder.h @@ -125,6 +125,31 @@ return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); } + /// Create a llvm.matrix.multiply.add call, multiplying matrixes \p LHS and \p + /// RHS and adding the result to \p ACC. + CallInst *CreateMatrixMultiplyAdd(Value *LHS, Value *RHS, Value *ACC, + unsigned LHSRows, unsigned LHSColumns, + unsigned RHSColumns, + const Twine &Name = "") { + auto *LHSType = cast(LHS->getType()); + auto *RHSType = cast(RHS->getType()); + auto *AccType = cast(ACC->getType()); + + auto *ReturnType = + FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns); + Value *Ops[] = {LHS, + RHS, + ACC, + B.getInt32(LHSRows), + B.getInt32(LHSColumns), + B.getInt32(RHSColumns)}; + Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType, AccType}; + + Function *TheFn = Intrinsic::getDeclaration( + getModule(), Intrinsic::matrix_multiply_add, OverloadedTypes); + return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name); + } + /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p /// RHS. CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows, diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -511,6 +511,7 @@ if (II) switch (II->getIntrinsicID()) { case Intrinsic::matrix_multiply: + case Intrinsic::matrix_multiply_add: case Intrinsic::matrix_transpose: case Intrinsic::matrix_column_major_load: case Intrinsic::matrix_column_major_store: @@ -540,6 +541,7 @@ Value *MatrixA; Value *MatrixB; + Value *MatrixC; Value *M; Value *N; Value *K; @@ -547,6 +549,11 @@ m_Value(MatrixA), m_Value(MatrixB), m_Value(M), m_Value(N), m_Value(K)))) { Propagate = setShapeInfo(Inst, {M, K}); + } else if (match(Inst, + m_Intrinsic( + m_Value(MatrixA), m_Value(MatrixB), m_Value(MatrixC), + m_Value(M), m_Value(N), m_Value(K)))) { + Propagate = setShapeInfo(Inst, {M, K}); } else if (match(Inst, m_Intrinsic( m_Value(MatrixA), m_Value(M), m_Value(N)))) { // Flip dimensions. @@ -611,6 +618,7 @@ Value *MatrixA; Value *MatrixB; + Value *MatrixC; Value *M; Value *N; Value *K; @@ -622,7 +630,18 @@ if (setShapeInfo(MatrixB, {N, K})) pushInstruction(MatrixB, WorkList); + } else if (match(V, + m_Intrinsic( + m_Value(MatrixA), m_Value(MatrixB), m_Value(MatrixC), + m_Value(M), m_Value(N), m_Value(K)))) { + if (setShapeInfo(MatrixA, {M, N})) + pushInstruction(MatrixA, WorkList); + if (setShapeInfo(MatrixB, {N, K})) + pushInstruction(MatrixB, WorkList); + + if (setShapeInfo(MatrixC, {M, K})) + pushInstruction(MatrixC, WorkList); } else if (match(V, m_Intrinsic( m_Value(MatrixA), m_Value(M), m_Value(N)))) { // Flip dimensions. @@ -673,6 +692,7 @@ switch (II->getIntrinsicID()) { case Intrinsic::matrix_multiply: + case Intrinsic::matrix_multiply_add: case Intrinsic::matrix_transpose: case Intrinsic::matrix_column_major_load: case Intrinsic::matrix_column_major_store: @@ -769,6 +789,9 @@ case Intrinsic::matrix_column_major_store: LowerColumnMajorStore(Inst); break; + case Intrinsic::matrix_multiply_add: + LowerMultiplyAdd(Inst); + break; default: return false; } @@ -1009,11 +1032,13 @@ } } - /// Compute \p Result += \p A * \p B for input matrices with left-associating - /// addition. + /// Compute \p Result += \p A * \p B + \p ACC for input matrices with + /// left-associating addition. + template void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, const MatrixTy &B, bool AllowContraction, - IRBuilder<> &Builder, bool isTiled) { + IRBuilder<> &Builder, bool isTiled, + const MatrixTy *ACC = nullptr) { const unsigned VF = std::max( TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) .getFixedSize() / @@ -1030,20 +1055,25 @@ unsigned NumComputeOps = 0; if (A.isColumnMajor()) { // Multiply columns from the first operand with scalars from the second - // operand. Then move along the K axes and accumulate the columns. With + // operand. Then move along the K axes and accumulate the columns. With // this the adds can be vectorized without reassociation. for (unsigned J = 0; J < C; ++J) { unsigned BlockSize = VF; // If Result is zero, we don't need to accumulate in the K==0 iteration. - bool isSumZero = isa(Result.getColumn(J)); + bool isSumZero = isAccumulating + ? false + : isa(Result.getColumn(J)); for (unsigned I = 0; I < R; I += BlockSize) { // Gradually lower the vectorization factor to cover the remainder. while (I + BlockSize > R) BlockSize /= 2; - Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder) - : nullptr; + Value *Sum = + isAccumulating ? ACC->extractVector(I, J, BlockSize, Builder) + : isTiled ? Result.extractVector(I, J, BlockSize, Builder) + : nullptr; + ; for (unsigned K = 0; K < M; ++K) { Value *L = A.extractVector(I, K, BlockSize, Builder); Value *RH = Builder.CreateExtractElement(B.getColumn(J), K); @@ -1062,13 +1092,17 @@ // the adds can be vectorized without reassociation. for (unsigned I = 0; I < R; ++I) { unsigned BlockSize = VF; - bool isSumZero = isa(Result.getRow(I)); + bool isSumZero = isAccumulating + ? false + : isa(Result.getRow(I)); for (unsigned J = 0; J < C; J += BlockSize) { // Gradually lower the vectorization factor to cover the remainder. while (J + BlockSize > C) BlockSize /= 2; - Value *Sum = nullptr; + Value *Sum = isAccumulating + ? ACC->extractVector(I, J, BlockSize, Builder) + : nullptr; for (unsigned K = 0; K < M; ++K) { Value *R = B.extractVector(K, J, BlockSize, Builder); Value *LH = Builder.CreateExtractElement(A.getVector(I), K); @@ -1367,6 +1401,40 @@ } } + /// Lowers llvm.matrix.multiply.add + void LowerMultiplyAdd(CallInst *MatMulAdd) { + IRBuilder<> Builder(MatMulAdd); + auto *EltType = cast(MatMulAdd->getType())->getElementType(); + ShapeInfo LShape(MatMulAdd->getArgOperand(3), MatMulAdd->getArgOperand(4)); + ShapeInfo RShape(MatMulAdd->getArgOperand(4), MatMulAdd->getArgOperand(5)); + ShapeInfo AShape(MatMulAdd->getArgOperand(3), MatMulAdd->getArgOperand(5)); + + const MatrixTy &Lhs = + getMatrix(MatMulAdd->getArgOperand(0), LShape, Builder); + const MatrixTy &Rhs = + getMatrix(MatMulAdd->getArgOperand(1), RShape, Builder); + const MatrixTy &Acc = + getMatrix(MatMulAdd->getArgOperand(2), AShape, Builder); + assert(Lhs.getElementType() == Rhs.getElementType() && + "Matrix multiply argument element types do not match."); + + const unsigned R = LShape.NumRows; + const unsigned C = RShape.NumColumns; + assert(LShape.NumColumns == RShape.NumRows); + + // Initialize the output + MatrixTy Result(R, C, EltType); + assert(Lhs.getElementType() == Result.getElementType() && + "Matrix multiply result element type does not match arguments."); + + bool AllowContract = + AllowContractEnabled || + (isa(MatMulAdd) && MatMulAdd->hasAllowContract()); + emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false, + &Acc); + finalizeLowering(MatMulAdd, Result, Builder); + } + /// Lowers llvm.matrix.multiply. void LowerMultiply(CallInst *MatMul) { IRBuilder<> Builder(MatMul); @@ -1648,6 +1716,14 @@ prettyPrintMatrixType(II->getOperand(1), SS); SS << "." << *II->getType()->getScalarType(); break; + case Intrinsic::matrix_multiply_add: + prettyPrintMatrixType(II->getOperand(0), SS); + SS << "."; + prettyPrintMatrixType(II->getOperand(1), SS); + SS << "." << *II->getType()->getScalarType(); + prettyPrintMatrixType(II->getOperand(2), SS); + SS << "." << *II->getType()->getScalarType(); + break; case Intrinsic::matrix_transpose: prettyPrintMatrixType(II->getOperand(0), SS); SS << "." << *II->getType()->getScalarType(); @@ -1672,6 +1748,7 @@ if (IntrinsicInst *II = dyn_cast(CI)) { switch (II->getIntrinsicID()) { case Intrinsic::matrix_multiply: + case Intrinsic::matrix_multiply_add: return 3; case Intrinsic::matrix_transpose: return 2;