diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -72,10 +72,12 @@ // There are many variants to this pattern: // a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt // b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt +// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt // All these patterns can be simplified to just: // x << ShiftShAmt // iff: // a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x) +// c) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt) static Instruction * dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, const SimplifyQuery &SQ) { @@ -89,24 +91,38 @@ auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); // (~(-1 << maskNbits)) auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); + // (-1 >> MaskShAmt) + auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt)); Value *X; - if (!match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) - return nullptr; - - // Can we simplify (MaskShAmt+ShiftShAmt) ? - Value *SumOfShAmts = - SimplifyBinOp(Instruction::BinaryOps::Add, MaskShAmt, ShiftShAmt, - SQ.getWithInstruction(OuterShift)); - if (!SumOfShAmts) - return nullptr; // Did not simplify. - // Is the total shift amount *not* smaller than the bit width? - // FIXME: could also rely on ConstantRange. - unsigned BitWidth = X->getType()->getScalarSizeInBits(); - if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE, - APInt(BitWidth, BitWidth)))) - return nullptr; - // All good, we can do this fold. + if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) { + // Can we simplify (MaskShAmt+ShiftShAmt) ? + Value *SumOfShAmts = + SimplifyBinOp(Instruction::BinaryOps::Add, MaskShAmt, ShiftShAmt, + SQ.getWithInstruction(OuterShift)); + if (!SumOfShAmts) + return nullptr; // Did not simplify. + // Is the total shift amount *not* smaller than the bit width? + // FIXME: could also rely on ConstantRange. + unsigned BitWidth = X->getType()->getScalarSizeInBits(); + if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE, + APInt(BitWidth, BitWidth)))) + return nullptr; + // All good, we can do this fold. + } else if (match(Masked, m_c_And(MaskC, m_Value(X)))) { + // Can we simplify (ShiftShAmt-MaskShAmt) ? + Value *ShAmtsDiff = + SimplifyBinOp(Instruction::BinaryOps::Sub, ShiftShAmt, MaskShAmt, + SQ.getWithInstruction(OuterShift)); + if (!ShAmtsDiff) + return nullptr; // Did not simplify. + // Is the difference non-negative? (is ShiftShAmt u>= MaskShAmt ?) + // FIXME: could also rely on ConstantRange. + if (!match(ShAmtsDiff, m_NonNegative())) + return nullptr; + // All good, we can do this fold. + } else + return nullptr; // Don't know anything about this pattern. // No 'NUW'/'NSW'! return BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt); diff --git a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-c.ll b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-c.ll --- a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-c.ll +++ b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-c.ll @@ -21,7 +21,7 @@ ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]] ; CHECK-NEXT: call void @use32(i32 [[T0]]) ; CHECK-NEXT: call void @use32(i32 [[T1]]) -; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]] +; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]] ; CHECK-NEXT: ret i32 [[T2]] ; %t0 = lshr i32 -1, %nbits @@ -40,7 +40,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T0]]) ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[T2]] +; CHECK-NEXT: [[T3:%.*]] = shl i32 [[X]], [[T2]] ; CHECK-NEXT: ret i32 [[T3]] ; %t0 = lshr i32 -1, %nbits @@ -65,7 +65,7 @@ ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]]) -; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]] +; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]] ; CHECK-NEXT: ret <3 x i32> [[T3]] ; %t0 = lshr <3 x i32> , %nbits @@ -86,7 +86,7 @@ ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]]) -; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]] +; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]] ; CHECK-NEXT: ret <3 x i32> [[T3]] ; %t0 = lshr <3 x i32> , %nbits @@ -107,7 +107,7 @@ ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]]) -; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]] +; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]] ; CHECK-NEXT: ret <3 x i32> [[T3]] ; %t0 = lshr <3 x i32> , %nbits @@ -131,7 +131,7 @@ ; CHECK-NEXT: [[T1:%.*]] = and i32 [[X]], [[T0]] ; CHECK-NEXT: call void @use32(i32 [[T0]]) ; CHECK-NEXT: call void @use32(i32 [[T1]]) -; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]] +; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]] ; CHECK-NEXT: ret i32 [[T2]] ; %x = call i32 @gen32() @@ -151,7 +151,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T0]]) ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) -; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T2]], [[NBITS0]] +; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[NBITS0]] ; CHECK-NEXT: ret i32 [[T3]] ; %t0 = lshr i32 -1, %nbits0 @@ -192,7 +192,7 @@ ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]] ; CHECK-NEXT: call void @use32(i32 [[T0]]) ; CHECK-NEXT: call void @use32(i32 [[T1]]) -; CHECK-NEXT: [[T2:%.*]] = shl nuw i32 [[T1]], [[NBITS]] +; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]] ; CHECK-NEXT: ret i32 [[T2]] ; %t0 = lshr i32 -1, %nbits @@ -209,7 +209,7 @@ ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]] ; CHECK-NEXT: call void @use32(i32 [[T0]]) ; CHECK-NEXT: call void @use32(i32 [[T1]]) -; CHECK-NEXT: [[T2:%.*]] = shl nsw i32 [[T1]], [[NBITS]] +; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]] ; CHECK-NEXT: ret i32 [[T2]] ; %t0 = lshr i32 -1, %nbits @@ -226,7 +226,7 @@ ; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]] ; CHECK-NEXT: call void @use32(i32 [[T0]]) ; CHECK-NEXT: call void @use32(i32 [[T1]]) -; CHECK-NEXT: [[T2:%.*]] = shl nuw nsw i32 [[T1]], [[NBITS]] +; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]] ; CHECK-NEXT: ret i32 [[T2]] ; %t0 = lshr i32 -1, %nbits