Index: llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
===================================================================
--- llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -51,6 +51,8 @@
 
   const TargetMachine *TM;
 
+  bool UnsafeFPMath = false;
+
   // -fuse-native.
   bool AllNative = false;
 
@@ -67,10 +69,11 @@
   /* Specialized optimizations */
 
   // pow/powr/pown
-  bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+  bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
+                const FuncInfo &FInfo);
 
   // rootn
-  bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+  bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
 
   // -fuse-native for sincos
   bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo);
@@ -81,10 +84,11 @@
   bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo);
 
   // sqrt
-  bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
+  bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
+                 const FuncInfo &FInfo);
 
   // sin/cos
-  bool fold_sincos(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo,
+  bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo,
                    AliasAnalysis *AA);
 
   // __read_pipe/__write_pipe
@@ -104,7 +108,9 @@
 protected:
   CallInst *CI;
 
-  bool isUnsafeMath(const CallInst *CI) const;
+  bool isUnsafeMath(const FPMathOperator *FPOp) const;
+
+  bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const;
 
   void replaceCall(Value *With) {
     CI->replaceAllUsesWith(With);
@@ -116,6 +122,7 @@
 
   bool fold(CallInst *CI, AliasAnalysis *AA = nullptr);
 
+  void initFunction(const Function &F);
   void initNativeFuncs();
 
   // Replace a normal math function call with that native version
@@ -436,13 +443,18 @@
   return AMDGPULibFunc::parse(FMangledName, FInfo);
 }
 
-bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const {
-  if (auto Op = dyn_cast<FPMathOperator>(CI))
-    if (Op->isFast())
-      return true;
-  const Function *F = CI->getParent()->getParent();
-  Attribute Attr = F->getFnAttribute("unsafe-fp-math");
-  return Attr.getValueAsBool();
+bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const {
+  return UnsafeFPMath || FPOp->isFast();
+}
+
+bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold(
+    const FPMathOperator *FPOp) const {
+  // TODO: Refine to approxFunc or contract
+  return isUnsafeMath(FPOp);
+}
+
+void AMDGPULibCalls::initFunction(const Function &F) {
+  UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool();
 }
 
 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
@@ -610,45 +622,43 @@
   if (TDOFold(CI, FInfo))
     return true;
 
-  // Under unsafe-math, evaluate calls if possible.
-  // According to Brian Sumner, we can do this for all f32 function calls
-  // using host's double function calls.
-  if (isUnsafeMath(CI) && evaluateCall(CI, FInfo))
-    return true;
+  if (FPMathOperator *FPOp = dyn_cast<FPMathOperator>(CI)) {
+    // Under unsafe-math, evaluate calls if possible.
+    // According to Brian Sumner, we can do this for all f32 function calls
+    // using host's double function calls.
+    if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo))
+      return true;
 
-  // Copy fast flags from the original call.
-  if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(CI))
+    // Copy fast flags from the original call.
     B.setFastMathFlags(FPOp->getFastMathFlags());
 
-  // Specialized optimizations for each function call
-  switch (FInfo.getId()) {
-  case AMDGPULibFunc::EI_POW:
-  case AMDGPULibFunc::EI_POWR:
-  case AMDGPULibFunc::EI_POWN:
-    return fold_pow(CI, B, FInfo);
-
-  case AMDGPULibFunc::EI_ROOTN:
-    // skip vector function
-    return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo);
-
-  case AMDGPULibFunc::EI_SQRT:
-    return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo);
-  case AMDGPULibFunc::EI_COS:
-  case AMDGPULibFunc::EI_SIN:
-    if ((getArgType(FInfo) == AMDGPULibFunc::F32 ||
-         getArgType(FInfo) == AMDGPULibFunc::F64)
-        && (FInfo.getPrefix() == AMDGPULibFunc::NOPFX))
-      return fold_sincos(CI, B, FInfo, AA);
-
-    break;
-  case AMDGPULibFunc::EI_READ_PIPE_2:
-  case AMDGPULibFunc::EI_READ_PIPE_4:
-  case AMDGPULibFunc::EI_WRITE_PIPE_2:
-  case AMDGPULibFunc::EI_WRITE_PIPE_4:
-    return fold_read_write_pipe(CI, B, FInfo);
-
-  default:
-    break;
+    // Specialized optimizations for each function call
+    switch (FInfo.getId()) {
+    case AMDGPULibFunc::EI_POW:
+    case AMDGPULibFunc::EI_POWR:
+    case AMDGPULibFunc::EI_POWN:
+      return fold_pow(FPOp, B, FInfo);
+    case AMDGPULibFunc::EI_ROOTN:
+      return fold_rootn(FPOp, B, FInfo);
+    case AMDGPULibFunc::EI_SQRT:
+      return fold_sqrt(FPOp, B, FInfo);
+    case AMDGPULibFunc::EI_COS:
+    case AMDGPULibFunc::EI_SIN:
+      return fold_sincos(FPOp, B, FInfo, AA);
+    default:
+      break;
+    }
+  } else {
+    // Specialized optimizations for each function call
+    switch (FInfo.getId()) {
+    case AMDGPULibFunc::EI_READ_PIPE_2:
+    case AMDGPULibFunc::EI_READ_PIPE_4:
+    case AMDGPULibFunc::EI_WRITE_PIPE_2:
+    case AMDGPULibFunc::EI_WRITE_PIPE_4:
+      return fold_read_write_pipe(CI, B, FInfo);
+    default:
+      break;
+    }
   }
 
   return false;
@@ -727,7 +737,7 @@
 }
 }
 
-bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
                               const FuncInfo &FInfo) {
   assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
           FInfo.getId() == AMDGPULibFunc::EI_POWR ||
@@ -759,7 +769,7 @@
   }
 
   // No unsafe math , no constant argument, do nothing
-  if (!isUnsafeMath(CI) && !CF && !CINT && !CZero)
+  if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero)
     return false;
 
   // 0x1111111 means that we don't do anything for this call.
@@ -818,7 +828,7 @@
     }
   }
 
-  if (!isUnsafeMath(CI))
+  if (!isUnsafeMath(FPOp))
     return false;
 
   // Unsafe Math optimization
@@ -1012,10 +1022,14 @@
   return true;
 }
 
-bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B,
                                 const FuncInfo &FInfo) {
-  Value *opr0 = CI->getArgOperand(0);
-  Value *opr1 = CI->getArgOperand(1);
+  // skip vector function
+  if (getVecSize(FInfo) != 1)
+    return false;
+
+  Value *opr0 = FPOp->getOperand(0);
+  Value *opr1 = FPOp->getOperand(1);
 
   ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
   if (!CINT) {
@@ -1077,8 +1091,11 @@
 }
 
 // fold sqrt -> native_sqrt (x)
-bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
                                const FuncInfo &FInfo) {
+  if (!isUnsafeMath(FPOp))
+    return false;
+
   if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
       (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
     if (FunctionCallee FPExpr = getNativeFunction(
@@ -1095,10 +1112,16 @@
 }
 
 // fold sin, cos -> sincos.
-bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B,
+bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B,
                                  const FuncInfo &fInfo, AliasAnalysis *AA) {
   assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
          fInfo.getId() == AMDGPULibFunc::EI_COS);
+
+  if ((getArgType(fInfo) != AMDGPULibFunc::F32 &&
+       getArgType(fInfo) != AMDGPULibFunc::F64) ||
+      fInfo.getPrefix() != AMDGPULibFunc::NOPFX)
+    return false;
+
   bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
 
   Value *CArgVal = CI->getArgOperand(0);
@@ -1540,6 +1563,8 @@
   if (skipFunction(F))
     return false;
 
+  Simplifier.initFunction(F);
+
   bool Changed = false;
   auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
 
@@ -1564,6 +1589,7 @@
                                                   FunctionAnalysisManager &AM) {
   AMDGPULibCalls Simplifier(&TM);
   Simplifier.initNativeFuncs();
+  Simplifier.initFunction(F);
 
   bool Changed = false;
   auto AA = &AM.getResult<AAManager>(F);
@@ -1590,6 +1616,8 @@
   if (skipFunction(F) || UseNative.empty())
     return false;
 
+  Simplifier.initFunction(F);
+
   bool Changed = false;
   for (auto &BB : F) {
     for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
@@ -1610,6 +1638,7 @@
 
   AMDGPULibCalls Simplifier;
   Simplifier.initNativeFuncs();
+  Simplifier.initFunction(F);
 
   bool Changed = false;
   for (auto &BB : F) {