Index: llvm/include/llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h =================================================================== --- llvm/include/llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h +++ llvm/include/llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h @@ -8,7 +8,7 @@ /// \file /// /// AggressiveInstCombiner - Combine expression patterns to form expressions -/// with fewer, simple instructions. This pass does not modify the CFG. +/// with fewer, simple instructions. /// //===----------------------------------------------------------------------===// Index: llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp =================================================================== --- llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -19,6 +19,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -28,6 +29,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" @@ -396,52 +398,6 @@ return true; } -/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids -/// pessimistic codegen that has to account for setting errno and can enable -/// vectorization. -static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI, - TargetLibraryInfo &TLI) { - // Match a call to sqrt mathlib function. - auto *Call = dyn_cast(&I); - if (!Call) - return false; - - Module *M = Call->getModule(); - LibFunc Func; - if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func)) - return false; - - if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) - return false; - - // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created - // (because NNAN or the operand arg must not be less than -0.0) and (2) we - // would not end up lowering to a libcall anyway (which could change the value - // of errno), then: - // (1) errno won't be set. - // (2) it is safe to convert this to an intrinsic call. - Type *Ty = Call->getType(); - Value *Arg = Call->getArgOperand(0); - if (TTI.haveFastSqrt(Ty) && - (Call->hasNoNaNs() || - CannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI))) { - IRBuilder<> Builder(&I); - IRBuilderBase::FastMathFlagGuard Guard(Builder); - Builder.setFastMathFlags(Call->getFastMathFlags()); - - Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); - Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); - I.replaceAllUsesWith(NewSqrt); - - // Explicitly erase the old call because a call with side effects is not - // trivially dead. - I.eraseFromParent(); - return true; - } - - return false; -} - // Check if this array of constants represents a cttz table. // Iterate over the elements from \p Table by trying to find/match all // the numbers from 0 to \p InputBits that should represent cttz results. @@ -913,12 +869,150 @@ return true; } +/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids +/// pessimistic codegen that has to account for setting errno and can enable +/// vectorization. +static bool foldSqrt(CallInst *Call, TargetTransformInfo &TTI, + TargetLibraryInfo &TLI) { + Module *M = Call->getModule(); + + // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created + // (because NNAN or the operand arg must not be less than -0.0) and (2) we + // would not end up lowering to a libcall anyway (which could change the value + // of errno), then: + // (1) errno won't be set. + // (2) it is safe to convert this to an intrinsic call. + Type *Ty = Call->getType(); + Value *Arg = Call->getArgOperand(0); + if (TTI.haveFastSqrt(Ty) && + (Call->hasNoNaNs() || + CannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI))) { + IRBuilder<> Builder(Call); + IRBuilderBase::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(Call->getFastMathFlags()); + + Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); + Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); + Call->replaceAllUsesWith(NewSqrt); + + // Explicitly erase the old call because a call with side effects is not + // trivially dead. + Call->eraseFromParent(); + return true; + } + + return false; +} + +/// Try to fold strcmp(P, "x") calls. +static bool foldStcmp(CallInst *CI, IRBuilderBase &B, DominatorTree &DT, + bool &MadeCFGChange) { + Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1); + + /// Trivial cases are optimized during inst combine + if (Str1P == Str2P) + return false; + + StringRef Str1, Str2; + bool HasStr1 = getConstantStringInfo(Str1P, Str1); + bool HasStr2 = getConstantStringInfo(Str2P, Str2); + + bool Match = !HasStr1 && HasStr2 && Str2.size() == 1; + if (!Match) + return false; + + // For strcmp(P, "x") do the following transformation: + // + // (before) + // dst = strcmp(P, "x") + // + // (after) + // v0 = P[0] - 'x' + // [if v0 == 0] + // v1 = P[1] + // dst = phi(v0, v1) + // + + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + + B.SetInsertPoint(CI); + BasicBlock *InitialBB = B.GetInsertBlock(); + Value *Str1FirstCharacterValue = + B.CreateIntCast(B.CreateLoad(B.getInt8Ty(), Str1P), B.getInt32Ty(), true); + Value *Str2FirstCharacterValue = ConstantInt::get( + B.getInt32Ty(), static_cast(Str2[0]), true); + Value *FirstCharacterSub = + B.CreateNSWSub(Str1FirstCharacterValue, Str2FirstCharacterValue); + Value *IsFirstCharacterSubZero = B.CreateICmpEQ( + FirstCharacterSub, ConstantInt::get(B.getInt32Ty(), 0, true)); + Instruction *IsFirstCharacterSubZeroBBTerminator = SplitBlockAndInsertIfThen( + IsFirstCharacterSubZero, CI, /*Unreachable*/ false, + /*BranchWeights*/ nullptr, &DTU); + + B.SetInsertPoint(IsFirstCharacterSubZeroBBTerminator); + B.GetInsertBlock()->setName("strcmp_fold_sub_is_zero"); + BasicBlock *IsFirstCharacterSubZeroBB = B.GetInsertBlock(); + Value *Str1SecondCharacterValue = B.CreateIntCast( + B.CreateLoad(B.getInt8Ty(), + B.CreateConstInBoundsGEP1_64(B.getInt8Ty(), Str1P, 1)), + B.getInt32Ty(), true); + + B.SetInsertPoint(CI); + B.GetInsertBlock()->setName("strcmp_fold_sub_join"); + + PHINode *Result = B.CreatePHI(B.getInt32Ty(), 2); + Result->addIncoming(FirstCharacterSub, InitialBB); + Result->addIncoming(Str1SecondCharacterValue, IsFirstCharacterSubZeroBB); + + CI->replaceAllUsesWith(Result); + CI->eraseFromParent(); + + MadeCFGChange = true; + + return true; +} + +static bool foldLibraryCalls(Instruction &I, TargetTransformInfo &TTI, + TargetLibraryInfo &TLI, DominatorTree &DT, + bool &MadeCFGChange) { + CallInst *CI = dyn_cast(&I); + if (!CI) + return false; + + bool IsCallingConvC = TargetLibraryInfoImpl::isCallingConvCCompatible(CI); + if (!IsCallingConvC) + return false; + + Module *M = I.getModule(); + + LibFunc Func; + Function *Callee = CI->getCalledFunction(); + if (!TLI.getLibFunc(*Callee, Func) || !isLibFuncEmittable(M, &TLI, Func)) + return false; + + IRBuilder<> Builder(I.getParent()); + + switch (Func) { + case LibFunc_sqrt: + case LibFunc_sqrtf: + case LibFunc_sqrtl: + return foldSqrt(CI, TTI, TLI); + case LibFunc_strcmp: + return foldStcmp(CI, Builder, DT, MadeCFGChange); + default: + break; + } + + return false; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. static bool foldUnusualPatterns(Function &F, DominatorTree &DT, TargetTransformInfo &TTI, - TargetLibraryInfo &TLI, AliasAnalysis &AA) { + TargetLibraryInfo &TLI, AliasAnalysis &AA, + bool &MadeCFGChange) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. @@ -943,7 +1037,7 @@ // NOTE: This function introduces erasing of the instruction `I`, so it // needs to be called at the end of this sequence, otherwise we may make // bugs. - MadeChange |= foldSqrt(I, TTI, TLI); + MadeChange |= foldLibraryCalls(I, TTI, TLI, DT, MadeCFGChange); } } @@ -959,12 +1053,12 @@ /// handled in the callers of this function. static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, TargetLibraryInfo &TLI, DominatorTree &DT, - AliasAnalysis &AA) { + AliasAnalysis &AA, bool &ChangedCFG) { bool MadeChange = false; const DataLayout &DL = F.getParent()->getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA); + MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, ChangedCFG); return MadeChange; } @@ -975,12 +1069,21 @@ auto &DT = AM.getResult(F); auto &TTI = AM.getResult(F); auto &AA = AM.getResult(F); - if (!runImpl(F, AC, TTI, TLI, DT, AA)) { + + bool MadeCFGChange = false; + + if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) { // No changes, all analyses are preserved. return PreservedAnalyses::all(); } + // Mark all the analyses that instcombine updates as preserved. PreservedAnalyses PA; - PA.preserveSet(); + + if (MadeCFGChange) + PA.preserve(); + else + PA.preserveSet(); + return PA; } Index: llvm/test/Transforms/AggressiveInstCombine/strcmp.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/AggressiveInstCombine/strcmp.ll @@ -0,0 +1,51 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; TODO: Test that ... +; RUN: opt < %s -passes=aggressive-instcombine -S | FileCheck %s + +declare i32 @strcmp(ptr, ptr) + +@s1 = constant [2 x i8] c"0\00" + +; Fold strcmp(C, "x"). + +define i1 @fold_strcmp_s1_1(ptr %C) { +; CHECK-LABEL: @fold_strcmp_s1_1( +; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr %C, align 1 +; CHECK-NEXT: [[SEXT1:%.*]] = sext i8 [[TMP1]] to i32 +; CHECK-NEXT: [[ADD:%.*]] = sub nsw i32 [[SEXT1]], 48 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[ADD]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label %strcmp_fold_sub_is_zero, label %strcmp_fold_sub_join +; CHECK-LABEL: strcmp_fold_sub_is_zero: +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr %C, i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP2]], align 1 +; CHECK-NEXT: [[SEXT2:%.*]] = sext i8 [[TMP3]] to i32 +; CHECK-NEXT: br label %strcmp_fold_sub_join +; CHECK-LABEL: strcmp_fold_sub_join: +; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ [[ADD]], %0 ], [ [[SEXT2]], %strcmp_fold_sub_is_zero ] +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[PHI]], 0 +; CHECK-NEXT: ret i1 [[CMP2]] +; + %call = call i32 @strcmp(ptr %C, ptr noundef @s1) + %cmp = icmp eq i32 %call, 0 + ret i1 %cmp +} + +define i32 @fold_strcmp_s1_2(ptr %C) { +; CHECK-LABEL: @fold_strcmp_s1_2( +; CHECK-NEXT: [[TMP1:%.*]] = load i8, ptr %C, align 1 +; CHECK-NEXT: [[SEXT1:%.*]] = sext i8 [[TMP1]] to i32 +; CHECK-NEXT: [[ADD:%.*]] = sub nsw i32 [[SEXT1]], 48 +; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[ADD]], 0 +; CHECK-NEXT: br i1 [[CMP1]], label %strcmp_fold_sub_is_zero, label %strcmp_fold_sub_join +; CHECK-LABEL: strcmp_fold_sub_is_zero: +; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, ptr %C, i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[TMP2]], align 1 +; CHECK-NEXT: [[SEXT2:%.*]] = sext i8 [[TMP3]] to i32 +; CHECK-NEXT: br label %strcmp_fold_sub_join +; CHECK-LABEL: strcmp_fold_sub_join: +; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ [[ADD]], %0 ], [ [[SEXT2]], %strcmp_fold_sub_is_zero ] +; CHECK-NEXT: ret i32 [[PHI]] +; + %call = call i32 @strcmp(ptr %C, ptr noundef @s1) + ret i32 %call +}