Index: llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp +++ llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp @@ -51,6 +51,8 @@ const TargetMachine *TM; + bool UnsafeFPMath = false; + // -fuse-native. bool AllNative = false; @@ -67,10 +69,11 @@ /* Specialized optimizations */ // pow/powr/pown - bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); + bool fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, + const FuncInfo &FInfo); // rootn - bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); + bool fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo); // -fuse-native for sincos bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo); @@ -81,10 +84,11 @@ bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo); // sqrt - bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo); + bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, + const FuncInfo &FInfo); // sin/cos - bool fold_sincos(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo, + bool fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo, AliasAnalysis *AA); // __read_pipe/__write_pipe @@ -104,7 +108,9 @@ protected: CallInst *CI; - bool isUnsafeMath(const CallInst *CI) const; + bool isUnsafeMath(const FPMathOperator *FPOp) const; + + bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const; void replaceCall(Value *With) { CI->replaceAllUsesWith(With); @@ -116,6 +122,7 @@ bool fold(CallInst *CI, AliasAnalysis *AA = nullptr); + void initFunction(const Function &F); void initNativeFuncs(); // Replace a normal math function call with that native version @@ -436,13 +443,18 @@ return AMDGPULibFunc::parse(FMangledName, FInfo); } -bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const { - if (auto Op = dyn_cast(CI)) - if (Op->isFast()) - return true; - const Function *F = CI->getParent()->getParent(); - Attribute Attr = F->getFnAttribute("unsafe-fp-math"); - return Attr.getValueAsBool(); +bool AMDGPULibCalls::isUnsafeMath(const FPMathOperator *FPOp) const { + return UnsafeFPMath || FPOp->isFast(); +} + +bool AMDGPULibCalls::canIncreasePrecisionOfConstantFold( + const FPMathOperator *FPOp) const { + // TODO: Refine to approxFunc or contract + return isUnsafeMath(FPOp); +} + +void AMDGPULibCalls::initFunction(const Function &F) { + UnsafeFPMath = F.getFnAttribute("unsafe-fp-math").getValueAsBool(); } bool AMDGPULibCalls::useNativeFunc(const StringRef F) const { @@ -610,45 +622,43 @@ if (TDOFold(CI, FInfo)) return true; - // Under unsafe-math, evaluate calls if possible. - // According to Brian Sumner, we can do this for all f32 function calls - // using host's double function calls. - if (isUnsafeMath(CI) && evaluateCall(CI, FInfo)) - return true; + if (FPMathOperator *FPOp = dyn_cast(CI)) { + // Under unsafe-math, evaluate calls if possible. + // According to Brian Sumner, we can do this for all f32 function calls + // using host's double function calls. + if (canIncreasePrecisionOfConstantFold(FPOp) && evaluateCall(CI, FInfo)) + return true; - // Copy fast flags from the original call. - if (const FPMathOperator *FPOp = dyn_cast(CI)) + // Copy fast flags from the original call. B.setFastMathFlags(FPOp->getFastMathFlags()); - // Specialized optimizations for each function call - switch (FInfo.getId()) { - case AMDGPULibFunc::EI_POW: - case AMDGPULibFunc::EI_POWR: - case AMDGPULibFunc::EI_POWN: - return fold_pow(CI, B, FInfo); - - case AMDGPULibFunc::EI_ROOTN: - // skip vector function - return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo); - - case AMDGPULibFunc::EI_SQRT: - return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo); - case AMDGPULibFunc::EI_COS: - case AMDGPULibFunc::EI_SIN: - if ((getArgType(FInfo) == AMDGPULibFunc::F32 || - getArgType(FInfo) == AMDGPULibFunc::F64) - && (FInfo.getPrefix() == AMDGPULibFunc::NOPFX)) - return fold_sincos(CI, B, FInfo, AA); - - break; - case AMDGPULibFunc::EI_READ_PIPE_2: - case AMDGPULibFunc::EI_READ_PIPE_4: - case AMDGPULibFunc::EI_WRITE_PIPE_2: - case AMDGPULibFunc::EI_WRITE_PIPE_4: - return fold_read_write_pipe(CI, B, FInfo); - - default: - break; + // Specialized optimizations for each function call + switch (FInfo.getId()) { + case AMDGPULibFunc::EI_POW: + case AMDGPULibFunc::EI_POWR: + case AMDGPULibFunc::EI_POWN: + return fold_pow(FPOp, B, FInfo); + case AMDGPULibFunc::EI_ROOTN: + return fold_rootn(FPOp, B, FInfo); + case AMDGPULibFunc::EI_SQRT: + return fold_sqrt(FPOp, B, FInfo); + case AMDGPULibFunc::EI_COS: + case AMDGPULibFunc::EI_SIN: + return fold_sincos(FPOp, B, FInfo, AA); + default: + break; + } + } else { + // Specialized optimizations for each function call + switch (FInfo.getId()) { + case AMDGPULibFunc::EI_READ_PIPE_2: + case AMDGPULibFunc::EI_READ_PIPE_4: + case AMDGPULibFunc::EI_WRITE_PIPE_2: + case AMDGPULibFunc::EI_WRITE_PIPE_4: + return fold_read_write_pipe(CI, B, FInfo); + default: + break; + } } return false; @@ -727,7 +737,7 @@ } } -bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo) { assert((FInfo.getId() == AMDGPULibFunc::EI_POW || FInfo.getId() == AMDGPULibFunc::EI_POWR || @@ -759,7 +769,7 @@ } // No unsafe math , no constant argument, do nothing - if (!isUnsafeMath(CI) && !CF && !CINT && !CZero) + if (!isUnsafeMath(FPOp) && !CF && !CINT && !CZero) return false; // 0x1111111 means that we don't do anything for this call. @@ -818,7 +828,7 @@ } } - if (!isUnsafeMath(CI)) + if (!isUnsafeMath(FPOp)) return false; // Unsafe Math optimization @@ -1012,10 +1022,14 @@ return true; } -bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_rootn(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo) { - Value *opr0 = CI->getArgOperand(0); - Value *opr1 = CI->getArgOperand(1); + // skip vector function + if (getVecSize(FInfo) != 1) + return false; + + Value *opr0 = FPOp->getOperand(0); + Value *opr1 = FPOp->getOperand(1); ConstantInt *CINT = dyn_cast(opr1); if (!CINT) { @@ -1077,8 +1091,11 @@ } // fold sqrt -> native_sqrt (x) -bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo) { + if (!isUnsafeMath(FPOp)) + return false; + if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) && (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) { if (FunctionCallee FPExpr = getNativeFunction( @@ -1095,10 +1112,16 @@ } // fold sin, cos -> sincos. -bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B, +bool AMDGPULibCalls::fold_sincos(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &fInfo, AliasAnalysis *AA) { assert(fInfo.getId() == AMDGPULibFunc::EI_SIN || fInfo.getId() == AMDGPULibFunc::EI_COS); + + if ((getArgType(fInfo) != AMDGPULibFunc::F32 && + getArgType(fInfo) != AMDGPULibFunc::F64) || + fInfo.getPrefix() != AMDGPULibFunc::NOPFX) + return false; + bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN; Value *CArgVal = CI->getArgOperand(0); @@ -1540,6 +1563,8 @@ if (skipFunction(F)) return false; + Simplifier.initFunction(F); + bool Changed = false; auto AA = &getAnalysis().getAAResults(); @@ -1564,6 +1589,7 @@ FunctionAnalysisManager &AM) { AMDGPULibCalls Simplifier(&TM); Simplifier.initNativeFuncs(); + Simplifier.initFunction(F); bool Changed = false; auto AA = &AM.getResult(F); @@ -1590,6 +1616,8 @@ if (skipFunction(F) || UseNative.empty()) return false; + Simplifier.initFunction(F); + bool Changed = false; for (auto &BB : F) { for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) { @@ -1610,6 +1638,7 @@ AMDGPULibCalls Simplifier; Simplifier.initNativeFuncs(); + Simplifier.initFunction(F); bool Changed = false; for (auto &BB : F) {