diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -65,6 +65,53 @@ return NewShift; } +// If we have some pattern that leaves only some low bits set, and then performs +// left-shift of those bits, if none of the bits that are left after the final +// shift are modified by the mask, we can omit the mask. +// +// There are many variants to this pattern: +// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt +// All these patterns can be simplified to just: +// x << ShiftShAmt +// iff: +// a) (MaskShAmt+ShiftShAmt) u>= bitwidth(x) +static Instruction * +dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, + const SimplifyQuery &SQ) { + assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl && + "The input must be 'shl'!"); + + Value *Masked = OuterShift->getOperand(0); + Value *ShiftShAmt = OuterShift->getOperand(1); + + Value *MaskShAmt; + + // ((1 << MaskShAmt) - 1) + auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); + + Value *X; + if (!match(Masked, m_c_And(MaskA, m_Value(X)))) + return nullptr; + + // Can we simplify (MaskShAmt+ShiftShAmt) ? + Value *SumOfShAmts = + SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, + SQ.getWithInstruction(OuterShift)); + if (!SumOfShAmts) + return nullptr; // Did not simplify. + // Is the total shift amount *not* smaller than the bit width? + // FIXME: could also rely on ConstantRange. + unsigned BitWidth = X->getType()->getScalarSizeInBits(); + if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE, + APInt(BitWidth, BitWidth)))) + return nullptr; + // All good, we can do this fold. + + // No 'NUW'/'NSW'! + // We no longer know that we won't shift-out non-0 bits. + return BinaryOperator::Create(OuterShift->getOpcode(), X, ShiftShAmt); +} + Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); @@ -629,6 +676,9 @@ if (Instruction *V = commonShiftTransforms(I)) return V; + if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, SQ)) + return V; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); diff --git a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-a.ll b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-a.ll --- a/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-a.ll +++ b/llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-a.ll @@ -25,7 +25,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: call void @use32(i32 [[T3]]) -; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]] +; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]] ; CHECK-NEXT: ret i32 [[T4]] ; %t0 = shl i32 1, %nbits @@ -50,7 +50,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: call void @use32(i32 [[T3]]) -; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]] +; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]] ; CHECK-NEXT: ret i32 [[T4]] ; %t0 = shl i32 1, %nbits @@ -77,7 +77,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: call void @use32(i32 [[T3]]) ; CHECK-NEXT: call void @use32(i32 [[T4]]) -; CHECK-NEXT: [[T5:%.*]] = shl i32 [[T3]], [[T4]] +; CHECK-NEXT: [[T5:%.*]] = shl i32 [[X]], [[T4]] ; CHECK-NEXT: ret i32 [[T5]] ; %t0 = add i32 %nbits, 1 @@ -109,7 +109,7 @@ ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]]) -; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]] +; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]] ; CHECK-NEXT: ret <3 x i32> [[T5]] ; %t0 = add <3 x i32> %nbits, @@ -138,7 +138,7 @@ ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]]) -; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]] +; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]] ; CHECK-NEXT: ret <3 x i32> [[T5]] ; %t0 = add <3 x i32> %nbits, @@ -166,7 +166,7 @@ ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T3]]) ; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T4]]) -; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[T3]], [[T4]] +; CHECK-NEXT: [[T5:%.*]] = shl <3 x i32> [[X]], [[T4]] ; CHECK-NEXT: ret <3 x i32> [[T5]] ; %t0 = add <3 x i32> %nbits, @@ -198,7 +198,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: call void @use32(i32 [[T3]]) -; CHECK-NEXT: [[T4:%.*]] = shl i32 [[T2]], [[T3]] +; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]] ; CHECK-NEXT: ret i32 [[T4]] ; %x = call i32 @gen32() @@ -260,7 +260,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T3]]) ; CHECK-NEXT: call void @use32(i32 [[T4]]) ; CHECK-NEXT: call void @use32(i32 [[T5]]) -; CHECK-NEXT: [[T6:%.*]] = shl i32 [[T4]], [[T5]] +; CHECK-NEXT: [[T6:%.*]] = shl i32 [[T1]], [[T5]] ; CHECK-NEXT: ret i32 [[T6]] ; %t0 = shl i32 1, %nbits0 @@ -291,7 +291,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: call void @use32(i32 [[T3]]) -; CHECK-NEXT: [[T4:%.*]] = shl nuw i32 [[T2]], [[T3]] +; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]] ; CHECK-NEXT: ret i32 [[T4]] ; %t0 = shl i32 1, %nbits @@ -316,7 +316,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: call void @use32(i32 [[T3]]) -; CHECK-NEXT: [[T4:%.*]] = shl nsw i32 [[T2]], [[T3]] +; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]] ; CHECK-NEXT: ret i32 [[T4]] ; %t0 = shl i32 1, %nbits @@ -341,7 +341,7 @@ ; CHECK-NEXT: call void @use32(i32 [[T1]]) ; CHECK-NEXT: call void @use32(i32 [[T2]]) ; CHECK-NEXT: call void @use32(i32 [[T3]]) -; CHECK-NEXT: [[T4:%.*]] = shl nuw nsw i32 [[T2]], [[T3]] +; CHECK-NEXT: [[T4:%.*]] = shl i32 [[X]], [[T3]] ; CHECK-NEXT: ret i32 [[T4]] ; %t0 = shl i32 1, %nbits