diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -545,6 +545,8 @@ SmallVector Ops = {Op0, Op1, Op2}; return getMulExpr(Ops, Flags, Depth); } + const SCEV *getShlByConstantExpr(const SCEV *LHS, ConstantInt *SA, + SCEV::NoWrapFlags ShlFlags); const SCEV *getUDivExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getUDivExactExpr(const SCEV *LHS, const SCEV *RHS); const SCEV *getURemExpr(const SCEV *LHS, const SCEV *RHS); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2913,6 +2913,36 @@ return getOrCreateMulExpr(Ops, Flags); } +/// Helper to get a SCEV representing a left shift of \p LHS by \p SA steps. +/// Returns nullptr when unsuccessful. +const SCEV *ScalarEvolution::getShlByConstantExpr(const SCEV *LHS, + ConstantInt *SA, + SCEV::NoWrapFlags ShlFlags) { + uint32_t BitWidth = cast(SA->getType())->getBitWidth(); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().uge(BitWidth)) + return nullptr; + + // We can safely preserve the nuw flag in all cases. It's also safe to turn a + // nuw nsw shl into a nuw nsw mul. However, nsw in isolation requires special + // handling. It can be preserved as long as we're not left shifting by + // bitwidth - 1. + auto Flags = SCEV::FlagAnyWrap; + if ((ShlFlags & SCEV::FlagNSW) && + ((ShlFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1))) + Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW); + if (ShlFlags & SCEV::FlagNUW) + Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW); + + Constant *X = ConstantInt::get( + getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return getMulExpr(LHS, getSCEV(X), Flags); +} + /// Represents an unsigned remainder expression based on unsigned division. const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, const SCEV *RHS) { @@ -6179,38 +6209,36 @@ return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), UTy); } + + // Model xor(shl(x, C), (-1 << C)) as shl(xor(x, -1), C). + // This is a variant of the check for xor with -1, and it handles + // the case where instcombine has folded the shift over the xor. + if (auto *LBO = dyn_cast(BO->LHS)) + if (ConstantInt *SA = dyn_cast(LBO->getOperand(1))) { + uint32_t BitWidth = cast(SA->getType())->getBitWidth(); + APInt MinusOne = APInt::getAllOnesValue(BitWidth); + // Looking for a left shift. Shift count must be less than + // bitwidth. And (-1 <<< SA) should be equal to CI. + if (LBO->getOpcode() == Instruction::Shl && + SA->getValue().ult(BitWidth) && + CI->getValue() == MinusOne.shl(SA->getValue())) { + // We can preserve the flags from the shift. Wrapping behavior + // won't change when inverting all bits before the shift. + auto Flags = getNoWrapFlagsFromUB(LBO); + const SCEV *NotXSCEV = getNotSCEV(getSCEV(LBO->getOperand(0))); + if (auto *SCEV = getShlByConstantExpr(NotXSCEV, SA, Flags)) + return SCEV; + } + } } break; case Instruction::Shl: // Turn shift left of a constant amount into a multiply. if (ConstantInt *SA = dyn_cast(BO->RHS)) { - uint32_t BitWidth = cast(SA->getType())->getBitWidth(); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (SA->getValue().uge(BitWidth)) - break; - - // We can safely preserve the nuw flag in all cases. It's also safe to - // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation - // requires special handling. It can be preserved as long as we're not - // left shifting by bitwidth - 1. - auto Flags = SCEV::FlagAnyWrap; - if (BO->Op) { - auto MulFlags = getNoWrapFlagsFromUB(BO->Op); - if ((MulFlags & SCEV::FlagNSW) && - ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1))) - Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNSW); - if (MulFlags & SCEV::FlagNUW) - Flags = (SCEV::NoWrapFlags)(Flags | SCEV::FlagNUW); - } - - Constant *X = ConstantInt::get( - getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); + auto Flags = BO->Op ? getNoWrapFlagsFromUB(BO->Op) : SCEV::FlagAnyWrap; + if (auto *SCEV = getShlByConstantExpr(getSCEV(BO->LHS), SA, Flags)) + return SCEV; } break; diff --git a/llvm/test/Analysis/ScalarEvolution/xor-shl.ll b/llvm/test/Analysis/ScalarEvolution/xor-shl.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Analysis/ScalarEvolution/xor-shl.ll @@ -0,0 +1,74 @@ +; RUN: opt < %s -scalar-evolution -analyze -enable-new-pm=0 | FileCheck %s +; RUN: opt < %s "-passes=print" -disable-output 2>&1 | FileCheck %s + +; CHECK-LABEL: @test1a +; CHECK: %y = +; CHECK-NEXT: --> (-2 + (-2 * %x)) U: [0,-1) S: [-2147483648,2147483647) +define i32 @test1a(i32 %x) { + %n = shl i32 %x, 1 + %y = xor i32 %n, -2 + ret i32 %y +} + +; CHECK-LABEL: @test1b +; CHECK: %y = +; CHECK-NEXT: --> (-2 + (-2 * %x)) U: [0,-1) S: [-2147483648,2147483647) +define i32 @test1b(i32 %x) { + %n = xor i32 %x, -1 + %y = shl i32 %n, 1 + ret i32 %y +} + +; CHECK-LABEL: @test2a +; CHECK: %y = +; CHECK-NEXT: --> (-16 + (-16 * %x)) U: [0,-15) S: [-2147483648,2147483633) +define i32 @test2a(i32 %x) { + %n = shl i32 %x, 4 + %y = xor i32 %n, -16 + ret i32 %y +} + +; CHECK-LABEL: @test2b +; CHECK: %y = +; CHECK-NEXT: --> (-16 + (-16 * %x)) U: [0,-15) S: [-2147483648,2147483633) +define i32 @test2b(i32 %x) { + %n = xor i32 %x, -1 + %y = shl i32 %n, 4 + ret i32 %y +} + +; CHECK-LABEL: @test3a +; CHECK: %y = +; CHECK-NEXT: --> (-8 + (-8 * %x)) U: [0,-7) S: [-8,1) +define i4 @test3a(i4 %x) { + %n = shl i4 %x, 3 + %y = xor i4 %n, -8 + ret i4 %y +} + +; CHECK-LABEL: @test3b +; CHECK: %y = +; CHECK-NEXT: --> (-8 + (-8 * %x)) U: [0,-7) S: [-8,1) +define i4 @test3b(i4 %x) { + %n = xor i4 %x, -1 + %y = shl i4 %n, 3 + ret i4 %y +} + +; CHECK-LABEL: @test4a +; CHECK: %y = +; CHECK-NEXT: --> %y U: [0,-7) S: [-8,1) +define i4 @test4a(i4 %x) { + %n = shl i4 %x, 4 + %y = xor i4 %n, -16 + ret i4 %y +} + +; CHECK-LABEL: @test4b +; CHECK: %y = +; CHECK-NEXT: --> %y U: [0,-7) S: [-8,1) +define i4 @test4b(i4 %x) { + %n = xor i4 %x, -1 + %y = shl i4 %n, 4 + ret i4 %y +}