Index: llvm/include/llvm/IR/Operator.h =================================================================== --- llvm/include/llvm/IR/Operator.h +++ llvm/include/llvm/IR/Operator.h @@ -18,6 +18,7 @@ #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -195,6 +196,26 @@ return FMF; } + static FastMathFlags getFromFuncAttr(Function *F) { + FastMathFlags FMF; + if (F->hasFnAttribute("unsafe-fp-math")) { + bool UnsafeMath = F->getFnAttribute("unsafe-fp-math").getValueAsBool(); + FMF.setAllowReassoc(UnsafeMath); + FMF.setAllowReciprocal(UnsafeMath); + } + if (F->hasFnAttribute("no-nans-fp-math")) + FMF.setNoNaNs(F->getFnAttribute("no-nans-fp-math").getValueAsBool()); + if (F->hasFnAttribute("no-signed-zeros-fp-math")) + FMF.setNoSignedZeros( + F->getFnAttribute("no-signed-zeros-fp-math").getValueAsBool()); + if (F->hasFnAttribute("no-infs-fp-math")) + FMF.setNoInfs(F->getFnAttribute("no-infs-fp-math").getValueAsBool()); + if (F->hasFnAttribute("approx-func-fp-math")) + FMF.setApproxFunc( + F->getFnAttribute("approx-func-fp-math").getValueAsBool()); + return FMF; + } + bool any() const { return Flags != 0; } bool none() const { return Flags == 0; } bool all() const { return Flags == ~0U; } Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2682,6 +2682,32 @@ } } +static FastMathFlags getFMFforSel(SelectInst &SI) { + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + FastMathFlags FMF; + FMF.set(); + + if (auto *V = dyn_cast(TrueVal)) + FMF &= V->getFastMathFlags(); + else + FMF &= FastMathFlags::getFromFuncAttr(SI.getFunction()); + + if (auto *V = dyn_cast(FalseVal)) + FMF &= V->getFastMathFlags(); + else + FMF &= FastMathFlags::getFromFuncAttr(SI.getFunction()); + + // This is probably not needed. + // Value *CondVal = SI.getCondition(); + // if (auto *V = dyn_cast(CondVal)) + // FMF &= V->getFastMathFlags(); + // else + // FMF &= FastMathFlags::getFromFuncAttr(SI.getFunction()); + + return FMF; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -2714,6 +2740,16 @@ if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) return I; + // If the select does not have any FMFs then use the intersection + // of it's operand's FMFs. + if (isa(SI) && SI.getFastMathFlags().none()) { + FastMathFlags FMF = getFMFforSel(SI); + if (FMF.any()) { + SI.setFastMathFlags(getFMFforSel(SI)); + return &SI; + } + } + CmpInst::Predicate Pred; // Avoid potential infinite loops by checking for non-constant condition. Index: llvm/test/Transforms/InstCombine/minmax-fold.ll =================================================================== --- llvm/test/Transforms/InstCombine/minmax-fold.ll +++ llvm/test/Transforms/InstCombine/minmax-fold.ll @@ -926,7 +926,7 @@ ; CHECK-NEXT: [[CMP2_INV:%.*]] = fcmp fast oge float [[X]], 2.000000e+00 ; CHECK-NEXT: [[TMP2:%.*]] = select fast i1 [[CMP2_INV]], float 2.000000e+00, float [[X]] ; CHECK-NEXT: [[CMP3:%.*]] = icmp ult i8 [[I:%.*]], 16 -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP3]], float [[TMP1]], float [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = select fast i1 [[CMP3]], float [[TMP1]], float [[TMP2]] ; CHECK-NEXT: ret float [[R]] ; %cmp1 = fcmp fast ult float %x, 1.0 @@ -1080,8 +1080,8 @@ define i37 @add_umax_simplify(i37 %x) { ; CHECK-LABEL: @add_umax_simplify( -; CHECK-NEXT: [[A:%.*]] = add nuw i37 [[X:%.*]], 42 -; CHECK-NEXT: ret i37 [[A]] +; CHECK-NEXT: [[R:%.*]] = add nuw i37 [[X:%.*]], 42 +; CHECK-NEXT: ret i37 [[R]] ; %a = add nuw i37 %x, 42 %c = icmp ugt i37 %a, 42 Index: llvm/test/Transforms/PhaseOrdering/fmf-select-abs.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/PhaseOrdering/fmf-select-abs.ll @@ -0,0 +1,37 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -S -O1 | FileCheck %s + + +; Function Attrs: mustprogress nounwind uwtable +define dso_local double @foo(double %x) #0 { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = call reassoc nnan ninf nsz arcp afn double @llvm.fabs.f64(double [[X:%.*]]) +; CHECK-NEXT: ret double [[TMP0]] +; +entry: + %retval = alloca double, align 8 + %x.addr = alloca double, align 8 + store double %x, double* %x.addr, align 8 + %0 = load double, double* %x.addr, align 8 + %cmp = fcmp fast olt double %0, 0.000000e+00 + br i1 %cmp, label %if.then, label %if.else + +if.then: ; preds = %entry + %1 = load double, double* %x.addr, align 8 + %fneg = fneg fast double %1 + store double %fneg, double* %retval, align 8 + br label %return + +if.else: ; preds = %entry + %2 = load double, double* %x.addr, align 8 + store double %2, double* %retval, align 8 + br label %return + +return: ; preds = %if.else, %if.then + %3 = load double, double* %retval, align 8 + ret double %3 +} + + +attributes #0 = { mustprogress nounwind uwtable "approx-func-fp-math"="true" "frame-pointer"="non-leaf" "min-legal-vector-width"="0" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="cortex-a76" "unsafe-fp-math"="true" }