Index: lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -612,12 +612,24 @@ } } - // sqrt(X) * sqrt(X) -> X - if (AllowReassociate && (Op0 == Op1)) - if (IntrinsicInst *II = dyn_cast(Op0)) - if (II->getIntrinsicID() == Intrinsic::sqrt) + if (Op0 == Op1) { + if (IntrinsicInst *II = dyn_cast(Op0)) { + // sqrt(X) * sqrt(X) -> X + if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt) return ReplaceInstUsesWith(I, II->getOperand(0)); + // fabs(X) * fabs(X) -> X * X + if (II->getIntrinsicID() == Intrinsic::fabs) { + Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0), + II->getOperand(0), + I.getName()); + FMulVal->copyFastMathFlags(&I); + + return FMulVal; + } + } + } + // Under unsafe algebra do: // X * log2(0.5*Y) = X*log2(Y) - X if (AllowReassociate) { Index: test/Transforms/InstCombine/fmul.ll =================================================================== --- test/Transforms/InstCombine/fmul.ll +++ test/Transforms/InstCombine/fmul.ll @@ -152,3 +152,32 @@ ; CHECK-NEXT: %mul2 = fmul double %sqrt, %f ; CHECK-NEXT: ret double %mul2 } + +declare float @llvm.fabs.f32(float) nounwind readnone + +; CHECK-LABEL @fabs_squared( +; CHECK: %mul = fmul float %x, %x +define float @fabs_squared(float %x) { + %x.fabs = call float @llvm.fabs.f32(float %x) + %mul = fmul float %x.fabs, %x.fabs + ret float %mul +} + +; CHECK-LABEL @fabs_squared_fast( +; CHECK: %mul = fmul fast float %x, %x +define float @fabs_squared_fast(float %x) { + %x.fabs = call float @llvm.fabs.f32(float %x) + %mul = fmul fast float %x.fabs, %x.fabs + ret float %mul +} + +; CHECK-LABEL @fabs_x_fabs( +; CHECK: call float @llvm.fabs.f32(float %x) +; CHECK: call float @llvm.fabs.f32(float %y) +; CHECK: %mul = fmul float %x.fabs, %y.fabs +define float @fabs_x_fabs(float %x, float %y) { + %x.fabs = call float @llvm.fabs.f32(float %x) + %y.fabs = call float @llvm.fabs.f32(float %y) + %mul = fmul float %x.fabs, %y.fabs + ret float %mul +}