diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -2264,6 +2264,36 @@ Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy)); } + // Simplify nonnull arguments if possible. + // ptr = select _, null, p + // f(nonnull ptr) + // => + // f(nonnull p) + for (unsigned ArgNo = 0, N = Call.getNumArgOperands(); ArgNo != N; ++ArgNo) { + if (!Call.paramHasAttr(ArgNo, Attribute::NonNull)) + continue; + + Value *V = Call.getArgOperand(ArgNo); + Value *Cond, *TrueVal, *FalseVal; + if (!match(V, m_Select(m_Value(Cond), m_Value(TrueVal), m_Value(FalseVal)))) + continue; + + bool IsTrueNull = isa(TrueVal); + if (IsTrueNull || isa(FalseVal)) { + replaceUse(Call.getArgOperandUse(ArgNo), IsTrueNull ? FalseVal : TrueVal); + + if (Call.paramHasAttr(ArgNo, Attribute::NoUndef)) { + // ptr = select cond, null, p + // f(nonnull noundef ptr) + // => + // llvm.assume(!cond) + // f(nonnull noundef p) + Builder.CreateAssumption(IsTrueNull ? Builder.CreateNot(Cond) : Cond); + } + Changed = true; + } + } + if (isAllocLikeFn(&Call, &TLI)) return visitAllocSite(Call); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -2844,6 +2844,31 @@ return nullptr; Value *ResultOp = RI.getOperand(0); + + auto Attrs = RI.getFunction()->getAttributes(); + if (Attrs.hasAttribute(AttributeList::ReturnIndex, Attribute::NonNull)) { + // ptr = select _, null, p + // ret nonnull ptr + // => + // ret nonnull p + Value *Cond, *TrueVal, *FalseVal; + if (match(ResultOp, + m_Select(m_Value(Cond), m_Value(TrueVal), m_Value(FalseVal))) && + (isa(TrueVal) || + isa(FalseVal))) { + bool IsTrueNull = isa(TrueVal); + if (Attrs.hasAttribute(AttributeList::ReturnIndex, Attribute::NoUndef)) { + // ptr = select cond, null, p + // ret noundef nonnull ptr + // => + // llvm.assume(!cond) + // ret noundef nonnull p + Builder.CreateAssumption(IsTrueNull ? Builder.CreateNot(Cond) : Cond); + } + return replaceOperand(RI, 0, IsTrueNull ? FalseVal : TrueVal); + } + } + Type *VTy = ResultOp->getType(); if (!VTy->isIntegerTy() || isa(ResultOp)) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/nonnull-select.ll b/llvm/test/Transforms/InstCombine/nonnull-select.ll --- a/llvm/test/Transforms/InstCombine/nonnull-select.ll +++ b/llvm/test/Transforms/InstCombine/nonnull-select.ll @@ -5,11 +5,8 @@ define nonnull i32* @pr48975(i32** %.0) { ; CHECK-LABEL: @pr48975( -; CHECK-NEXT: [[DOT1:%.*]] = load i32*, i32** [[DOT0:%.*]], align 8 -; CHECK-NEXT: [[DOT2:%.*]] = icmp eq i32* [[DOT1]], null -; CHECK-NEXT: [[DOT3:%.*]] = bitcast i32** [[DOT0]] to i32* -; CHECK-NEXT: [[DOT4:%.*]] = select i1 [[DOT2]], i32* null, i32* [[DOT3]] -; CHECK-NEXT: ret i32* [[DOT4]] +; CHECK-NEXT: [[DOT3:%.*]] = bitcast i32** [[DOT0:%.*]] to i32* +; CHECK-NEXT: ret i32* [[DOT3]] ; %.1 = load i32*, i32** %.0, align 8 %.2 = icmp eq i32* %.1, null @@ -20,8 +17,7 @@ define nonnull i32* @nonnull_ret(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_ret( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* [[P:%.*]], i32* null -; CHECK-NEXT: ret i32* [[RES]] +; CHECK-NEXT: ret i32* [[P:%.*]] ; %res = select i1 %cond, i32* %p, i32* null ret i32* %res @@ -29,8 +25,7 @@ define nonnull i32* @nonnull_ret2(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_ret2( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* null, i32* [[P:%.*]] -; CHECK-NEXT: ret i32* [[RES]] +; CHECK-NEXT: ret i32* [[P:%.*]] ; %res = select i1 %cond, i32* null, i32* %p ret i32* %res @@ -38,8 +33,8 @@ define nonnull noundef i32* @nonnull_noundef_ret(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_noundef_ret( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* [[P:%.*]], i32* null -; CHECK-NEXT: ret i32* [[RES]] +; CHECK-NEXT: call void @llvm.assume(i1 [[COND:%.*]]) +; CHECK-NEXT: ret i32* [[P:%.*]] ; %res = select i1 %cond, i32* %p, i32* null ret i32* %res @@ -47,8 +42,9 @@ define nonnull noundef i32* @nonnull_noundef_ret2(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_noundef_ret2( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* null, i32* [[P:%.*]] -; CHECK-NEXT: ret i32* [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[COND:%.*]], true +; CHECK-NEXT: call void @llvm.assume(i1 [[TMP1]]) +; CHECK-NEXT: ret i32* [[P:%.*]] ; %res = select i1 %cond, i32* null, i32* %p ret i32* %res @@ -57,8 +53,7 @@ define void @nonnull_call(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_call( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* [[P:%.*]], i32* null -; CHECK-NEXT: call void @f(i32* nonnull [[RES]]) +; CHECK-NEXT: call void @f(i32* nonnull [[P:%.*]]) ; CHECK-NEXT: ret void ; %res = select i1 %cond, i32* %p, i32* null @@ -68,8 +63,7 @@ define void @nonnull_call2(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_call2( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* null, i32* [[P:%.*]] -; CHECK-NEXT: call void @f(i32* nonnull [[RES]]) +; CHECK-NEXT: call void @f(i32* nonnull [[P:%.*]]) ; CHECK-NEXT: ret void ; %res = select i1 %cond, i32* null, i32* %p @@ -79,8 +73,8 @@ define void @nonnull_noundef_call(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_noundef_call( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* [[P:%.*]], i32* null -; CHECK-NEXT: call void @f(i32* noundef nonnull [[RES]]) +; CHECK-NEXT: call void @llvm.assume(i1 [[COND:%.*]]) +; CHECK-NEXT: call void @f(i32* noundef nonnull [[P:%.*]]) ; CHECK-NEXT: ret void ; %res = select i1 %cond, i32* %p, i32* null @@ -90,8 +84,9 @@ define void @nonnull_noundef_call2(i1 %cond, i32* %p) { ; CHECK-LABEL: @nonnull_noundef_call2( -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND:%.*]], i32* null, i32* [[P:%.*]] -; CHECK-NEXT: call void @f(i32* noundef nonnull [[RES]]) +; CHECK-NEXT: [[TMP1:%.*]] = xor i1 [[COND:%.*]], true +; CHECK-NEXT: call void @llvm.assume(i1 [[TMP1]]) +; CHECK-NEXT: call void @f(i32* noundef nonnull [[P:%.*]]) ; CHECK-NEXT: ret void ; %res = select i1 %cond, i32* null, i32* %p