diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3205,7 +3205,9 @@ } /// Handle icmp with constant (but not simple integer constant) RHS. -Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { +Instruction * +InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I, + const SimplifyQuery &Q) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Constant *RHSC = dyn_cast(Op1); Instruction *LHSI = dyn_cast(Op0); @@ -3235,14 +3237,26 @@ // constant folded and the select turned into a bitwise or. Value *Op1 = nullptr, *Op2 = nullptr; ConstantInt *CI = nullptr; - if (Constant *C = dyn_cast(LHSI->getOperand(1))) { - Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + + auto SimplifyOp = [&](Value *V) { + Value *Op = nullptr; + if (Constant *C = dyn_cast(V)) { + Op = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + } else if (RHSC->isNullValue()) { + // If null is being compared, check if V is a nonnull argument. + Argument *A = dyn_cast(V); + if (A && A->hasNonNullAttr(true)) + Op = SimplifyICmpInst(I.getPredicate(), A, RHSC, Q); + } + return Op; + }; + Op1 = SimplifyOp(LHSI->getOperand(1)); + if (Op1) CI = dyn_cast(Op1); - } - if (Constant *C = dyn_cast(LHSI->getOperand(2))) { - Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + + Op2 = SimplifyOp(LHSI->getOperand(2)); + if (Op2) CI = dyn_cast(Op2); - } // We only want to perform this transformation if it will not lead to // additional code. This is true if either both sides of the select @@ -5640,7 +5654,7 @@ if (Instruction *New = foldSignBitTest(I)) return New; - if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) + if (Instruction *Res = foldICmpInstWithConstantNotInt(I, Q)) return Res; // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -648,7 +648,8 @@ Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); - Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); + Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp, + const SimplifyQuery &SQ); Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); diff --git a/llvm/test/Transforms/InstCombine/assume-icmp-null-select.ll b/llvm/test/Transforms/InstCombine/assume-icmp-null-select.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/assume-icmp-null-select.ll @@ -0,0 +1,45 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instcombine -S | FileCheck %s +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +define align 8 dereferenceable(24) i8* @example(i8* readonly align 8 dereferenceable(24) %x) { +; CHECK-LABEL: @example( +; CHECK-NEXT: [[DOT0:%.*]] = bitcast i8* [[X:%.*]] to {}** +; CHECK-NEXT: [[TMP1:%.*]] = load {}*, {}** [[DOT0]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne {}* [[TMP1]], null +; CHECK-NEXT: call void @llvm.assume(i1 [[TMP2]]) +; CHECK-NEXT: ret i8* [[X]] +; + %.0 = bitcast i8* %x to {}** + %1 = load {}*, {}** %.0, align 8 + %2 = icmp eq {}* %1, null + %3 = getelementptr inbounds i8, i8* %x, i64 0 + %.0.i = select i1 %2, i8* null, i8* %3 + %4 = icmp ne i8* %.0.i, null + call void @llvm.assume(i1 %4) + ret i8* %.0.i +} + +; TODO: this should be optimized to 'ret %x' as well. +define align 8 dereferenceable(24) i8* @example_negative(i8* readonly align 8 %x) { +; CHECK-LABEL: @example_negative( +; CHECK-NEXT: [[DOT0:%.*]] = bitcast i8* [[X:%.*]] to {}** +; CHECK-NEXT: [[TMP1:%.*]] = load {}*, {}** [[DOT0]], align 8 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq {}* [[TMP1]], null +; CHECK-NEXT: [[DOT0_I:%.*]] = select i1 [[TMP2]], i8* null, i8* [[X]] +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i8* [[DOT0_I]], null +; CHECK-NEXT: call void @llvm.assume(i1 [[TMP3]]) +; CHECK-NEXT: ret i8* [[DOT0_I]] +; + %.0 = bitcast i8* %x to {}** + %1 = load {}*, {}** %.0, align 8 + %2 = icmp eq {}* %1, null + %3 = getelementptr inbounds i8, i8* %x, i64 0 + %.0.i = select i1 %2, i8* null, i8* %3 + %4 = icmp ne i8* %.0.i, null + call void @llvm.assume(i1 %4) + ret i8* %.0.i +} + +declare void @llvm.assume(i1)