diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -1829,6 +1829,67 @@ break; } + case Intrinsic::matrix_multiply: { + // Optimise negation in matrix multiplication. + // If we have a negated operand where it's size is larger than the second + // operand or the result We can optimise the result by moving the negation + // operation to the smallest operand in the equation This covers two cases: + // Case 1: the operand has the smalest element count i.e + // (-A) * B = A * (-B) + // Case 2: the result has the smalest element count + // (-A) * B = -(A * B) + Value *X; + + Value *Op0 = II->getArgOperand(0); + Value *Op1 = II->getArgOperand(1); + + VectorType *RetType = dyn_cast(II->getType()); + Instruction *FNegOp; + Value *SecondOperand; + unsigned SecondOperandArg; + bool MatchOp0 = match(Op0, m_FNeg(m_Value(X))); + bool MatchOp1 = match(Op1, m_FNeg(m_Value(X))); + + if (MatchOp0) { + FNegOp = cast(Op0); + SecondOperand = Op1; + SecondOperandArg = 1; + } else if (MatchOp1) { + FNegOp = cast(Op1); + SecondOperand = Op0; + SecondOperandArg = 0; + } else { + break; + } + if (!FNegOp->hasOneUse()) + break; + + Value *OpNotNeg = FNegOp->getOperand(0); + VectorType *FNegType = dyn_cast(FNegOp->getType()); + VectorType *SecondOperandType = cast(SecondOperand->getType()); + if (ElementCount::isKnownGT(FNegType->getElementCount(), + SecondOperandType->getElementCount()) && + ElementCount::isKnownLT(SecondOperandType->getElementCount(), + RetType->getElementCount())) { + replaceInstUsesWith(*FNegOp, OpNotNeg); + Value *InverseSecondOp = Builder.CreateFNeg(SecondOperand); + Instruction *NewCall = II->clone(); + NewCall->setOperand(SecondOperandArg, InverseSecondOp); + NewCall->insertAfter(II); + return replaceInstUsesWith(*II, NewCall); + } + if (ElementCount::isKnownGT(FNegType->getElementCount(), + RetType->getElementCount())) { + replaceInstUsesWith(*FNegOp, OpNotNeg); + // Insert after call instruction + Builder.SetInsertPoint(II->getNextNode()); + Instruction *FNegInst = cast(Builder.CreateFNeg(II)); + replaceInstUsesWith(*II, FNegInst); + FNegInst->setOperand(0, II); + return II; + } + break; + } case Intrinsic::fmuladd: { // Canonicalize fast fmuladd to the separate fmul + fadd. if (II->isFast()) { @@ -1852,6 +1913,7 @@ [[fallthrough]]; } + case Intrinsic::fma: { // fma fneg(x), fneg(y), z -> fma x, y, z Value *Src0 = II->getArgOperand(0); diff --git a/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll @@ -0,0 +1,209 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +; The result has the fewest vector elements between the result and the two operands so the negation can be moved there +define <2 x double> @test_negation_move_to_result(<6 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_negation_move_to_result( +; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x double> [[RES]] +; CHECK-NEXT: ret <2 x double> [[TMP1]] +; + %a.neg = fneg <6 x double> %a + %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1) + ret <2 x double> %res +} + +; The result has the fewest vector elements between the result and the two operands so the negation can be moved there +; Fast flag should be preserved +define <2 x double> @test_negation_move_to_result_with_fastflags(<6 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_negation_move_to_result_with_fastflags( +; CHECK-NEXT: [[RES:%.*]] = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x double> [[RES]] +; CHECK-NEXT: ret <2 x double> [[TMP1]] +; + %a.neg = fneg <6 x double> %a + %res = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1) + ret <2 x double> %res +} + +; %b has the fewest vector elements between the result and the two operands so the negation can be moved there +define <9 x double> @test_move_negation_to_second_operand(<27 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_move_negation_to_second_operand( +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1) +; CHECK-NEXT: ret <9 x double> [[TMP2]] +; + %a.neg = fneg <27 x double> %a + %res = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> %a.neg, <3 x double> %b, i32 9, i32 3, i32 1) + ret <9 x double> %res +} + +; %b has the fewest vector elements between the result and the two operands so the negation can be moved there +; Fast flag should be preserved +define <9 x double> @test_move_negation_to_second_operand_with_fast_flags(<27 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_move_negation_to_second_operand_with_fast_flags( +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1) +; CHECK-NEXT: ret <9 x double> [[TMP2]] +; + %a.neg = fneg <27 x double> %a + %res = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> %a.neg, <3 x double> %b, i32 9, i32 3, i32 1) + ret <9 x double> %res +} + +; The result has the fewest vector elements between the result and the two operands so the negation can be moved there +define <2 x double> @test_negation_move_to_result_from_second_operand(<3 x double> %a, <6 x double> %b){ +; CHECK-LABEL: @test_negation_move_to_result_from_second_operand( +; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> [[A:%.*]], <6 x double> [[B:%.*]], i32 1, i32 3, i32 2) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <2 x double> [[RES]] +; CHECK-NEXT: ret <2 x double> [[TMP1]] +; + %b.neg = fneg <6 x double> %b + %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> %a, <6 x double> %b.neg, i32 1, i32 3, i32 2) + ret <2 x double> %res +} + +; %a has the fewest vector elements between the result and the two operands so the negation can be moved there +define <9 x double> @test_move_negation_to_first_operand(<3 x double> %a, <27 x double> %b) { +; CHECK-LABEL: @test_move_negation_to_first_operand( +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> [[TMP1]], <27 x double> [[B:%.*]], i32 1, i32 3, i32 9) +; CHECK-NEXT: ret <9 x double> [[TMP2]] +; + %b.neg = fneg <27 x double> %b + %res = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> %a, <27 x double> %b.neg, i32 1, i32 3, i32 9) + ret <9 x double> %res +} + +; %a has the fewest vector elements between the result and the two operands so the negation is not moved +define <15 x double> @test_negation_not_moved(<3 x double> %a, <5 x double> %b) { +; CHECK-LABEL: @test_negation_not_moved( +; CHECK-NEXT: [[A_NEG:%.*]] = fneg <3 x double> [[A:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[A_NEG]], <5 x double> [[B:%.*]], i32 3, i32 1, i32 5) +; CHECK-NEXT: ret <15 x double> [[RES]] +; + %a.neg = fneg <3 x double> %a + %res = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> %a.neg, <5 x double> %b, i32 3, i32 1, i32 5) + ret <15 x double> %res +} + +; %b as the fewest vector elements between the result and the two operands so the negation is not moved +define <15 x double> @test_negation_not_moved_second_operand(<5 x double> %a, <3 x double> %b) { +; CHECK-LABEL: @test_negation_not_moved_second_operand( +; CHECK-NEXT: [[B_NEG:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v5f64.v3f64(<5 x double> [[A:%.*]], <3 x double> [[B_NEG]], i32 5, i32 1, i32 3) +; CHECK-NEXT: ret <15 x double> [[RES]] +; + %b.neg = fneg <3 x double> %b + %res = tail call <15 x double> @llvm.matrix.multiply.v15f64.v5f64.v3f64(<5 x double> %a, <3 x double> %b.neg, i32 5, i32 1, i32 3) + ret <15 x double> %res +} + +; the negation should be moved from the result to operand %a because it has the smallest vector element count +define <15 x double> @test_negation_on_result(<3 x double> %a, <5 x double> %b) { +; CHECK-LABEL: @test_negation_on_result( +; CHECK-NEXT: [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[A:%.*]], <5 x double> [[B:%.*]], i32 3, i32 1, i32 5) +; CHECK-NEXT: [[RES_2:%.*]] = fneg <15 x double> [[RES]] +; CHECK-NEXT: ret <15 x double> [[RES_2]] +; + %res = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> %a, <5 x double> %b, i32 3, i32 1, i32 5) + %res.2 = fneg <15 x double> %res + ret <15 x double> %res.2 +} + +; both negations can be deleted +define <2 x double> @test_with_two_operands_negated1(<6 x double> %a, <3 x double> %b){ +; CHECK-LABEL: @test_with_two_operands_negated1( +; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1) +; CHECK-NEXT: ret <2 x double> [[RES]] +; + %a.neg = fneg <6 x double> %a + %b.neg = fneg <3 x double> %b + %res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b.neg, i32 2, i32 3, i32 1) + ret <2 x double> %res +} + +; both negations will appear on `%b` other passes should optimise ~~b to b +define <9 x double> @test_with_two_operands_negated2(<27 x double> %a, <3 x double> %b){ +; CHECK-LABEL: @test_with_two_operands_negated2( +; CHECK-NEXT: [[TMP1:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1) +; CHECK-NEXT: ret <9 x double> [[TMP1]] +; + %a.neg = fneg <27 x double> %a + %b.neg = fneg <3 x double> %b + %res = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> %a.neg, <3 x double> %b.neg, i32 9, i32 3, i32 1) + ret <9 x double> %res +} + +define <12 x double> @fneg_with_multiple_uses(<15 x double> %a, <20 x double> %b){ +; CHECK-LABEL: @fneg_with_multiple_uses( +; CHECK-NEXT: [[A_NEG:%.*]] = fneg <15 x double> [[A:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A_NEG]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4) +; CHECK-NEXT: [[RES_2:%.*]] = shufflevector <15 x double> [[A_NEG]], <15 x double> undef, <12 x i32> +; CHECK-NEXT: [[RES_3:%.*]] = fadd <12 x double> [[RES_2]], [[RES]] +; CHECK-NEXT: ret <12 x double> [[RES_3]] +; + %a.neg = fneg <15 x double> %a + %res = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> %a.neg, <20 x double> %b, i32 3, i32 5, i32 4) + %res.2 = shufflevector <15 x double> %a.neg, <15 x double> undef, + <12 x i32> + %res.3 = fadd <12 x double> %res.2, %res + ret <12 x double> %res.3 +} + +; negation should be moved to the second operand given it has the smallest operand count +define <72 x double> @chain_of_matrix_mutliplies(<27 x double> %a, <3 x double> %b, <8 x double> %c) { +; CHECK-LABEL: @chain_of_matrix_mutliplies( +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1) +; CHECK-NEXT: [[RES_2:%.*]] = tail call <72 x double> @llvm.matrix.multiply.v72f64.v9f64.v8f64(<9 x double> [[TMP2]], <8 x double> [[C:%.*]], i32 9, i32 1, i32 8) +; CHECK-NEXT: ret <72 x double> [[RES_2]] +; + %a.neg = fneg <27 x double> %a + %res = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> %a.neg, <3 x double> %b, i32 9, i32 3, i32 1) + %res.2 = tail call <72 x double> @llvm.matrix.multiply.v72f64.v9f64.v8f64(<9 x double> %res, <8 x double> %c, i32 9, i32 1, i32 8) + ret <72 x double> %res.2 +} + +; first negation should be moved to %a +; second negation should be moved to the result of the second multipication +define <6 x double> @chain_of_matrix_mutliplies_with_two_negations(<3 x double> %a, <5 x double> %b, <10 x double> %c) { +; CHECK-LABEL: @chain_of_matrix_mutliplies_with_two_negations( +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[TMP1]], <5 x double> [[B:%.*]], i32 3, i32 1, i32 5) +; CHECK-NEXT: [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> [[TMP2]], <10 x double> [[C:%.*]], i32 3, i32 5, i32 2) +; CHECK-NEXT: [[TMP3:%.*]] = fneg <6 x double> [[RES_2]] +; CHECK-NEXT: ret <6 x double> [[TMP3]] +; + %b.neg = fneg <5 x double> %b + %res = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> %a, <5 x double> %b.neg, i32 3, i32 1, i32 5) + %res.neg = fneg <15 x double> %res + %res.2 = tail call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> %res.neg, <10 x double> %c, i32 3, i32 5, i32 2) + ret <6 x double> %res.2 +} + +; negation should be propagated to the result of the second matrix multiplication +define <6 x double> @chain_of_matrix_mutliplies_propagation(<15 x double> %a, <20 x double> %b, <8 x double> %c){ +; CHECK-LABEL: @chain_of_matrix_mutliplies_propagation( +; CHECK-NEXT: [[RES:%.*]] = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A:%.*]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4) +; CHECK-NEXT: [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> [[RES]], <8 x double> [[C:%.*]], i32 3, i32 4, i32 2) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <6 x double> [[RES_2]] +; CHECK-NEXT: ret <6 x double> [[TMP1]] +; + %a.neg = fneg <15 x double> %a + %res = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> %a.neg, <20 x double> %b, i32 3, i32 5, i32 4) + %res.2 = tail call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> %res, <8 x double> %c, i32 3, i32 4, i32 2) + ret <6 x double> %res.2 +} + +declare <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double>, <3 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double>, <6 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double>, <3 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double>, <27 x double>, i32 immarg, i32 immarg, i32 immarg) +declare <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double>, <5 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <15 x double> @llvm.matrix.multiply.v15f64.v5f64.v3f64(<5 x double>, <3 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <72 x double> @llvm.matrix.multiply.v72f64.v9f64.v8f64(<9 x double>, <8 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double>, <20 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <21 x double> @llvm.matrix.multiply.v21f64.v15f64.v35f64(<15 x double>, <35 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double>, <10 x double>, i32 immarg, i32 immarg, i32 immarg) #1 +declare <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double>, <8 x double>, i32 immarg, i32 immarg, i32 immarg) #1