Index: lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCasts.cpp +++ lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1411,45 +1411,43 @@ /// Return a Constant* for the specified floating-point constant if it fits /// in the specified FP type without changing its value. -static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { +static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { bool losesInfo; APFloat F = CFP->getValueAPF(); (void)F.convert(Sem, APFloat::rmNearestTiesToEven, &losesInfo); - if (!losesInfo) - return ConstantFP::get(CFP->getContext(), F); - return nullptr; + return !losesInfo; } -static Constant *shrinkFPConstant(ConstantFP *CFP) { +static Type *shrinkFPConstant(ConstantFP *CFP) { if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext())) return nullptr; // No constant folding of this. // See if the value can be truncated to half and then reextended. - if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEhalf())) - return NewCFP; + if (fitsInFPType(CFP, APFloat::IEEEhalf())) + return Type::getHalfTy(CFP->getContext()); // See if the value can be truncated to float and then reextended. - if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEsingle())) - return NewCFP; + if (fitsInFPType(CFP, APFloat::IEEEsingle())) + return Type::getFloatTy(CFP->getContext()); if (CFP->getType()->isDoubleTy()) return nullptr; // Won't shrink. - if (Constant *NewCFP = fitsInFPType(CFP, APFloat::IEEEdouble())) - return NewCFP; + if (fitsInFPType(CFP, APFloat::IEEEdouble())) + return Type::getDoubleTy(CFP->getContext()); // Don't try to shrink to various long double types. return nullptr; } -/// Look through floating-point extensions until we get the source value. -static Value *lookThroughFPExtensions(Value *V) { - while (auto *FPExt = dyn_cast(V)) - V = FPExt->getOperand(0); +/// Find the minimum FP type we can safely truncate to. +static Type *getMinimumFPType(Value *V) { + if (auto *FPExt = dyn_cast(V)) + return FPExt->getOperand(0)->getType(); // If this value is a constant, return the constant in the smallest FP type // that can accurately represent it. This allows us to turn // (float)((double)X+2.0) into x+2.0f. if (auto *CFP = dyn_cast(V)) - if (Constant *NewCFP = shrinkFPConstant(CFP)) - return NewCFP; + if (Type *T = shrinkFPConstant(CFP)) + return T; - return V; + return V->getType(); } Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { @@ -1464,11 +1462,11 @@ // is explained below in the various case statements. BinaryOperator *OpI = dyn_cast(CI.getOperand(0)); if (OpI && OpI->hasOneUse()) { - Value *LHSOrig = lookThroughFPExtensions(OpI->getOperand(0)); - Value *RHSOrig = lookThroughFPExtensions(OpI->getOperand(1)); + Type *LHSMinType = getMinimumFPType(OpI->getOperand(0)); + Type *RHSMinType = getMinimumFPType(OpI->getOperand(1)); unsigned OpWidth = OpI->getType()->getFPMantissaWidth(); - unsigned LHSWidth = LHSOrig->getType()->getFPMantissaWidth(); - unsigned RHSWidth = RHSOrig->getType()->getFPMantissaWidth(); + unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); + unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); unsigned SrcWidth = std::max(LHSWidth, RHSWidth); unsigned DstWidth = CI.getType()->getFPMantissaWidth(); switch (OpI->getOpcode()) { @@ -1494,12 +1492,10 @@ // could be tightened for those cases, but they are rare (the main // case of interest here is (float)((double)float + float)). if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType()); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType()); Instruction *RI = - BinaryOperator::Create(OpI->getOpcode(), LHSOrig, RHSOrig); + BinaryOperator::Create(OpI->getOpcode(), LHS, RHS); RI->copyFastMathFlags(OpI); return RI; } @@ -1511,12 +1507,10 @@ // rounding can possibly occur; we can safely perform the operation // in the destination format if it can represent both sources. if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType()); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType()); Instruction *RI = - BinaryOperator::CreateFMul(LHSOrig, RHSOrig); + BinaryOperator::CreateFMul(LHS, RHS); RI->copyFastMathFlags(OpI); return RI; } @@ -1529,33 +1523,35 @@ // condition used here is a good conservative first pass. // TODO: Tighten bound via rigorous analysis of the unbalanced case. if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { - if (LHSOrig->getType() != CI.getType()) - LHSOrig = Builder.CreateFPExt(LHSOrig, CI.getType()); - if (RHSOrig->getType() != CI.getType()) - RHSOrig = Builder.CreateFPExt(RHSOrig, CI.getType()); + Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), CI.getType()); + Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), CI.getType()); Instruction *RI = - BinaryOperator::CreateFDiv(LHSOrig, RHSOrig); + BinaryOperator::CreateFDiv(LHS, RHS); RI->copyFastMathFlags(OpI); return RI; } break; - case Instruction::FRem: + case Instruction::FRem: { // Remainder is straightforward. Remainder is always exact, so the // type of OpI doesn't enter into things at all. We simply evaluate // in whichever source type is larger, then convert to the // destination type. if (SrcWidth == OpWidth) break; - if (LHSWidth < SrcWidth) - LHSOrig = Builder.CreateFPExt(LHSOrig, RHSOrig->getType()); - else if (RHSWidth <= SrcWidth) - RHSOrig = Builder.CreateFPExt(RHSOrig, LHSOrig->getType()); - if (LHSOrig != OpI->getOperand(0) || RHSOrig != OpI->getOperand(1)) { - Value *ExactResult = Builder.CreateFRem(LHSOrig, RHSOrig); - if (Instruction *RI = dyn_cast(ExactResult)) - RI->copyFastMathFlags(OpI); - return CastInst::CreateFPCast(ExactResult, CI.getType()); + Value *LHS, *RHS; + if (LHSWidth == SrcWidth) { + LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType); + RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType); + } else { + LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType); + RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType); } + + Value *ExactResult = Builder.CreateFRem(LHS, RHS); + if (Instruction *RI = dyn_cast(ExactResult)) + RI->copyFastMathFlags(OpI); + return CastInst::CreateFPCast(ExactResult, CI.getType()); + } } // (fptrunc (fneg x)) -> (fneg (fptrunc x))