Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -164,7 +164,9 @@ Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); - Value *performFactorization(Instruction *I); + Value *performFactorization(Value *Op0, Value *Op1, unsigned Opcode, + FastMathFlags FMF); + Value *performFactorizationAssociative(Instruction *I); /// Convert given addend to a Value Value *createAddendVal(const FAddend &A, bool& NeedNeg); @@ -429,6 +431,168 @@ return BreakNum; } +// If I is either FAdd or FSub, see if we can ... +// -> Transform (A op1 B) op2 C -> A op3 (B op4 C) if (B op C) factorizes +// Eg. (A + X * C1) + X * C2 -> A + X * (C1 + C2) +// -> Transform A op1 (B op2 C) -> (A op3 B) op4 C) if (A op3 B) factorizes +// Eg. (A + X * C1) - X * C2 -> A + X * (C1-C2) +// -> Transform ( A op1 B) op2 C -> (A op3 C) op4 B if (A op3 C) factorizes +// Eg. (X * C1 - B) + X * C2 -> X * (C1 - C2) - B +// -> Transform A op1 (B op2 C) -> (A op3 C) op4 B +// Eg. X * C1 - (B + X * C2) -> X * (C1 - C2) - B +// This method should only be called when unsafeAlgebra is set for the inst +Value *FAddCombine::performFactorizationAssociative(Instruction *I) { + assert(I->hasUnsafeAlgebra() && + "This method can't be called without unsafe algebra"); + if (I->getOpcode() != Instruction::FAdd && + I->getOpcode() != Instruction::FSub) + return nullptr; + // TODO: Overly conservative? + if (I->getNumUses() != 1) + return nullptr; + + BinaryOperator *Op0 = dyn_cast(I->getOperand(0)); + BinaryOperator *Op1 = dyn_cast(I->getOperand(1)); + // (A op B) op C -> Simplify? + if (Op0 && (Op0->getOpcode() == Instruction::FAdd || + Op0->getOpcode() == Instruction::FSub)) { + unsigned FactorizeOpcode, FinalOpcode; + Value *OpFactor0, *OpFactor1; + bool IsOp0OpcodeAdd = (Op0->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I->getOperand(1); + // (a + b) + c -> a + ( b + c ) ? + // (a + b) - c -> a + ( b - c ) ? + if (IsOp0OpcodeAdd) { + OpFactor0 = B; + OpFactor1 = C; + FactorizeOpcode = I->getOpcode(); + FinalOpcode = Instruction::FAdd; + } + // (a - b) + c -> a + (c - b) ? + else if (!IsOp0OpcodeAdd && IsIOpcodeAdd) { + OpFactor0 = C; + OpFactor1 = B; + FactorizeOpcode = Instruction::FSub; + FinalOpcode = Instruction::FAdd; + } + // (a - b) - c -> a - (b + c) ? + else { + OpFactor0 = B; + OpFactor1 = C; + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + if (Value *V = performFactorization(OpFactor0, OpFactor1, FactorizeOpcode, + I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? createFAdd(A, V) + : createFSub(A, V); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = cast(NewV); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + // Transform: "A op (B op C)" ==> "(A op B) op C" if "A op B" factorizes. + if (Op1 && (Op1->getOpcode() == Instruction::FAdd || + Op1->getOpcode() == Instruction::FSub)) { + unsigned FactorizeOpcode, FinalOpcode; + bool IsOp1OpcodeAdd = (Op1->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + Value *A = I->getOperand(0); + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + // A + (B+C) -> (A + B) + C factorizes? + if (IsIOpcodeAdd && IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FAdd; + } + // A + (B - C) -> (A + B) - C factorizes? + if (IsIOpcodeAdd && !IsOp1OpcodeAdd) { + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + // A - (B + C) -> (A - B) - C factorizes? + if (!IsIOpcodeAdd && IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FSub; + } + // A - (B - C) -> (A - B) + C factorizes? + if (!IsIOpcodeAdd && !IsOp1OpcodeAdd) { + FactorizeOpcode = Instruction::FSub; + FinalOpcode = Instruction::FAdd; + } + if (Value *V = performFactorization(A, B, FactorizeOpcode, + I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? createFAdd(V, C) + : createFSub(V, C); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = cast(NewV); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + // We know that op1 and op2 can only be FAdd or FSub + // (A op1 B) op2 C -> (A op2 C) op1 B factorizes? + if (Op0 && (Op0->getOpcode() == Instruction::FAdd || + Op0->getOpcode() == Instruction::FSub)) { + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I->getOperand(1); + if (Value *V = + performFactorization(A, C, I->getOpcode(), I->getFastMathFlags())) { + Value *NewV = (Op0->getOpcode() == Instruction::FAdd) ? createFAdd(V, B) + : createFSub(V, B); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = cast(NewV); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + // A op1 (B op2 C) -> (A op2 C) op1 B factorizes? + if (Op1 && (Op1->getOpcode() == Instruction::FAdd || + Op1->getOpcode() == Instruction::FSub)) { + Value *A = I->getOperand(0); + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + unsigned FactorizeOpcode, FinalOpcode; + bool IsOp1OpcodeAdd = (Op1->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + // A + (B + C) -> (A + C) + B simplifies? + // A + (B - C) -> (A - C) + B simplifies? + if (IsIOpcodeAdd) { + FactorizeOpcode = Op1->getOpcode(); + FinalOpcode = Instruction::FAdd; + } + // A - (B + C) -> (A - C) - B + else if (IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FSub; + } else { + // A - (B - C) -> (A + C) - B + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + if (Value *V = performFactorization(A, C, FactorizeOpcode, + I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? createFAdd(V, B) + : createFSub(V, B); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = cast(NewV); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + return nullptr; +} + // Try to perform following optimization on the input instruction I. Return the // simplified expression if was successful; otherwise, return 0. // @@ -437,12 +601,13 @@ // (x * y) +/- (x * z) x * (y +/- z) // (y / x) +/- (z / x) (y +/- z) / x // -Value *FAddCombine::performFactorization(Instruction *I) { - assert((I->getOpcode() == Instruction::FAdd || - I->getOpcode() == Instruction::FSub) && "Expect add/sub"); +Value *FAddCombine::performFactorization(Value *Op0, Value *Op1, + unsigned Opcode, FastMathFlags FMF) { + assert((Opcode == Instruction::FAdd || Opcode == Instruction::FSub) && + "Expect add/sub"); - Instruction *I0 = dyn_cast(I->getOperand(0)); - Instruction *I1 = dyn_cast(I->getOperand(1)); + Instruction *I0 = dyn_cast(Op0); + Instruction *I1 = dyn_cast(Op1); if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode()) return nullptr; @@ -487,13 +652,15 @@ FastMathFlags Flags; Flags.setUnsafeAlgebra(); - if (I0) Flags &= I->getFastMathFlags(); - if (I1) Flags &= I->getFastMathFlags(); + if (I0) + Flags &= FMF; + if (I1) + Flags &= FMF; // Create expression "NewAddSub = AddSub0 +/- AddsSub1" - Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ? - createFAdd(AddSub0, AddSub1) : - createFSub(AddSub0, AddSub1); + Value *NewAddSub = (Opcode == Instruction::FAdd) + ? createFAdd(AddSub0, AddSub1) + : createFSub(AddSub0, AddSub1); if (ConstantFP *CFP = dyn_cast(NewAddSub)) { const APFloat &F = CFP->getValueAPF(); if (!F.isNormal()) @@ -598,7 +765,11 @@ } // step 6: Try factorization as the last resort, - return performFactorization(I); + if (Value *V = performFactorization(I->getOperand(0), I->getOperand(1), + I->getOpcode(), I->getFastMathFlags())) + return V; + return performFactorizationAssociative(I); + // return nullptr; } Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { Index: test/Transforms/InstCombine/FAddFSubAssociativeFactorize.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/FAddFSubAssociativeFactorize.ll @@ -0,0 +1,224 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; **************************************** +; Test (a op1 b) op2 c - > a op3 (b op4 c) +; **************************************** + +; CHECK: faddsubAssoc1 +; CHECK: fadd fast half %a, %b +; CHECK: fmul fast half %1, 0xH4500 +; CHECK: ret +define half @faddsubAssoc1(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fadd fast half %tmp4, %tmp3 ; (5 * a + 3 * b) + (2 * b) + ret half %tmp5 ; = 5 * ( a + b ) +} + +; CHECK: faddsubAssoc2 +; CHECK: %tmp2 = fmul fast half %a, 0xH4500 +; CHECK: fadd fast half %tmp2, %b +; CHECK: ret +define half @faddsubAssoc2(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fsub fast half %tmp4, %tmp3 ; (5 * a + 3 * b) - (2 * b) + ret half %tmp5 ; = 5 * a + b +} + +; CHECK: faddsubAssoc3 +; CHECK: %tmp2 = fmul fast half %a, 0xH4500 +; CHECK: fsub fast half %tmp2, %b +; CHECK: ret +define half @faddsubAssoc3(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fadd fast half %tmp4, %tmp3 ; (5 * a - 3 * b) + (2 * b) + ret half %tmp5 ; = 5 * a - b +} + +; CHECK: faddsubAssoc4 +; CHECK: fsub fast half %a, %b +; CHECK: fmul fast half %1, 0xH4500 +; CHECK: ret +define half @faddsubAssoc4(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fsub fast half %tmp4, %tmp3 ; (5 * a - 3 * b) - (2 * b) + ret half %tmp5 ; = 5 * a - 5 * b +} + +; **************************************** +; Test a op1 (b op2 c) - > (a op3 b) op4 c +; **************************************** + +; CHECK: faddsubAssoc5 +; CHECK: fadd fast half %b, %a +; CHECK: fmul fast half %1, 0xH4500 +define half @faddsubAssoc5(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + (3 * b + 5 * a) + ret half %tmp5 ; = 5 * (a + b) +} + +; CHECK: faddsubAssoc6 +; CHECK: fmul fast half %a, 0xHC500 +; CHECK: fsub +; CHECK: ret +define half @faddsubAssoc6(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - (3 * b + 5 * a) + ret half %tmp5 ; = -b - 5 * a +} + +; CHECK: faddsubAssoc7 +; CHECK: fsub fast half %b, %a +; CHECK: fmul fast half %1, 0xH4500 +; CHECK: ret +define half @faddsubAssoc7(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + (3 * b - 5 * a) + ret half %tmp5 ; = 5 * (b - a) +} + +; CHECK: faddsubAssoc8 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fsub fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc8(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - (3 * b - 5 * a) + ret half %tmp5 ; = -b + 5 * a +} + +; **************************************** +; Test (a op1 b) op2 c - > (a op3 c) op4 b +; **************************************** + +; CHECK: faddsubAssoc9 +; CHECK: fadd fast half %b, %a +; CHECK: fmul fast half {{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc9(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fadd fast half %tmp4, %tmp3 ; (3 * b + 5 * a) + (2 * b) + ret half %tmp5 ; = 5 * ( a + b ) +} + +; CHECK: faddsubAssoc10 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fadd fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc10(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fsub fast half %tmp4, %tmp3 ; (3 * b + 5 * a) - (2 * b) + ret half %tmp5 ; = b + 5 * a +} + +; CHECK: faddsubAssoc11 +; CHECK: fsub fast half %b, %a +; CHECK: fmul fast half {{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc11(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fadd fast half %tmp4, %tmp3 ; (3 * b - 5 * a) + (2 * b) + ret half %tmp5 ; = 5 * (b - a) +} + +; CHECK: faddsubAssoc12 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fsub fast half %b, {{.*}} +; CHECK: ret +define half @faddsubAssoc12(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fsub fast half %tmp4, %tmp3 ; (3 * b - 5 * a) - (2 * b) + ret half %tmp5 ; = b - 5 * a +} + +; **************************************** +; Test a op1 (b op2 c) - > (a op3 c) op4 b +; **************************************** + +; CHECK: faddsubAssoc13 +; CHECK: fadd fast half %b, %a +; CHECK: fmul fast half {{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc13(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + ( 5 * a + 3 * b) + ret half %tmp5 ; = 5 * ( a + b ) +} + +; CHECK: faddsubAssoc14 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fsub fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc14(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + ( 5 * a - 3 * b) + ret half %tmp5 ; = 5 * a - b +} + +; CHECK: faddsubAssoc15 +; CHECK: fmul fast half %a, 0xHC500 +; CHECK: fsub fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc15(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - ( 5 * a + 3 * b) + ret half %tmp5 ; = -5 * a - b +} + +; CHECK: faddsubAssoc16 +; CHECK: fsub fast half %b, %a +; CHECK: fmul fast half %{{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc16(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - ( 5 * a - 3 * b) + ret half %tmp5 ; = 5 * (b - a) +}