diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -2253,6 +2253,20 @@ :ref:`stackmap entry `. See the intrinsic description for further details. +.. _ob_matrix + +Matrix Operand Bundles +^^^^^^^^^^^^^^^^^^^^^^ + +Matrix operand bundles are characterized by the ``matrix_`` prefix. Currently there +are the ``"matrix_nuw"`` and ``"matrix_nsw"`` operand bundles, to indicate whether +``nuw`` and ``nsw`` should be used when lowering calls to +``@llvm.matrix.multiply.*`` with integer matrix arguments. + +Matrix operand bundles must be attached to calls of the +``@llvm.matrix.multiply.*`` intrinsic and at most one of each matrix operand +bundle can be attached to a call. + .. _moduleasm: Module-Level Inline Assembly @@ -15477,6 +15491,10 @@ columns and multiplies them. The result matrix is returned embedded in the result vector. +:ref:`"Matrix" bundles ` can be used to indicate whether ``nuw`` +and ``nsw`` should be used when '``llvm.matrix.multiply.*``' calls with integer +matrix arguments. + Arguments: """""""""" diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -93,6 +93,8 @@ OB_cfguardtarget = 3, // "cfguardtarget" OB_preallocated = 4, // "preallocated" OB_gc_live = 5, // "gc-live" + OB_matrix_nuw = 6, // "matrix_nuw" + OB_matrix_nsw = 7, // "matrix_nsw" }; /// getMDKindID - Return a unique non-zero ID for the specified metadata kind. diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -78,6 +78,16 @@ "gc-transition operand bundle id drifted!"); (void)GCLiveEntry; + auto *MatrixNUWEntry = pImpl->getOrInsertBundleTag("matrix_nuw"); + assert(MatrixNUWEntry->second == LLVMContext::OB_matrix_nuw && + "matrix_nuw operand bundle id drifted!"); + (void)MatrixNUWEntry; + + auto *MatrixNSWEntry = pImpl->getOrInsertBundleTag("matrix_nsw"); + assert(MatrixNSWEntry->second == LLVMContext::OB_matrix_nsw && + "matrix_nsw operand bundle id drifted!"); + (void)MatrixNSWEntry; + SyncScope::ID SingleThreadSSID = pImpl->getOrInsertSyncScopeID("singlethread"); assert(SingleThreadSSID == SyncScope::SingleThread && diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -5010,11 +5010,39 @@ ConstantInt *NumColumns; VectorType *TypeToCheck; switch (ID) { - case Intrinsic::matrix_multiply: + case Intrinsic::matrix_multiply: { NumRows = cast(Call.getArgOperand(2)); NumColumns = cast(Call.getArgOperand(4)); TypeToCheck = cast(Call.getType()); + + auto NUWBundle = Call.getOperandBundle(LLVMContext::OB_matrix_nuw); + if (NUWBundle) { + Value *Arg1 = Call.getArgOperand(0); + Assert(Arg1->getType()->getScalarType()->isIntegerTy(), + "matrix_nuw bundle only supported for integer multiplications"); + Assert(NUWBundle->Inputs.size() == 1, + "matrix_nuw bundle must have a single argument"); + auto OpTy = NUWBundle->Inputs[0]->getType(); + Assert(OpTy->isIntegerTy() && OpTy->getIntegerBitWidth() == 1, + "matrix_nuw bundle operand must be an integer with bitwidth 1"); + Assert(isa(NUWBundle->Inputs[0]), + "matrix_nuw bundle operand must be a constant integer"); + } + auto NSWBundle = Call.getOperandBundle(LLVMContext::OB_matrix_nsw); + if (NSWBundle) { + Value *Arg1 = Call.getArgOperand(0); + Assert(Arg1->getType()->getScalarType()->isIntegerTy(), + "matrix_nsw bundle only supported for integer multiplications"); + Assert(NSWBundle->Inputs.size() == 1, + "matrix_nsw bundle must have a single argument"); + auto OpTy = NSWBundle->Inputs[0]->getType(); + Assert(OpTy->isIntegerTy() && OpTy->getIntegerBitWidth() == 1, + "matrix_nsw bundle operand must be an integer with bitwidth 1"); + Assert(isa(NSWBundle->Inputs[0]), + "matrix_nsw bundle operand must be a constant integer"); + } break; + } case Intrinsic::matrix_transpose: NumRows = cast(Call.getArgOperand(1)); NumColumns = cast(Call.getArgOperand(2)); diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -187,6 +187,29 @@ LoopInfo &LI; OptimizationRemarkEmitter &ORE; + /// Contains a set of options for generating code for llvm.matrix.multiply. + struct MultiplyOptions { + /// Is Floating-point contraction allowed? + bool AllowContract; + /// Should the NUW flag be added? + bool HasNUW; + /// Should the NSW flag be added? + bool HasNSW; + + MultiplyOptions(CallInst *MatMul) { + AllowContract = AllowContractEnabled || (isa(MatMul) && + MatMul->hasAllowContract()); + auto NUWBundle = MatMul->getOperandBundle(LLVMContext::OB_matrix_nuw); + HasNUW = NUWBundle + ? cast(*NUWBundle->Inputs.begin())->isOne() + : false; + auto NSWBundle = MatMul->getOperandBundle(LLVMContext::OB_matrix_nsw); + HasNSW = NSWBundle + ? cast(*NSWBundle->Inputs.begin())->isOne() + : false; + } + }; + /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation. struct OpInfoTy { /// Number of stores emitted to generate this matrix. @@ -955,14 +978,15 @@ } Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp, - IRBuilder<> &Builder, bool AllowContraction, + IRBuilder<> &Builder, MultiplyOptions Opts, unsigned &NumComputeOps) { NumComputeOps += getNumOps(A->getType()); if (!Sum) - return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B); + return UseFPOp ? Builder.CreateFMul(A, B) + : Builder.CreateMul(A, B, "", Opts.HasNUW, Opts.HasNSW); if (UseFPOp) { - if (AllowContraction) { + if (Opts.AllowContract) { // Use fmuladd for floating point operations and let the backend decide // if that's profitable. Function *FMulAdd = Intrinsic::getDeclaration( @@ -975,8 +999,8 @@ } NumComputeOps += getNumOps(A->getType()); - Value *Mul = Builder.CreateMul(A, B); - return Builder.CreateAdd(Sum, Mul); + Value *Mul = Builder.CreateMul(A, B, "", Opts.HasNUW, Opts.HasNSW); + return Builder.CreateAdd(Sum, Mul, "", Opts.HasNUW, Opts.HasNSW); } /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For @@ -1003,7 +1027,7 @@ /// Compute \p Result += \p A * \p B for input matrices with left-associating /// addition. void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A, - const MatrixTy &B, bool AllowContraction, + const MatrixTy &B, MultiplyOptions Opts, IRBuilder<> &Builder, bool isTiled) { const unsigned VF = std::max( TTI.getRegisterBitWidth(true) / @@ -1040,7 +1064,7 @@ Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat"); Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat, Result.getElementType()->isFloatingPointTy(), - Builder, AllowContraction, NumComputeOps); + Builder, Opts, NumComputeOps); } Result.setVector(J, insertVector(Result.getVector(J), I, Sum, Builder)); @@ -1064,7 +1088,7 @@ Value *LH = Builder.CreateExtractElement(A.getVector(I), K); Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat"); Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R, - IsFP, Builder, AllowContraction, NumComputeOps); + IsFP, Builder, Opts, NumComputeOps); } Result.setVector(I, insertVector(Result.getVector(I), J, Sum, Builder)); @@ -1219,8 +1243,6 @@ Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul); Value *CPtr = Store->getPointerOperand(); - bool AllowContract = AllowContractEnabled || (isa(MatMul) && - MatMul->hasAllowContract()); IRBuilder<> Builder(Store); for (unsigned J = 0; J < C; J += TileSize) for (unsigned I = 0; I < R; I += TileSize) { @@ -1238,7 +1260,7 @@ loadMatrix(BPtr, LoadOp1->getAlign(), LoadOp1->isVolatile(), RShape, Builder.getInt64(K), Builder.getInt64(J), {TileM, TileC}, EltType, Builder); - emitMatrixMultiply(Res, A, B, AllowContract, Builder, true); + emitMatrixMultiply(Res, A, B, MultiplyOptions(MatMul), Builder, true); } storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M}, Builder.getInt64(I), Builder.getInt64(J), EltType, Builder); @@ -1305,9 +1327,8 @@ // Initialize the output MatrixTy Result(R, C, EltType); - bool AllowContract = AllowContractEnabled || (isa(MatMul) && - MatMul->hasAllowContract()); - emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false); + emitMatrixMultiply(Result, Lhs, Rhs, MultiplyOptions(MatMul), Builder, + false); finalizeLowering(MatMul, Result, Builder); } diff --git a/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll b/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll --- a/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll +++ b/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll @@ -9,6 +9,8 @@ ; CHECK-NEXT: @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32>, <4 x i32>, i32, i32, i32) + +define <4 x i32> @multiply_2x2_nuw(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: @multiply_2x2_nuw( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <4 x i32> [[B:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <4 x i32> [[B]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i32> undef, i32 [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = mul nuw <1 x i32> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x i32> undef, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT5]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = mul nuw <1 x i32> [[BLOCK4]], [[SPLAT_SPLAT6]] +; CHECK-NEXT: [[TMP4:%.*]] = add nuw <1 x i32> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP5]], <2 x i32> +; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x i32> undef, i32 [[TMP7]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT8]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = mul nuw <1 x i32> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x i32> undef, i32 [[TMP9]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT11]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = mul nuw <1 x i32> [[BLOCK10]], [[SPLAT_SPLAT12]] +; CHECK-NEXT: [[TMP11:%.*]] = add nuw <1 x i32> [[TMP8]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <1 x i32> [[TMP11]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> [[TMP12]], <2 x i32> +; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x i32> undef, i32 [[TMP14]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT14]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = mul nuw <1 x i32> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x i32> undef, i32 [[TMP16]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT17]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP17:%.*]] = mul nuw <1 x i32> [[BLOCK16]], [[SPLAT_SPLAT18]] +; CHECK-NEXT: [[TMP18:%.*]] = add nuw <1 x i32> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x i32> [[TMP18]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP19]], <2 x i32> +; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x i32> undef, i32 [[TMP21]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT20]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP22:%.*]] = mul nuw <1 x i32> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x i32> undef, i32 [[TMP23]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT23]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP24:%.*]] = mul nuw <1 x i32> [[BLOCK22]], [[SPLAT_SPLAT24]] +; CHECK-NEXT: [[TMP25:%.*]] = add nuw <1 x i32> [[TMP22]], [[TMP24]] +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <1 x i32> [[TMP25]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x i32> [[TMP20]], <2 x i32> [[TMP26]], <2 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <2 x i32> [[TMP13]], <2 x i32> [[TMP27]], <4 x i32> +; CHECK-NEXT: ret <4 x i32> [[TMP28]] +; +entry: + %c = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2) [ "matrix_nuw"(i1 true), "matrix_nsw"(i1 false) ] + ret <4 x i32> %c +} + + +define <4 x i32> @multiply_2x2_nsw(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: @multiply_2x2_nsw( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <4 x i32> [[B:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <4 x i32> [[B]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i32> undef, i32 [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = mul nsw <1 x i32> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x i32> undef, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT5]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = mul nsw <1 x i32> [[BLOCK4]], [[SPLAT_SPLAT6]] +; CHECK-NEXT: [[TMP4:%.*]] = add nsw <1 x i32> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP5]], <2 x i32> +; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x i32> undef, i32 [[TMP7]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT8]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = mul nsw <1 x i32> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x i32> undef, i32 [[TMP9]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT11]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = mul nsw <1 x i32> [[BLOCK10]], [[SPLAT_SPLAT12]] +; CHECK-NEXT: [[TMP11:%.*]] = add nsw <1 x i32> [[TMP8]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <1 x i32> [[TMP11]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> [[TMP12]], <2 x i32> +; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x i32> undef, i32 [[TMP14]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT14]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = mul nsw <1 x i32> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x i32> undef, i32 [[TMP16]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT17]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP17:%.*]] = mul nsw <1 x i32> [[BLOCK16]], [[SPLAT_SPLAT18]] +; CHECK-NEXT: [[TMP18:%.*]] = add nsw <1 x i32> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x i32> [[TMP18]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP19]], <2 x i32> +; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x i32> undef, i32 [[TMP21]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT20]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP22:%.*]] = mul nsw <1 x i32> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x i32> undef, i32 [[TMP23]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT23]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP24:%.*]] = mul nsw <1 x i32> [[BLOCK22]], [[SPLAT_SPLAT24]] +; CHECK-NEXT: [[TMP25:%.*]] = add nsw <1 x i32> [[TMP22]], [[TMP24]] +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <1 x i32> [[TMP25]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x i32> [[TMP20]], <2 x i32> [[TMP26]], <2 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <2 x i32> [[TMP13]], <2 x i32> [[TMP27]], <4 x i32> +; CHECK-NEXT: ret <4 x i32> [[TMP28]] +; +entry: + %c = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2) [ "matrix_nuw"(i1 false), "matrix_nsw"(i1 true) ] + ret <4 x i32> %c +} + +define <4 x i32> @multiply_2x2_nuw_nsw(<4 x i32> %a, <4 x i32> %b) { +; CHECK-LABEL: @multiply_2x2_nuw_nsw( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT1:%.*]] = shufflevector <4 x i32> [[A]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT2:%.*]] = shufflevector <4 x i32> [[B:%.*]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[SPLIT3:%.*]] = shufflevector <4 x i32> [[B]], <4 x i32> undef, <2 x i32> +; CHECK-NEXT: [[BLOCK:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i32> undef, i32 [[TMP0]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = mul nuw nsw <1 x i32> [[BLOCK]], [[SPLAT_SPLAT]] +; CHECK-NEXT: [[BLOCK4:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT5:%.*]] = insertelement <1 x i32> undef, i32 [[TMP2]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT6:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT5]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = mul nuw nsw <1 x i32> [[BLOCK4]], [[SPLAT_SPLAT6]] +; CHECK-NEXT: [[TMP4:%.*]] = add nuw nsw <1 x i32> [[TMP1]], [[TMP3]] +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <1 x i32> [[TMP4]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP5]], <2 x i32> +; CHECK-NEXT: [[BLOCK7:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT8:%.*]] = insertelement <1 x i32> undef, i32 [[TMP7]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT9:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT8]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw <1 x i32> [[BLOCK7]], [[SPLAT_SPLAT9]] +; CHECK-NEXT: [[BLOCK10:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <2 x i32> [[SPLIT2]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT11:%.*]] = insertelement <1 x i32> undef, i32 [[TMP9]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT12:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT11]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP10:%.*]] = mul nuw nsw <1 x i32> [[BLOCK10]], [[SPLAT_SPLAT12]] +; CHECK-NEXT: [[TMP11:%.*]] = add nuw nsw <1 x i32> [[TMP8]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <1 x i32> [[TMP11]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> [[TMP12]], <2 x i32> +; CHECK-NEXT: [[BLOCK13:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT14:%.*]] = insertelement <1 x i32> undef, i32 [[TMP14]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT15:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT14]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP15:%.*]] = mul nuw nsw <1 x i32> [[BLOCK13]], [[SPLAT_SPLAT15]] +; CHECK-NEXT: [[BLOCK16:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT17:%.*]] = insertelement <1 x i32> undef, i32 [[TMP16]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT18:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT17]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP17:%.*]] = mul nuw nsw <1 x i32> [[BLOCK16]], [[SPLAT_SPLAT18]] +; CHECK-NEXT: [[TMP18:%.*]] = add nuw nsw <1 x i32> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = shufflevector <1 x i32> [[TMP18]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP19]], <2 x i32> +; CHECK-NEXT: [[BLOCK19:%.*]] = shufflevector <2 x i32> [[SPLIT]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP21:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 0 +; CHECK-NEXT: [[SPLAT_SPLATINSERT20:%.*]] = insertelement <1 x i32> undef, i32 [[TMP21]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT21:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT20]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP22:%.*]] = mul nuw nsw <1 x i32> [[BLOCK19]], [[SPLAT_SPLAT21]] +; CHECK-NEXT: [[BLOCK22:%.*]] = shufflevector <2 x i32> [[SPLIT1]], <2 x i32> undef, <1 x i32> +; CHECK-NEXT: [[TMP23:%.*]] = extractelement <2 x i32> [[SPLIT3]], i64 1 +; CHECK-NEXT: [[SPLAT_SPLATINSERT23:%.*]] = insertelement <1 x i32> undef, i32 [[TMP23]], i32 0 +; CHECK-NEXT: [[SPLAT_SPLAT24:%.*]] = shufflevector <1 x i32> [[SPLAT_SPLATINSERT23]], <1 x i32> undef, <1 x i32> zeroinitializer +; CHECK-NEXT: [[TMP24:%.*]] = mul nuw nsw <1 x i32> [[BLOCK22]], [[SPLAT_SPLAT24]] +; CHECK-NEXT: [[TMP25:%.*]] = add nuw nsw <1 x i32> [[TMP22]], [[TMP24]] +; CHECK-NEXT: [[TMP26:%.*]] = shufflevector <1 x i32> [[TMP25]], <1 x i32> undef, <2 x i32> +; CHECK-NEXT: [[TMP27:%.*]] = shufflevector <2 x i32> [[TMP20]], <2 x i32> [[TMP26]], <2 x i32> +; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <2 x i32> [[TMP13]], <2 x i32> [[TMP27]], <4 x i32> +; CHECK-NEXT: ret <4 x i32> [[TMP28]] +; +entry: + %c = call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32(<4 x i32> %a, <4 x i32> %b, i32 2, i32 2, i32 2) [ "matrix_nuw"(i1 true), "matrix_nsw"(i1 true) ] + ret <4 x i32> %c +} diff --git a/llvm/test/Verifier/matrix-intrinsics.ll b/llvm/test/Verifier/matrix-intrinsics.ll --- a/llvm/test/Verifier/matrix-intrinsics.ll +++ b/llvm/test/Verifier/matrix-intrinsics.ll @@ -21,6 +21,7 @@ declare <4 x float> @llvm.matrix.column.major.load.v4f32.p0v4f32(<4 x float>*, i64, i1, i32, i32) declare <6 x float> @llvm.matrix.column.major.load.v6f32.p0v6f32(<6 x float>*, i64, i1, i32, i32) + define <4 x float> @column.major_load(<4 x float>* %m, <6 x float>* %n) { ; CHECK-NEXT: result of a matrix operation does not fit in the returned vector ; CHECK-NEXT: result of a matrix operation does not fit in the returned vector @@ -38,3 +39,42 @@ call void @llvm.matrix.column.major.store.v6f32.p0v6f32(<6 x float> zeroinitializer, <6 x float>* %n, i64 2, i1 false, i32 3, i32 3) ret void } + +declare <6 x i32 > @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32>, <6 x i32>, i32, i32, i32) +define <4 x float> @multiply_nuw_bundle_args(<4 x float> %m, <6 x i32> %n, i1 %c, i32 %d) { +; CHECK-NEXT: matrix_nuw bundle operand must be a constant integer + %result.1 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nuw"(i1 %c) ] + +; CHECK-NEXT: matrix_nuw bundle operand must be an integer with bitwidth 1 + %result.2 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nuw"(i32 %d) ] + +; CHECK-NEXT: matrix_nuw bundle operand must be an integer with bitwidth 1 + %result.3 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nuw"(<6 x i32> %n) ] + +; CHECK-NEXT: matrix_nuw bundle must have a single argument + %result.4 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nuw"(i1 %c, i1 true) ] + +; CHECK-NEXT: matrix_nuw bundle only supported for integer multiplications + %result.5 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %m, <4 x float> %m, i32 2, i32 2, i32 1) [ "matrix_nuw"(i1 true) ] + + ret <4 x float> %result.5 +} + +define <4 x float> @multiply_nsw_bundle_args(<4 x float> %m, <6 x i32> %n, i1 %c, i32 %d) { +; CHECK-NEXT: matrix_nsw bundle operand must be a constant integer + %result.1 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nsw"(i1 %c) ] + +; CHECK-NEXT: matrix_nsw bundle operand must be an integer with bitwidth 1 + %result.2 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nsw"(i32 %d) ] + +; CHECK-NEXT: matrix_nsw bundle operand must be an integer with bitwidth 1 + %result.3 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nsw"(<6 x i32> %n) ] + +; CHECK-NEXT: matrix_nsw bundle must have a single argument + %result.4 = call <6 x i32> @llvm.matrix.multiply.v6i32.v6i32.v6i32(<6 x i32> %n, <6 x i32> %n, i32 2, i32 3, i32 2) [ "matrix_nsw"(i1 %c, i1 true) ] + +; CHECK-NEXT: matrix_nsw bundle only supported for integer multiplications + %result.5 = call <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32(<4 x float> %m, <4 x float> %m, i32 2, i32 2, i32 1) [ "matrix_nsw"(i1 true) ] + + ret <4 x float> %result.5 +}