diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -186,7 +186,7 @@ } static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold); + bool AssumeNonZero, bool DoFold); Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -486,15 +486,19 @@ // (shl Op1, Log2(Op0)) // if Log2(Op1) folds away -> // (shl Op0, Log2(Op1)) - if (takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ false)) { - Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*DoFold*/ true); + if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res); // We can only propegate nuw flag. Shl->setHasNoUnsignedWrap(HasNUW); return Shl; } - if (takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*DoFold*/ true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res); // We can only propegate nuw flag. Shl->setHasNoUnsignedWrap(HasNUW); @@ -1181,7 +1185,7 @@ // actual instructions, otherwise return a non-null dummy value. Return nullptr // on failure. static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold) { + bool AssumeNonZero, bool DoFold) { auto IfFold = [DoFold](function_ref Fn) { if (!DoFold) return reinterpret_cast(-1); @@ -1207,14 +1211,22 @@ // FIXME: Require one use? Value *X, *Y; if (match(Op, m_ZExt(m_Value(X)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); // log2(X << Y) -> log2(X) + Y // FIXME: Require one use unless X is 1? - if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) - return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) { + bool Usable = AssumeNonZero; + if (!AssumeNonZero) { + BinaryOperator *BO = cast(Op); + // nuw will be set if the `shl` is trivially non-zero. + Usable = BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap(); + } + if (Usable) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) // FIXME: missed optimization: if one of the hands of select is/contains @@ -1222,8 +1234,10 @@ // FIXME: can both hands contain undef? // FIXME: Require one use? if (SelectInst *SI = dyn_cast(Op)) - if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, + AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, + AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); }); @@ -1232,11 +1246,13 @@ // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) auto *MinMax = dyn_cast(Op); if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) - if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, + AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, + AssumeNonZero, DoFold)) return IfFold([&]() { - return Builder.CreateBinaryIntrinsic( - MinMax->getIntrinsicID(), LogX, LogY); + return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX, + LogY); }); return nullptr; @@ -1357,8 +1373,10 @@ } // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. - if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, + /*AssumeNonZero*/ true, /*DoFold*/ true); return replaceInstUsesWith( I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); } diff --git a/llvm/test/Transforms/InstCombine/mul-pow2.ll b/llvm/test/Transforms/InstCombine/mul-pow2.ll --- a/llvm/test/Transforms/InstCombine/mul-pow2.ll +++ b/llvm/test/Transforms/InstCombine/mul-pow2.ll @@ -102,3 +102,37 @@ %r = mul <2 x i8> %x, %s ret <2 x i8> %r } + + +define i8 @shl_add_log_may_cause_poison_pr62175_fail(i8 %x, i8 %y) { +; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_fail( +; CHECK-NEXT: [[SHL:%.*]] = shl i8 4, [[X:%.*]] +; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[SHL]], [[Y:%.*]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %shl = shl i8 4, %x + %mul = mul i8 %y, %shl + ret i8 %mul +} + +define i8 @shl_add_log_may_cause_poison_pr62175_with_nuw(i8 %x, i8 %y) { +; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_with_nuw( +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[X:%.*]], 2 +; CHECK-NEXT: [[MUL:%.*]] = shl i8 [[Y:%.*]], [[TMP1]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %shl = shl nuw i8 4, %x + %mul = mul i8 %y, %shl + ret i8 %mul +} + +define i8 @shl_add_log_may_cause_poison_pr62175_with_nsw(i8 %x, i8 %y) { +; CHECK-LABEL: @shl_add_log_may_cause_poison_pr62175_with_nsw( +; CHECK-NEXT: [[TMP1:%.*]] = add i8 [[X:%.*]], 2 +; CHECK-NEXT: [[MUL:%.*]] = shl i8 [[Y:%.*]], [[TMP1]] +; CHECK-NEXT: ret i8 [[MUL]] +; + %shl = shl nsw i8 4, %x + %mul = mul i8 %y, %shl + ret i8 %mul +}