diff --git a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp --- a/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp +++ b/llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp @@ -1359,10 +1359,12 @@ return; auto CanBeFlattened = [](Value *Op) { - return match(Op, m_OneUse(m_CombineOr( - m_Load(m_Value()), - m_Intrinsic( - m_Value(), m_SpecificInt(1))))); + return match( + Op, m_OneUse(m_CombineOr( + m_Load(m_Value()), + m_CombineOr(m_Intrinsic(), + m_Intrinsic( + m_Value(), m_SpecificInt(1)))))); }; // Returns the cost benefit of using \p Op with the dot product lowering. If // the returned cost is < 0, the argument is cheaper to use in the @@ -1374,21 +1376,34 @@ FixedVectorType *VecTy = cast(Op->getType()); Type *EltTy = VecTy->getElementType(); - if (CanBeFlattened(Op)) { - if (N == 1) - return InstructionCost(0); + if (!CanBeFlattened(Op)) { + InstructionCost EmbedCost(0); + // Roughly estimate the cost for embedding the columns into a vector. + for (unsigned I = 1; I < N; ++I) + EmbedCost -= + TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), + std::nullopt, TTI::TCK_RecipThroughput); + return EmbedCost; + } - return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - - N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); + if (match(Op, m_Intrinsic())) { + // The transpose can be skipped for the dot product lowering, roughly + // estimate the savings as the cost of embedding the columns in a + // vector. + InstructionCost EmbedCost(0); + for (unsigned I = 1; I < N; ++I) + EmbedCost += + TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), + std::nullopt, TTI::TCK_RecipThroughput); + return EmbedCost; } - InstructionCost EmbedCost(0); - // Roughly estimate the cost for embedding the columns into a vector. - for (unsigned I = 1; I < N; ++I) - EmbedCost += - TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1), - std::nullopt, TTI::TCK_RecipThroughput); - return EmbedCost; + // Costs for loads. + if (N == 1) + return InstructionCost(0); + + return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) - + N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0); }; auto LHSCost = GetCostForArg(LHS, LShape.NumColumns); @@ -1410,8 +1425,8 @@ FusedInsts.insert(MatMul); IRBuilder<> Builder(MatMul); - auto FlattenArg = [&Builder, &FusedInsts, - &CanBeFlattened](Value *Op) -> Value * { + auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened, + this](Value *Op) -> Value * { // Matmul must be the only user of loads because we don't use LowerLoad // for row vectors (LowerLoad results in scalar loads and shufflevectors // instead of single vector load). @@ -1419,15 +1434,21 @@ return Op; FusedInsts.insert(cast(Op)); + // If vector uses the builtin load, lower to a LoadInst - Value *Ptr; + Value *Arg; if (match(Op, m_Intrinsic( - m_Value(Ptr)))) { - auto *NewLoad = Builder.CreateLoad(Op->getType(), Ptr); + m_Value(Arg)))) { + auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg); Op->replaceAllUsesWith(NewLoad); cast(Op)->eraseFromParent(); return NewLoad; + } else if (match(Op, m_Intrinsic( + m_Value(Arg)))) { + ToRemove.push_back(cast(Op)); + return Arg; } + return Op; }; LHS = FlattenArg(LHS); diff --git a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll --- a/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll +++ b/llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll @@ -5,21 +5,9 @@ define void @transposed_multiply_feeding_dot_product_v4i322(<4 x i32> %a, <4 x i32> %b) { ; CHECK-LABEL: @transposed_multiply_feeding_dot_product_v4i322( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <4 x i32> -; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 0 -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <1 x i32> poison, i32 [[TMP0]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 1 -; CHECK-NEXT: [[TMP3:%.*]] = insertelement <1 x i32> poison, i32 [[TMP2]], i64 0 -; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 2 -; CHECK-NEXT: [[TMP5:%.*]] = insertelement <1 x i32> poison, i32 [[TMP4]], i64 0 -; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 3 -; CHECK-NEXT: [[TMP7:%.*]] = insertelement <1 x i32> poison, i32 [[TMP6]], i64 0 -; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP1]], <1 x i32> [[TMP3]], <2 x i32> -; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP5]], <1 x i32> [[TMP7]], <2 x i32> -; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> -; CHECK-NEXT: [[TMP11:%.*]] = mul <4 x i32> [[TMP10]], [[B:%.*]] -; CHECK-NEXT: [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP11]]) -; CHECK-NEXT: [[TMP13:%.*]] = insertelement <1 x i32> poison, i32 [[TMP12]], i64 0 +; CHECK-NEXT: [[TMP0:%.*]] = mul <4 x i32> [[A:%.*]], [[B:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP0]]) +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <1 x i32> poison, i32 [[TMP1]], i64 0 ; CHECK-NEXT: ret void ; entry: @@ -61,18 +49,10 @@ ; CHECK-NEXT: [[TMP11:%.*]] = add <2 x i32> [[TMP8]], [[TMP10]] ; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP11]], <2 x i32> poison, <2 x i32> ; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP12]], <2 x i32> -; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[TMP6]], i64 0 -; CHECK-NEXT: [[TMP15:%.*]] = insertelement <2 x i32> poison, i32 [[TMP14]], i64 0 -; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[TMP13]], i64 0 -; CHECK-NEXT: [[TMP17:%.*]] = insertelement <2 x i32> [[TMP15]], i32 [[TMP16]], i64 1 -; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x i32> [[TMP6]], i64 1 -; CHECK-NEXT: [[TMP19:%.*]] = insertelement <2 x i32> poison, i32 [[TMP18]], i64 0 -; CHECK-NEXT: [[TMP20:%.*]] = extractelement <2 x i32> [[TMP13]], i64 1 -; CHECK-NEXT: [[TMP21:%.*]] = insertelement <2 x i32> [[TMP19]], i32 [[TMP20]], i64 1 -; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <2 x i32> [[TMP17]], <2 x i32> [[TMP21]], <4 x i32> -; CHECK-NEXT: [[TMP23:%.*]] = mul <4 x i32> [[TMP22]], [[C:%.*]] -; CHECK-NEXT: [[TMP24:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP23]]) -; CHECK-NEXT: [[TMP25:%.*]] = insertelement <1 x i32> poison, i32 [[TMP24]], i64 0 +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> [[TMP13]], <4 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = mul <4 x i32> [[TMP14]], [[C:%.*]] +; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP15]]) +; CHECK-NEXT: [[TMP17:%.*]] = insertelement <1 x i32> poison, i32 [[TMP16]], i64 0 ; CHECK-NEXT: ret void ; entry: