Index: llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -14,6 +14,7 @@ #include "llvm/ADT/APSInt.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -1546,6 +1547,23 @@ return nullptr; } +Instruction *InstCombiner::foldICmpWithDominatingAssume(ICmpInst &ICmp, + Value *Op) { + for (auto &AssumeVH : AC.assumptionsFor(Op)) { + if (!AssumeVH) + continue; + + CallInst *Assume = cast(AssumeVH); + if (Optional Imp = isImpliedCondition(Assume->getArgOperand(0), &ICmp, + DL)) + if (isValidAssumeForContext(Assume, &ICmp, &DT)) + return replaceInstUsesWith(ICmp, + ConstantInt::get(ICmp.getType(), *Imp)); + } + + return nullptr; +} + /// Fold icmp (trunc X, Y), C. Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, @@ -5711,6 +5729,11 @@ if (Instruction *Res = foldVectorCmp(I, Builder)) return Res; + if (Instruction *Res = foldICmpWithDominatingAssume(I, Op0)) + return Res; + if (Instruction *Res = foldICmpWithDominatingAssume(I, Op1)) + return Res; + return Changed ? &I : nullptr; } Index: llvm/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -931,6 +931,7 @@ Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); + Instruction *foldICmpWithDominatingAssume(ICmpInst &Cmp, Value *Op); Instruction *foldICmpWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Index: llvm/test/Transforms/InstCombine/assume_icmp.ll =================================================================== --- llvm/test/Transforms/InstCombine/assume_icmp.ll +++ llvm/test/Transforms/InstCombine/assume_icmp.ll @@ -8,22 +8,14 @@ ; CHECK-LABEL: @basic_ugt( ; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i32 [[X:%.*]], [[Y:%.*]] ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP1]]) -; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP2]]) -; CHECK-NEXT: [[CMP3:%.*]] = icmp uge i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP3]]) -; CHECK-NEXT: [[CMP4:%.*]] = icmp ult i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP4]]) -; CHECK-NEXT: [[CMP5:%.*]] = icmp ule i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP5]]) -; CHECK-NEXT: [[CMP6:%.*]] = icmp ugt i32 [[Y]], [[X]] -; CHECK-NEXT: call void @use(i1 [[CMP6]]) -; CHECK-NEXT: [[CMP7:%.*]] = icmp uge i32 [[Y]], [[X]] -; CHECK-NEXT: call void @use(i1 [[CMP7]]) -; CHECK-NEXT: [[CMP8:%.*]] = icmp ult i32 [[Y]], [[X]] -; CHECK-NEXT: call void @use(i1 [[CMP8]]) -; CHECK-NEXT: [[CMP9:%.*]] = icmp ule i32 [[Y]], [[X]] -; CHECK-NEXT: call void @use(i1 [[CMP9]]) +; CHECK-NEXT: call void @use(i1 true) +; CHECK-NEXT: call void @use(i1 true) +; CHECK-NEXT: call void @use(i1 false) +; CHECK-NEXT: call void @use(i1 false) +; CHECK-NEXT: call void @use(i1 false) +; CHECK-NEXT: call void @use(i1 false) +; CHECK-NEXT: call void @use(i1 true) +; CHECK-NEXT: call void @use(i1 true) ; CHECK-NEXT: ret void ; %cmp1 = icmp ugt i32 %x, %y @@ -56,20 +48,16 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP1]]) ; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[X]], [[Y]] ; CHECK-NEXT: call void @use(i1 [[CMP2]]) -; CHECK-NEXT: [[CMP3:%.*]] = icmp uge i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP3]]) -; CHECK-NEXT: [[CMP4:%.*]] = icmp ult i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP4]]) +; CHECK-NEXT: call void @use(i1 true) +; CHECK-NEXT: call void @use(i1 false) ; CHECK-NEXT: [[CMP5:%.*]] = icmp ule i32 [[X]], [[Y]] ; CHECK-NEXT: call void @use(i1 [[CMP5]]) -; CHECK-NEXT: [[CMP6:%.*]] = icmp ugt i32 [[Y]], [[X]] -; CHECK-NEXT: call void @use(i1 [[CMP6]]) +; CHECK-NEXT: call void @use(i1 false) ; CHECK-NEXT: [[CMP7:%.*]] = icmp uge i32 [[Y]], [[X]] ; CHECK-NEXT: call void @use(i1 [[CMP7]]) ; CHECK-NEXT: [[CMP8:%.*]] = icmp ult i32 [[Y]], [[X]] ; CHECK-NEXT: call void @use(i1 [[CMP8]]) -; CHECK-NEXT: [[CMP9:%.*]] = icmp ule i32 [[Y]], [[X]] -; CHECK-NEXT: call void @use(i1 [[CMP9]]) +; CHECK-NEXT: call void @use(i1 true) ; CHECK-NEXT: ret void ; %cmp1 = icmp uge i32 %x, %y @@ -102,14 +90,10 @@ ; CHECK-NEXT: [[CMP2:%.*]] = icmp ugt i32 [[Z:%.*]], [[Y]] ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP1]]) ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP2]]) -; CHECK-NEXT: [[CMP3:%.*]] = icmp ugt i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP3]]) -; CHECK-NEXT: [[CMP4:%.*]] = icmp uge i32 [[X]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP4]]) -; CHECK-NEXT: [[CMP5:%.*]] = icmp ugt i32 [[Z]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP5]]) -; CHECK-NEXT: [[CMP6:%.*]] = icmp uge i32 [[Z]], [[Y]] -; CHECK-NEXT: call void @use(i1 [[CMP6]]) +; CHECK-NEXT: call void @use(i1 true) +; CHECK-NEXT: call void @use(i1 true) +; CHECK-NEXT: call void @use(i1 true) +; CHECK-NEXT: call void @use(i1 true) ; CHECK-NEXT: ret void ; %cmp1 = icmp ugt i32 %x, %y