diff --git a/llvm/include/llvm/Analysis/ConstantFolding.h b/llvm/include/llvm/Analysis/ConstantFolding.h --- a/llvm/include/llvm/Analysis/ConstantFolding.h +++ b/llvm/include/llvm/Analysis/ConstantFolding.h @@ -19,6 +19,7 @@ #ifndef LLVM_ANALYSIS_CONSTANTFOLDING_H #define LLVM_ANALYSIS_CONSTANTFOLDING_H +#include "llvm/ADT/APFloat.h" #include namespace llvm { @@ -26,6 +27,7 @@ template class ArrayRef; class CallBase; class Constant; +class ConstantFP; class DSOLocalEquivalent; class DataLayout; class Function; @@ -95,6 +97,10 @@ Constant *RHS, const DataLayout &DL, const Instruction *I); +/// Attempt to flush float point constant according to function's denormal mode. +ConstantFP *FlushFPConstant(ConstantFP *Operand, DenormalMode DenormMode, + bool IsOutput); + /// Attempt to constant fold a select instruction with the specified /// operands. The constant result is returned if successful; if not, null is /// returned. diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -240,7 +240,8 @@ /// Given operands for an FCmpInst, fold the result or return null. Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q); + FastMathFlags FMF, const SimplifyQuery &Q, + DenormalMode DenormMode); /// Given operands for a SelectInst, fold the result or return null. Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -16,7 +16,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ConstantFolding.h" -#include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" @@ -1347,36 +1346,31 @@ // attributes. If so, return a zero with the correct sign, otherwise return the // original constant. Inputs and outputs to floating point instructions can have // their mode set separately, so the direction is also needed. -Constant *FlushFPConstant(Constant *Operand, const llvm::Function *F, - bool IsOutput) { - if (F == nullptr) +ConstantFP *llvm::FlushFPConstant(ConstantFP *Operand, DenormalMode DenormMode, + bool IsOutput) { + const APFloat &APF = Operand->getValueAPF(); + Type *Ty = Operand->getType(); + DenormalMode::DenormalModeKind Mode = + IsOutput ? DenormMode.Output : DenormMode.Input; + switch (Mode) { + default: + llvm_unreachable("unknown denormal mode"); return Operand; - if (auto *CFP = dyn_cast(Operand)) { - const APFloat &APF = CFP->getValueAPF(); - Type *Ty = CFP->getType(); - DenormalMode DenormMode = F->getDenormalMode(Ty->getFltSemantics()); - DenormalMode::DenormalModeKind Mode = - IsOutput ? DenormMode.Output : DenormMode.Input; - switch (Mode) { - default: - llvm_unreachable("unknown denormal mode"); - return Operand; - case DenormalMode::IEEE: - return Operand; - case DenormalMode::PreserveSign: - if (APF.isDenormal()) { - return ConstantFP::get( - Ty->getContext(), - APFloat::getZero(Ty->getFltSemantics(), APF.isNegative())); - } - return Operand; - case DenormalMode::PositiveZero: - if (APF.isDenormal()) { - return ConstantFP::get(Ty->getContext(), - APFloat::getZero(Ty->getFltSemantics(), false)); - } - return Operand; + case DenormalMode::IEEE: + return Operand; + case DenormalMode::PreserveSign: + if (APF.isDenormal()) { + return ConstantFP::get( + Ty->getContext(), + APFloat::getZero(Ty->getFltSemantics(), APF.isNegative())); } + return Operand; + case DenormalMode::PositiveZero: + if (APF.isDenormal()) { + return ConstantFP::get(Ty->getContext(), + APFloat::getZero(Ty->getFltSemantics(), false)); + } + return Operand; } return Operand; } @@ -1387,10 +1381,27 @@ if (auto *BB = I->getParent()) { if (auto *F = BB->getParent()) { if (Instruction::isBinaryOp(Opcode)) { - Constant *Op0 = FlushFPConstant(LHS, F, false); - Constant *Op1 = FlushFPConstant(RHS, F, false); + Constant *Op0 = + isa(LHS) + ? FlushFPConstant( + cast(LHS), + F->getDenormalMode(LHS->getType()->getFltSemantics()), + false) + : LHS; + Constant *Op1 = + isa(RHS) + ? FlushFPConstant( + cast(RHS), + F->getDenormalMode(RHS->getType()->getFltSemantics()), + false) + : RHS; Constant *C = ConstantFoldBinaryOpOperands(Opcode, Op0, Op1, DL); - return FlushFPConstant(C, F, true); + return isa(C) + ? FlushFPConstant( + cast(C), + F->getDenormalMode(C->getType()->getFltSemantics()), + true) + : C; } } } diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -3894,13 +3894,20 @@ /// If not, this returns null. static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q, - unsigned MaxRecurse) { + unsigned MaxRecurse, DenormalMode DenormMode) { CmpInst::Predicate Pred = (CmpInst::Predicate)Predicate; assert(CmpInst::isFPPredicate(Pred) && "Not an FP compare!"); if (Constant *CLHS = dyn_cast(LHS)) { - if (Constant *CRHS = dyn_cast(RHS)) + if (Constant *CRHS = dyn_cast(RHS)) { + if (isa(CLHS)) + CLHS = FlushFPConstant(cast(CLHS), DenormMode, + /* IsOutput */ false); + if (isa(CRHS)) + CRHS = FlushFPConstant(cast(CRHS), DenormMode, + /* IsOutput */ false); return ConstantFoldCompareInstOperands(Pred, CLHS, CRHS, Q.DL, Q.TLI); + } // If we have a constant, make sure it is on the RHS. std::swap(LHS, RHS); @@ -4101,8 +4108,10 @@ } Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q) { - return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); + FastMathFlags FMF, const SimplifyQuery &Q, + DenormalMode DenormMode) { + return ::simplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit, + DenormMode); } static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, @@ -5571,7 +5580,12 @@ const SimplifyQuery &Q, unsigned MaxRecurse) { if (CmpInst::isIntPredicate((CmpInst::Predicate)Predicate)) return simplifyICmpInst(Predicate, LHS, RHS, Q, MaxRecurse); - return simplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse); + // For now, most denormal float constants folding should be done in + // simplifyFCmpInst(), so simply pass IEEE mode for both input and output + // here. + // FIXME: use the real denormal mode from the function's attributes. + return simplifyFCmpInst(Predicate, LHS, RHS, FastMathFlags(), Q, MaxRecurse, + {DenormalMode::IEEE, DenormalMode::IEEE}); } Value *llvm::simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -6405,10 +6419,16 @@ Result = simplifyICmpInst(cast(I)->getPredicate(), NewOps[0], NewOps[1], Q); break; - case Instruction::FCmp: + case Instruction::FCmp: { + DenormalMode Mode = {DenormalMode::IEEE, DenormalMode::IEEE}; + // We can further fold denormal constant fp with specific denormal mode. + if (isa(NewOps[0]) && I->getParent()) + Mode = I->getFunction()->getDenormalMode( + NewOps[0]->getType()->getFltSemantics()); Result = simplifyFCmpInst(cast(I)->getPredicate(), NewOps[0], - NewOps[1], I->getFastMathFlags(), Q); + NewOps[1], I->getFastMathFlags(), Q, Mode); break; + } case Instruction::Select: Result = simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q); break; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -6691,8 +6691,11 @@ const CmpInst::Predicate Pred = I.getPredicate(); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + DenormalMode Mode = {DenormalMode::IEEE, DenormalMode::IEEE}; + if (isa(Op0) && I.getParent()) + Mode = I.getFunction()->getDenormalMode(Op0->getType()->getFltSemantics()); if (Value *V = simplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), - SQ.getWithInstruction(&I))) + SQ.getWithInstruction(&I), Mode)) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' diff --git a/llvm/test/Transforms/InstSimplify/constant-fold-fp-denormal.ll b/llvm/test/Transforms/InstSimplify/constant-fold-fp-denormal.ll --- a/llvm/test/Transforms/InstSimplify/constant-fold-fp-denormal.ll +++ b/llvm/test/Transforms/InstSimplify/constant-fold-fp-denormal.ll @@ -763,7 +763,7 @@ define i1 @fcmp_double_positive_zero() #6 { ; CHECK-LABEL: @fcmp_double_positive_zero( ; CHECK-NEXT: entry: -; CHECK-NEXT: ret i1 true +; CHECK-NEXT: ret i1 false ; entry: %cmp = fcmp une double 0x0, 0x8000000000000 @@ -773,7 +773,7 @@ define i1 @fcmp_float_positive_zero() #6 { ; CHECK-LABEL: @fcmp_float_positive_zero( ; CHECK-NEXT: entry: -; CHECK-NEXT: ret i1 true +; CHECK-NEXT: ret i1 false ; entry: %cmp = fcmp une double 0x0, 0x8000000000000