Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -164,7 +164,8 @@ Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); - Value *performFactorization(Instruction *I); + Value *performFactorization(Value *Op0, Value *Op1, unsigned Opcode, FastMathFlags FMF); + Value *performFactorizationAssociative(Instruction *I); /// Convert given addend to a Value Value *createAddendVal(const FAddend &A, bool& NeedNeg); @@ -429,6 +430,184 @@ return BreakNum; } +// Here if we know that I is either FAdd or FSub +// see if we can factorize them individually. +// The following should be possible as we would reach here if only +// unsafeAlgebra is enabled for this instruction. +// -> Transform (A op1 B) op2 C -> A op3 (B op4 C) if (B op4 C) factorizes +// Eg. (A + X * C1) + X * C2 -> A + X * (C1 + C2) +// -> Transform A op1 (B op2 C) -> (A op3 B) op4 C) if (A op3 B) factorizes +// Eg. (A + X * C1) - X * C2 -> A + X * (C1-C2) +// -> Transform ( A op1 B) op2 C -> (A op3 C) op4 B if (A op3 C) factorizes +// Eg. (X * C1 - B) + X * C2 -> X * (C1 - C2) - B +// -> Transform A op1 (B op2 C) -> (A op3 C) op4 B +// Eg. X * C1 - (B + X * C2) -> X * (C1 - C2) - B +Value *FAddCombine::performFactorizationAssociative(Instruction *I) { + assert(I->hasUnsafeAlgebra() && "This method can't be called without unsafe algebra"); + if (I->getOpcode() != Instruction::FAdd && I->getOpcode() != Instruction::FSub) + return nullptr; + // Be overly conservative for now? + if (I->getNumUses() != 1) + return nullptr; + + BinaryOperator *Op0 = dyn_cast(I->getOperand(0)); + BinaryOperator *Op1 = dyn_cast(I->getOperand(1)); + { + // (A op B) op C -> A op (B op C) Simplify? + bool Op0Factorizable = Op0 && (Op0->getOpcode() == Instruction::FAdd || + Op0->getOpcode() == Instruction::FSub); + if (Op0Factorizable) { + unsigned FactorizeOpcode, FinalOpcode; + Value *OpFactor0, *OpFactor1; + bool IsOp0OpcodeAdd = (Op0->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I->getOperand(1); + // (a + b) + c -> a + ( b + c ) ? + // (a + b) - c -> a + ( b - c ) ? + if (IsOp0OpcodeAdd) { + OpFactor0 = B; + OpFactor1 = C; + FactorizeOpcode = I->getOpcode(); + FinalOpcode = Instruction::FAdd; + } + // (a - b) + c -> a + (c - b) ? + else if (!IsOp0OpcodeAdd && IsIOpcodeAdd) { + OpFactor0 = C; + OpFactor1 = B; + FactorizeOpcode = Instruction::FSub; + FinalOpcode = Instruction::FAdd; + } + // (a - b) - c -> a - (b + c) ? + else { + OpFactor0 = B; + OpFactor1 = C; + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + if (Value *V = performFactorization(OpFactor0, OpFactor1, FactorizeOpcode, I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? + createFAdd(A, V) : + createFSub(A, V); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + { + // Transform: "A op (B op C)" ==> "(A op B) op C" if "A op B" factorizes. + bool Op1Factorizable = Op1 && (Op1->getOpcode() == Instruction::FAdd || + Op1->getOpcode() == Instruction::FSub); + if (Op1Factorizable) { + unsigned FactorizeOpcode, FinalOpcode; + bool IsOp1OpcodeAdd = (Op1->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + Value *A = I->getOperand(0); + Value *B = Op1->getOperand(1); + Value *C = Op1->getOperand(1); + // A + (B + C) -> (A + B) + C factorizes? + if (IsIOpcodeAdd && IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FAdd; + } + // A + (B - C) -> (A + B) - C factorizes? + if (IsIOpcodeAdd && !IsOp1OpcodeAdd) { + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + // A - (B + C) -> (A - B) - C factorizes? + if (!IsIOpcodeAdd && IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FSub; + } + // A - (B - C) -> (A - B) + C factorizes? + if (!IsIOpcodeAdd && !IsOp1OpcodeAdd) { + FactorizeOpcode = Instruction::FSub; + FinalOpcode = Instruction::FAdd; + } + if (Value *V = performFactorization(A, B, FactorizeOpcode, I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? + createFAdd(V, C) + : createFSub(V, C); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + { + // We know that op1 and op2 can only be FAdd or FSub + // (A op1 B) op2 C -> (A op2 C) op1 B factorizes? + bool Op0Valid = Op0 && (Op0->getOpcode() == Instruction::FAdd || + Op0->getOpcode() == Instruction::FSub); + if (Op0Valid) { + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I->getOperand(1); + if (Value *V = performFactorization(A, C, I->getOpcode(), I->getFastMathFlags())) { + Value *NewV = (Op0->getOpcode() == Instruction::FAdd) ? + createFAdd(V, B) + : createFSub(V, B); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + { + // A op1 (B op2 C) -> (A op2 C) op1 B factorizes? + bool Op1Valid = Op1 && (Op1->getOpcode() == Instruction::FAdd || + Op1->getOpcode() == Instruction::FSub); + if (Op1Valid) { + Value *A = I->getOperand(0); + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + unsigned FactorizeOpcode, FinalOpcode; + bool IsOp1OpcodeAdd = (Op1->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + // A + (B + C) -> (A + C) + B simplifies? + // A + (B - C) -> (A - C) + B simplifies? + if (IsIOpcodeAdd) { + FactorizeOpcode = Op1->getOpcode(); + FinalOpcode = Instruction::FAdd; + } + // A - (B + C) -> (A - C) - B + else if(IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FSub; + } else { + // A - (B - C) -> (A + C) - B + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + if (Value * V = performFactorization(A, C, FactorizeOpcode, I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? + createFAdd(V, B) : + createFSub(V, B); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + return nullptr; +} + // Try to perform following optimization on the input instruction I. Return the // simplified expression if was successful; otherwise, return 0. // @@ -437,12 +616,13 @@ // (x * y) +/- (x * z) x * (y +/- z) // (y / x) +/- (z / x) (y +/- z) / x // -Value *FAddCombine::performFactorization(Instruction *I) { - assert((I->getOpcode() == Instruction::FAdd || - I->getOpcode() == Instruction::FSub) && "Expect add/sub"); +Value *FAddCombine::performFactorization(Value *Op0, Value *Op1, unsigned Opcode, + FastMathFlags FMF) { + assert((Opcode == Instruction::FAdd || + Opcode == Instruction::FSub) && "Expect add/sub"); - Instruction *I0 = dyn_cast(I->getOperand(0)); - Instruction *I1 = dyn_cast(I->getOperand(1)); + Instruction *I0 = dyn_cast(Op0); + Instruction *I1 = dyn_cast(Op1); if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode()) return nullptr; @@ -487,11 +667,11 @@ FastMathFlags Flags; Flags.setUnsafeAlgebra(); - if (I0) Flags &= I->getFastMathFlags(); - if (I1) Flags &= I->getFastMathFlags(); + if (I0) Flags &= FMF; + if (I1) Flags &= FMF; // Create expression "NewAddSub = AddSub0 +/- AddsSub1" - Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ? + Value *NewAddSub = (Opcode == Instruction::FAdd) ? createFAdd(AddSub0, AddSub1) : createFSub(AddSub0, AddSub1); if (ConstantFP *CFP = dyn_cast(NewAddSub)) { @@ -598,7 +778,12 @@ } // step 6: Try factorization as the last resort, - return performFactorization(I); + if (Value *V = performFactorization(I->getOperand(0), + I->getOperand(1), + I->getOpcode(), + I->getFastMathFlags())) + return V; + return performFactorizationAssociative(I); } Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {