Index: llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp +++ llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp @@ -106,15 +106,17 @@ FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo); protected: - CallInst *CI; - bool isUnsafeMath(const FPMathOperator *FPOp) const; bool canIncreasePrecisionOfConstantFold(const FPMathOperator *FPOp) const; - void replaceCall(Value *With) { - CI->replaceAllUsesWith(With); - CI->eraseFromParent(); + static void replaceCall(Instruction *I, Value *With) { + I->replaceAllUsesWith(With); + I->eraseFromParent(); + } + + static void replaceCall(FPMathOperator *I, Value *With) { + replaceCall(cast(I), With); } public: @@ -494,7 +496,7 @@ DEBUG_WITH_TYPE("usenative", dbgs() << " replace " << *aCI << " with native version of sin/cos"); - replaceCall(sinval); + replaceCall(aCI, sinval); return true; } } @@ -502,7 +504,6 @@ } bool AMDGPULibCalls::useNative(CallInst *aCI) { - CI = aCI; Function *Callee = aCI->getCalledFunction(); if (!Callee || aCI->isNoBuiltin()) return false; @@ -592,7 +593,6 @@ // This function returns false if no change; return true otherwise. bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) { - this->CI = CI; Function *Callee = CI->getCalledFunction(); // Ignore indirect calls. if (!Callee || CI->isNoBuiltin()) @@ -707,7 +707,7 @@ nval = ConstantDataVector::get(context, tmp); } LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n"); - replaceCall(nval); + replaceCall(CI, nval); return true; } } else { @@ -717,7 +717,7 @@ if (CF->isExactlyValue(tr[i].input)) { Value *nval = ConstantFP::get(CF->getType(), tr[i].result); LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n"); - replaceCall(nval); + replaceCall(CI, nval); return true; } } @@ -748,8 +748,8 @@ ConstantFP *CF; ConstantInt *CINT; Type *eltType; - Value *opr0 = CI->getArgOperand(0); - Value *opr1 = CI->getArgOperand(1); + Value *opr0 = FPOp->getOperand(0); + Value *opr1 = FPOp->getOperand(1); ConstantAggregateZero *CZero = dyn_cast(opr1); if (getVecSize(FInfo) == 1) { @@ -776,37 +776,37 @@ if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) { // pow/powr/pown(x, 0) == 1 - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1\n"); + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1\n"); Constant *cnval = ConstantFP::get(eltType, 1.0); if (getVecSize(FInfo) > 1) { cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); } - replaceCall(cnval); + replaceCall(FPOp, cnval); return true; } if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) { // pow/powr/pown(x, 1.0) = x - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n"); - replaceCall(opr0); + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << "\n"); + replaceCall(FPOp, opr0); return true; } if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) { // pow/powr/pown(x, 2.0) = x*x - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * " << *opr0 - << "\n"); + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << *opr0 << " * " + << *opr0 << "\n"); Value *nval = B.CreateFMul(opr0, opr0, "__pow2"); - replaceCall(nval); + replaceCall(FPOp, nval); return true; } if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) { // pow/powr/pown(x, -1.0) = 1.0/x - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1 / " << *opr0 << "\n"); + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> 1 / " << *opr0 << "\n"); Constant *cnval = ConstantFP::get(eltType, 1.0); if (getVecSize(FInfo) > 1) { cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval); } Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip"); - replaceCall(nval); + replaceCall(FPOp, nval); return true; } @@ -817,11 +817,11 @@ getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT : AMDGPULibFunc::EI_RSQRT, FInfo))) { - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << FInfo.getName() + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << FInfo.getName() << '(' << *opr0 << ")\n"); Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt" : "__pow2rsqrt"); - replaceCall(nval); + replaceCall(FPOp, nval); return true; } } @@ -874,10 +874,10 @@ } nval = B.CreateFDiv(cnval, nval, "__1powprod"); } - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0 << ")\n"); - replaceCall(nval); + replaceCall(FPOp, nval); return true; } @@ -1001,7 +1001,7 @@ if (const auto *vTy = dyn_cast(rTy)) nTy = FixedVectorType::get(nTyS, vTy); unsigned size = nTy->getScalarSizeInBits(); - opr_n = CI->getArgOperand(1); + opr_n = FPOp->getOperand(1); if (opr_n->getType()->isIntegerTy()) opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou"); else @@ -1013,9 +1013,9 @@ nval = B.CreateBitCast(nval, opr0->getType()); } - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n"); - replaceCall(nval); + replaceCall(FPOp, nval); return true; } @@ -1036,7 +1036,7 @@ int ci_opr1 = (int)CINT->getSExtValue(); if (ci_opr1 == 1) { // rootn(x, 1) = x LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n"); - replaceCall(opr0); + replaceCall(CI, opr0); return true; } if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x) @@ -1045,7 +1045,7 @@ getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) { LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> sqrt(" << *opr0 << ")\n"); Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt"); - replaceCall(nval); + replaceCall(CI, nval); return true; } } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x) @@ -1054,7 +1054,7 @@ getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) { LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> cbrt(" << *opr0 << ")\n"); Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt"); - replaceCall(nval); + replaceCall(CI, nval); return true; } } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x @@ -1062,7 +1062,7 @@ Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0), opr0, "__rootn2div"); - replaceCall(nval); + replaceCall(CI, nval); return true; } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x) Module *M = CI->getModule(); @@ -1071,7 +1071,7 @@ LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> rsqrt(" << *opr0 << ")\n"); Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt"); - replaceCall(nval); + replaceCall(CI, nval); return true; } } @@ -1096,13 +1096,15 @@ if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) && (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) { + Module *M = B.GetInsertBlock()->getModule(); + if (FunctionCallee FPExpr = getNativeFunction( - CI->getModule(), AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) { - Value *opr0 = CI->getArgOperand(0); - LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " + M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) { + Value *opr0 = FPOp->getOperand(0); + LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> " << "sqrt(" << *opr0 << ")\n"); Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt"); - replaceCall(nval); + replaceCall(FPOp, nval); return true; } } @@ -1122,7 +1124,8 @@ bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN; - Value *CArgVal = CI->getArgOperand(0); + Value *CArgVal = FPOp->getOperand(0); + CallInst *CI = cast(FPOp); BasicBlock * const CBB = CI->getParent(); int const MaxScan = 30; @@ -1138,7 +1141,7 @@ CArgVal->replaceAllUsesWith(AvailableVal); if (CArgVal->getNumUses() == 0) LI->eraseFromParent(); - CArgVal = CI->getArgOperand(0); + CArgVal = FPOp->getOperand(0); } } } @@ -1508,12 +1511,12 @@ } } - LLVMContext &context = CI->getParent()->getParent()->getContext(); + LLVMContext &context = aCI->getContext(); Constant *nval0, *nval1; if (FuncVecSize == 1) { - nval0 = ConstantFP::get(CI->getType(), DVal0[0]); + nval0 = ConstantFP::get(aCI->getType(), DVal0[0]); if (hasTwoResults) - nval1 = ConstantFP::get(CI->getType(), DVal1[0]); + nval1 = ConstantFP::get(aCI->getType(), DVal1[0]); } else { if (getArgType(FInfo) == AMDGPULibFunc::F32) { SmallVector FVal0, FVal1; @@ -1544,7 +1547,7 @@ new StoreInst(nval1, aCI->getArgOperand(1), aCI); } - replaceCall(nval0); + replaceCall(aCI, nval0); return true; }