diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -748,6 +748,10 @@ /// /// If the multiplication is known not to overflow then NoSignedWrap is set. Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); + + /// Try to match a complex intrinsic that produces the given real/imaginary + /// pair. Returns whether or not it was successful. + bool createComplexMathInstruction(Value *Real, Value *Imag); }; class Negator final { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -1405,6 +1405,33 @@ eraseInstFromFunction(*PrevSI); return nullptr; } + + // Is this potentially a complex instruction? + auto OurGEP = dyn_cast(Ptr); + auto TheirGEP = dyn_cast(PrevSI->getOperand(1)); + if (PrevSI->isUnordered() && OurGEP && TheirGEP && + OurGEP->getOperand(0) == TheirGEP->getOperand(0) && + OurGEP->getNumIndices() == TheirGEP->getNumIndices() && + OurGEP->getType() == TheirGEP->getType()) { + bool AllMatch = true; + unsigned LastIndex = OurGEP->getNumIndices(); + for (unsigned Index = 1; Index < LastIndex; Index++) { + if (OurGEP->getOperand(Index) != TheirGEP->getOperand(Index)) { + AllMatch = false; + break; + } + } + if (!AllMatch) + break; + if (match(OurGEP->getOperand(LastIndex), m_ConstantInt<1>()) && + match(TheirGEP->getOperand(LastIndex), m_ConstantInt<0>())) { + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(PrevSI); + if (createComplexMathInstruction(PrevSI->getOperand(0), Val)) + return &SI; + } + } + break; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1126,6 +1126,21 @@ if (Instruction *NewI = foldAggregateConstructionIntoAggregateReuse(I)) return NewI; + // Check if this is potentially a complex instruction that has been manually + // expanded. + ArrayRef Fields = I.getType()->subtypes(); + if (Fields.size() == 2 && Fields[0] == Fields[1] && + Fields[0]->isFloatingPointTy()) { + Value *RealV, *ImgV; + if (match(&I, m_InsertValue<1>(m_InsertValue<0>(m_Value(), m_Value(RealV)), + m_Value(ImgV)))) { + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(cast(I.getOperand(0))); + if (createComplexMathInstruction(RealV, ImgV)) + return &I; + } + } + return nullptr; } @@ -1611,6 +1626,17 @@ if (Instruction *Ext = narrowInsElt(IE, Builder)) return Ext; + // Check for a potential computation of a complex instruction. + ElementCount Count = IE.getType()->getElementCount(); + Value *RealV, *ImagV; + if (!Count.isScalable() && Count.getFixedValue() == 2 && + match(&IE, m_InsertElt( + m_InsertElt(m_Value(), m_Value(RealV), m_ConstantInt<0>()), + m_Value(ImagV), m_ConstantInt<1>()))) { + if (createComplexMathInstruction(RealV, ImagV)) + return &IE; + } + return nullptr; } @@ -2815,3 +2841,120 @@ return MadeChange ? &SVI : nullptr; } + +static cl::opt InstCombineComplex( + "inst-combine-complex", + cl::desc("Enable pattern match to llvm.complex.* intrinsics")); + +bool InstCombinerImpl::createComplexMathInstruction(Value *Real, Value *Imag) { + if (!InstCombineComplex) + return false; + + Instruction *RealI = dyn_cast(Real); + Instruction *ImagI = dyn_cast(Imag); + if (!RealI || !ImagI) + return false; + + // Don't try to handle vector instructions for now. + if (RealI->getType()->isVectorTy()) + return false; + + Value *Op0R, *Op0I, *Op1R, *Op1I, *Scale, *Numerator; + // Compute the intersection of all the fast math flags of the entire tree up + // to the point that the input complex numbers are specified. + auto computeFMF = [&]() { + SmallVector Worklist = {RealI, ImagI}; + FastMathFlags Flags; + Flags.set(); + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + Flags &= I->getFastMathFlags(); + for (Use &U : I->operands()) { + Value *V = U.get(); + if (V == Op0R || V == Op0I || V == Op1R || V == Op1I) + continue; + Worklist.push_back(cast(V)); + } + } + return Flags; + }; + + Intrinsic::ID NewIntrinsic = Intrinsic::not_intrinsic; + // Check for complex multiply: + // real = op0.real * op1.real - op0.imag * op1.imag + // imag = op0.real * op1.imag + op1.imag * op0.real + if (match(Real, m_FSub(m_OneUse(m_FMul(m_Value(Op0R), m_Value(Op1R))), + m_OneUse(m_FMul(m_Value(Op0I), m_Value(Op1I)))))) { + if (match( + Imag, + m_c_FAdd(m_OneUse(m_c_FMul(m_Specific(Op0R), m_Specific(Op1I))), + m_OneUse(m_c_FMul(m_Specific(Op1R), m_Specific(Op0I)))))) { + NewIntrinsic = Intrinsic::experimental_complex_fmul; + } + } + // Check for complex div: + // real = (op0.real * op1.real + op0.imag * op1.imag) / scale + // imag = (op0.imag * op1.real - op0.real * op1.imag) / scale + // where scale = op1.real * op1.real + op1.imag * op1.imag + else if (match(Imag, m_FDiv(m_Value(Numerator), m_Value(Scale)))) { + if (match(Scale, + m_FAdd(m_OneUse(m_FMul(m_Value(Op1R), m_Deferred(Op1R))), + m_OneUse(m_FMul(m_Value(Op1I), m_Deferred(Op1I)))))) { + // The matching of Op1R and Op1I are temporary, we may need to reverse the + // assignments. + auto checkNumerator = [&]() { + return match(Numerator, + m_OneUse(m_FSub( + m_OneUse(m_c_FMul(m_Value(Op0I), m_Specific(Op1R))), + m_OneUse(m_c_FMul(m_Value(Op0R), m_Specific(Op1I)))))); + }; + bool ImagMatches = checkNumerator(); + if (!ImagMatches) { + std::swap(Op1R, Op1I); + ImagMatches = checkNumerator(); + } + if (ImagMatches && + match(Real, + m_FDiv(m_OneUse(m_c_FAdd(m_OneUse(m_c_FMul(m_Specific(Op0R), + m_Specific(Op1R))), + m_OneUse(m_c_FMul(m_Specific(Op0I), + m_Specific(Op1I))))), + m_Specific(Scale)))) { + NewIntrinsic = Intrinsic::experimental_complex_fdiv; + } + } + } + + // Make sure we matched an intrinsic. + if (NewIntrinsic == Intrinsic::not_intrinsic) + return false; + + // Use the computation tree to capture all of the fast-math flags. + IRBuilderBase::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(computeFMF()); + + Value *Op0 = Builder.CreateComplexValue(Op0R, Op0I); + Value *Op1 = Builder.CreateComplexValue(Op1R, Op1I); + + // Create new intrinsics. From our pattern matching of only the direct + // arithmetic formulas, we have to create them with complex-range=limited. + Value *Result; + switch (NewIntrinsic) { + case Intrinsic::experimental_complex_fmul: + Result = Builder.CreateComplexMul(Op0, Op1, true); + break; + case Intrinsic::experimental_complex_fdiv: + Result = Builder.CreateComplexDiv(Op0, Op1, true); + break; + default: + llvm_unreachable("Unexpected complex intrinsic"); + } + + replaceInstUsesWith(*RealI, + Builder.CreateExtractElement(Result, uint64_t(0))); + replaceInstUsesWith(*ImagI, + Builder.CreateExtractElement(Result, uint64_t(1))); + + return true; +}