Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -164,7 +164,8 @@ Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); - Value *performFactorization(Instruction *I); + Value *performFactorization(Value *Op0, Value *Op1, unsigned Opcode, FastMathFlags FMF); + Value *performFactorizationAssociative(Instruction *I); /// Convert given addend to a Value Value *createAddendVal(const FAddend &A, bool& NeedNeg); @@ -429,6 +430,179 @@ return BreakNum; } +// Here if we know that I is either FAdd or FSub +// see if we can factorize them individually. +// The following should be possible as we would reach here if only +// unsafeAlgebra is enabled for this instruction. +// -> Transform (A op1 B) op2 C -> A op3 (B op4 C) if (B op C) factorizes +// Eg. (A + X * C1) + X * C2 -> A + X * (C1 + C2) +// -> Transform A op1 (B op2 C) -> (A op3 B) op4 C) if (A op3 B) factorizes +// Eg. (A + X * C1) - X * C2 -> A + X * (C1-C2) +// -> Transform ( A op1 B) op2 C -> (A op3 C) op4 B if (A op3 C) factorizes +// Eg. (X * C1 - B) + X * C2 -> X * (C1 - C2) - B +// -> Transform A op1 (B op2 C) -> (A op3 C) op4 B +// Eg. X * C1 - (B + X * C2) -> X * (C1 - C2) - B +Value *FAddCombine::performFactorizationAssociative(Instruction *I) { + assert(I->hasUnsafeAlgebra() && "This method can't be called without unsafe algebra"); + if (I->getOpcode() != Instruction::FAdd && I->getOpcode() != Instruction::FSub) + return nullptr; + if (I->getNumUses() != 1) + return nullptr; + + BinaryOperator *Op0 = dyn_cast(I->getOperand(0)); + BinaryOperator *Op1 = dyn_cast(I->getOperand(1)); + { + // (A op B) op C -> Simplify? + bool Op0Factorizable = Op0 && (Op0->getOpcode() == Instruction::FAdd || Op0->getOpcode() == Instruction::FSub); + if (Op0Factorizable) { + unsigned FactorizeOpcode, FinalOpcode; + Value *OpFactor0, *OpFactor1; + bool IsOp0OpcodeAdd = (Op0->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I->getOperand(1); + // (a + b) + c -> a + ( b + c ) ? + // (a + b) - c -> a + ( b - c ) ? + if (IsOp0OpcodeAdd) { + OpFactor0 = B; + OpFactor1 = C; + FactorizeOpcode = I->getOpcode(); + FinalOpcode = Instruction::FAdd; + } + // (a - b) + c -> a + (c - b) ? + else if (!IsOp0OpcodeAdd && IsIOpcodeAdd) { + OpFactor0 = C; + OpFactor1 = B; + FactorizeOpcode = Instruction::FSub; + FinalOpcode = Instruction::FAdd; + } + // (a - b) - c -> a - (b + c) ? + else { + OpFactor0 = B; + OpFactor1 = C; + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + if (Value *V = performFactorization(OpFactor0, OpFactor1, FactorizeOpcode, I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? + createFAdd(A, V) : + createFSub(A, V); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting to have an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + { + // Transform: "A op (B op C)" ==> "(A op B) op C" if "A op B" factorizes. + bool Op1Factorizable = Op1 && (Op1->getOpcode() == Instruction::FAdd || Op1->getOpcode() == Instruction::FSub); + if (Op1Factorizable) { + unsigned FactorizeOpcode, FinalOpcode; + bool IsOp1OpcodeAdd = (Op1->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + Value *A = I->getOperand(0); + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + // A + (B+C) -> (A + B) + C factorizes? + if (IsIOpcodeAdd && IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FAdd; + } + // A + (B - C) -> (A + B) - C factorizes? + if (IsIOpcodeAdd && !IsOp1OpcodeAdd) { + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + // A - (B + C) -> (A - B) - C factorizes? + if (!IsIOpcodeAdd && IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FSub; + } + // A - (B - C) -> (A - B) + C factorizes? + if (!IsIOpcodeAdd && !IsOp1OpcodeAdd) { + FactorizeOpcode = Instruction::FSub; + FinalOpcode = Instruction::FAdd; + } + if (Value *V = performFactorization(A, B, FactorizeOpcode, I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? + createFAdd(V, C) + : createFSub(V, C); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting to have an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + { + // We know that op1 and op2 can only be FAdd or FSub + // (A op1 B) op2 C -> (A op2 C) op1 B factorizes? + bool Op0Valid = Op0 && (Op0->getOpcode() == Instruction::FAdd || Op0->getOpcode() == Instruction::FSub); + if (Op0Valid) { + Value *A = Op0->getOperand(0); + Value *B = Op0->getOperand(1); + Value *C = I->getOperand(1); + if (Value *V = performFactorization(A, C, I->getOpcode(), I->getFastMathFlags())) { + Value *NewV = (Op0->getOpcode() == Instruction::FAdd) ? + createFAdd(V, B) + : createFSub(V, B); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting to have an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + { + // A op1 (B op2 C) -> (A op2 C) op1 B factorizes? + bool Op1Valid = Op1 && (Op1->getOpcode() == Instruction::FAdd || Op1->getOpcode() == Instruction::FSub); + if (Op1Valid) { + Value *A = I->getOperand(0); + Value *B = Op1->getOperand(0); + Value *C = Op1->getOperand(1); + unsigned FactorizeOpcode, FinalOpcode; + bool IsOp1OpcodeAdd = (Op1->getOpcode() == Instruction::FAdd); + bool IsIOpcodeAdd = (I->getOpcode() == Instruction::FAdd); + // A + (B + C) -> (A + C) + B simplifies? + // A + (B - C) -> (A - C) + B simplifies? + if (IsIOpcodeAdd) { + FactorizeOpcode = Op1->getOpcode(); + FinalOpcode = Instruction::FAdd; + } + // A - (B + C) -> (A - C) - B + else if(IsOp1OpcodeAdd) { + FactorizeOpcode = FinalOpcode = Instruction::FSub; + } else { + // A - (B - C) -> (A + C) - B + FactorizeOpcode = Instruction::FAdd; + FinalOpcode = Instruction::FSub; + } + if (Value * V = performFactorization(A, C, FactorizeOpcode, I->getFastMathFlags())) { + Value *NewV = (FinalOpcode == Instruction::FAdd) ? + createFAdd(V, B) : + createFSub(V, B); + FastMathFlags Flags; + Flags.setUnsafeAlgebra(); + Flags &= I->getFastMathFlags(); + Instruction *NewI = dyn_cast(NewV); + assert(NewI && "We are expecting to have an instruction here ?? "); + NewI->setFastMathFlags(Flags); + return NewI; + } + } + } + return nullptr; +} + // Try to perform following optimization on the input instruction I. Return the // simplified expression if was successful; otherwise, return 0. // @@ -437,12 +611,13 @@ // (x * y) +/- (x * z) x * (y +/- z) // (y / x) +/- (z / x) (y +/- z) / x // -Value *FAddCombine::performFactorization(Instruction *I) { - assert((I->getOpcode() == Instruction::FAdd || - I->getOpcode() == Instruction::FSub) && "Expect add/sub"); +Value *FAddCombine::performFactorization(Value *Op0, Value *Op1, unsigned Opcode, + FastMathFlags FMF) { + assert((Opcode == Instruction::FAdd || + Opcode == Instruction::FSub) && "Expect add/sub"); - Instruction *I0 = dyn_cast(I->getOperand(0)); - Instruction *I1 = dyn_cast(I->getOperand(1)); + Instruction *I0 = dyn_cast(Op0); + Instruction *I1 = dyn_cast(Op1); if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode()) return nullptr; @@ -487,11 +662,11 @@ FastMathFlags Flags; Flags.setUnsafeAlgebra(); - if (I0) Flags &= I->getFastMathFlags(); - if (I1) Flags &= I->getFastMathFlags(); + if (I0) Flags &= FMF; + if (I1) Flags &= FMF; // Create expression "NewAddSub = AddSub0 +/- AddsSub1" - Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ? + Value *NewAddSub = (Opcode == Instruction::FAdd) ? createFAdd(AddSub0, AddSub1) : createFSub(AddSub0, AddSub1); if (ConstantFP *CFP = dyn_cast(NewAddSub)) { @@ -598,7 +773,10 @@ } // step 6: Try factorization as the last resort, - return performFactorization(I); + if (Value *V = performFactorization(I->getOperand(0), I->getOperand(1), I->getOpcode(), I->getFastMathFlags())) + return V; + return performFactorizationAssociative(I); + //return nullptr; } Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { Index: test/Transforms/InstCombine/FAddFSubAssociativeFactorize.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/FAddFSubAssociativeFactorize.ll @@ -0,0 +1,224 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; **************************************** +; Test (a op1 b) op2 c - > a op3 (b op4 c) +; **************************************** + +; CHECK: faddsubAssoc1 +; CHECK: fadd fast half %a, %b +; CHECK: fmul fast half %1, 0xH4500 +; CHECK: ret +define half @faddsubAssoc1(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fadd fast half %tmp4, %tmp3 ; (5 * a + 3 * b) + (2 * b) + ret half %tmp5 ; = 5 * ( a + b ) +} + +; CHECK: faddsubAssoc2 +; CHECK: %tmp2 = fmul fast half %a, 0xH4500 +; CHECK: fadd fast half %tmp2, %b +; CHECK: ret +define half @faddsubAssoc2(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fsub fast half %tmp4, %tmp3 ; (5 * a + 3 * b) - (2 * b) + ret half %tmp5 ; = 5 * a + b +} + +; CHECK: faddsubAssoc3 +; CHECK: %tmp2 = fmul fast half %a, 0xH4500 +; CHECK: fsub fast half %tmp2, %b +; CHECK: ret +define half @faddsubAssoc3(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fadd fast half %tmp4, %tmp3 ; (5 * a - 3 * b) + (2 * b) + ret half %tmp5 ; = 5 * a - b +} + +; CHECK: faddsubAssoc4 +; CHECK: fsub fast half %a, %b +; CHECK: fmul fast half %1, 0xH4500 +; CHECK: ret +define half @faddsubAssoc4(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fsub fast half %tmp4, %tmp3 ; (5 * a - 3 * b) - (2 * b) + ret half %tmp5 ; = 5 * a - 5 * b +} + +; **************************************** +; Test a op1 (b op2 c) - > (a op3 b) op4 c +; **************************************** + +; CHECK: faddsubAssoc5 +; CHECK: fadd fast half %b, %a +; CHECK: fmul fast half %1, 0xH4500 +define half @faddsubAssoc5(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + (3 * b + 5 * a) + ret half %tmp5 ; = 5 * (a + b) +} + +; CHECK: faddsubAssoc6 +; CHECK: fmul fast half %a, 0xHC500 +; CHECK: fsub +; CHECK: ret +define half @faddsubAssoc6(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - (3 * b + 5 * a) + ret half %tmp5 ; = -b - 5 * a +} + +; CHECK: faddsubAssoc7 +; CHECK: fsub fast half %b, %a +; CHECK: fmul fast half %1, 0xH4500 +; CHECK: ret +define half @faddsubAssoc7(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + (3 * b - 5 * a) + ret half %tmp5 ; = 5 * (b - a) +} + +; CHECK: faddsubAssoc8 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fsub fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc8(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - (3 * b - 5 * a) + ret half %tmp5 ; = -b + 5 * a +} + +; **************************************** +; Test (a op1 b) op2 c - > (a op3 c) op4 b +; **************************************** + +; CHECK: faddsubAssoc9 +; CHECK: fadd fast half %b, %a +; CHECK: fmul fast half {{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc9(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fadd fast half %tmp4, %tmp3 ; (3 * b + 5 * a) + (2 * b) + ret half %tmp5 ; = 5 * ( a + b ) +} + +; CHECK: faddsubAssoc10 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fadd fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc10(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp1, %tmp2 ; 3 * b + 5 * a + %tmp5 = fsub fast half %tmp4, %tmp3 ; (3 * b + 5 * a) - (2 * b) + ret half %tmp5 ; = b + 5 * a +} + +; CHECK: faddsubAssoc11 +; CHECK: fsub fast half %b, %a +; CHECK: fmul fast half {{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc11(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fadd fast half %tmp4, %tmp3 ; (3 * b - 5 * a) + (2 * b) + ret half %tmp5 ; = 5 * (b - a) +} + +; CHECK: faddsubAssoc12 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fsub fast half %b, {{.*}} +; CHECK: ret +define half @faddsubAssoc12(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp1, %tmp2 ; 3 * b - 5 * a + %tmp5 = fsub fast half %tmp4, %tmp3 ; (3 * b - 5 * a) - (2 * b) + ret half %tmp5 ; = b - 5 * a +} + +; **************************************** +; Test a op1 (b op2 c) - > (a op3 c) op4 b +; **************************************** + +; CHECK: faddsubAssoc13 +; CHECK: fadd fast half %b, %a +; CHECK: fmul fast half {{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc13(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + ( 5 * a + 3 * b) + ret half %tmp5 ; = 5 * ( a + b ) +} + +; CHECK: faddsubAssoc14 +; CHECK: fmul fast half %a, 0xH4500 +; CHECK: fsub fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc14(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fadd fast half %tmp3, %tmp4 ; 2 * b + ( 5 * a - 3 * b) + ret half %tmp5 ; = 5 * a - b +} + +; CHECK: faddsubAssoc15 +; CHECK: fmul fast half %a, 0xHC500 +; CHECK: fsub fast half {{.*}}, %b +; CHECK: ret +define half @faddsubAssoc15(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fadd fast half %tmp2, %tmp1 ; 5 * a + 3 * b + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - ( 5 * a + 3 * b) + ret half %tmp5 ; = -5 * a - b +} + +; CHECK: faddsubAssoc16 +; CHECK: fsub fast half %b, %a +; CHECK: fmul fast half %{{.*}}, 0xH4500 +; CHECK: ret +define half @faddsubAssoc16(half %a, half %b) { + %tmp1 = fmul fast half %b, 0xH4200 ; 3*b + %tmp2 = fmul fast half %a, 0xH4500 ; 5*a + %tmp3 = fmul fast half %b, 0xH4000 ; 2*b + %tmp4 = fsub fast half %tmp2, %tmp1 ; 5 * a - 3 * b + %tmp5 = fsub fast half %tmp3, %tmp4 ; 2 * b - ( 5 * a - 3 * b) + ret half %tmp5 ; = 5 * (b - a) +}