diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -63,6 +63,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/TypeSize.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" @@ -70,6 +71,7 @@ #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include #include +#include #include #include #include @@ -3224,6 +3226,66 @@ OperandBundleDef NewBundle("gc-live", NewLiveGc); return CallBase::Create(&Call, NewBundle); } + case Intrinsic::matrix_multiply: { + // Optimise negation in matrix multiplication. + // If we have a negated operand where it's size is larger than the second + // operand or the result We can optimise the result by moving the negation + // operation to the smallest operand in the equation This covers two cases: + // Case 1: the operand has the smalest element count i.e + // A.neg = ~A + // A.neg * B = C + // optmises to: + // B.neg = ~B + // A * B.neg = C + // + // Case 2: has the smalest element count + // A.neg = ~A + // A.neg * B = C + // optmises to: + // A * B = C + // C.neg = ~C + Value *X; + Value *Op0 = Call.getArgOperand(0); + Value *Op1 = Call.getArgOperand(1); + VectorType *RetType = dyn_cast(Call.getType()); + Instruction *FNegOp; + Value *SecondOperand; + unsigned SecondOperandArg; + if (match(Op0, m_FNeg(m_Value(X)))) { + FNegOp = dyn_cast(Op0); + SecondOperand = Op1; + SecondOperandArg = 1; + } else if (match(Op1, m_FNeg(m_Value(X)))) { + FNegOp = dyn_cast(Op1); + SecondOperand = Op0; + SecondOperandArg = 0; + } else { + break; + } + Value *OpNotNeg = FNegOp->getOperand(0); + VectorType *FNegType = dyn_cast(FNegOp->getType()); + VectorType *SecondOperandType = + dyn_cast(SecondOperand->getType()); + if (ElementCount::isKnownGT(FNegType->getElementCount(), + SecondOperandType->getElementCount()) && + ElementCount::isKnownLT(SecondOperandType->getElementCount(), + RetType->getElementCount())) { + replaceInstUsesWith(*FNegOp, OpNotNeg); + replaceOperand(Call, SecondOperandArg, Builder.CreateFNeg(SecondOperand)); + } else if (ElementCount::isKnownGT(FNegType->getElementCount(), + RetType->getElementCount())) { + replaceInstUsesWith(*FNegOp, OpNotNeg); + // Insert after call instruction + Builder.SetInsertPoint(Call.getNextNode()); + Instruction *FNegInst = dyn_cast(Builder.CreateFNeg(&Call)); + replaceInstUsesWith(Call, FNegInst); + FNegInst->setOperand(0, &Call); + } else { + return nullptr; + } + + return &Call; + } default: { break; } } diff --git a/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll --- a/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll +++ b/llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll @@ -3,9 +3,9 @@ define <3 x double> @matrix_multiply_v9f64_v3f64(<9 x double> %a, <3 x double> %b) { ; CHECK-LABEL: @matrix_multiply_v9f64_v3f64( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <9 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <3 x double> @llvm.matrix.multiply.v3f64.v9f64.v3f64(<9 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 3, i32 3, i32 1) -; CHECK-NEXT: ret <3 x double> [[RES]] +; CHECK-NEXT: [[RES:%.*]] = tail call <3 x double> @llvm.matrix.multiply.v3f64.v9f64.v3f64(<9 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 3, i32 3, i32 1) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[RES]] +; CHECK-NEXT: ret <3 x double> [[TMP1]] ; %a.neg = fneg <9 x double> %a %res = tail call <3 x double> @llvm.matrix.multiply.v3f64.v9f64.v3f64(<9 x double> %a.neg, <3 x double> %b, i32 3, i32 3, i32 1) @@ -15,8 +15,8 @@ define <9 x double> @matrix_multiply_v27f64_v3f64(<27 x double> %a, <3 x double> %b) { ; CHECK-LABEL: @matrix_multiply_v27f64_v3f64( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]] +; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1) ; CHECK-NEXT: ret <9 x double> [[RES]] ; %a.neg = fneg <27 x double> %a @@ -27,9 +27,9 @@ define <12 x double> @matrix_multiply_v15f64_v20f64(<15 x double> %a, <20 x double> %b) { ; CHECK-LABEL: @matrix_multiply_v15f64_v20f64( -; CHECK-NEXT: [[A_NEG:%.*]] = fneg <15 x double> [[A:%.*]] -; CHECK-NEXT: [[RES:%.*]] = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A_NEG]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4) -; CHECK-NEXT: ret <12 x double> [[RES]] +; CHECK-NEXT: [[RES:%.*]] = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A:%.*]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4) +; CHECK-NEXT: [[TMP1:%.*]] = fneg <12 x double> [[RES]] +; CHECK-NEXT: ret <12 x double> [[TMP1]] ; %a.neg = fneg <15 x double> %a %res = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> %a.neg, <20 x double> %b, i32 3, i32 5, i32 4)