Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -23,6 +23,12 @@ #define DEBUG_TYPE "instcombine" +static cl::opt MaxMaskedUsersFold( + "max-masked-users-fold", + cl::desc("Maximum number of users a masked value can have " + "so that it is allowed to check for redundant bit masking."), + cl::init(4), cl::Hidden); + /// Similar to getICmpCode but for FCmpInst. This encodes a fcmp predicate into /// a four bit mask. static unsigned getFCmpCode(FCmpInst::Predicate CC) { @@ -1409,6 +1415,8 @@ if (Instruction *X = foldShuffledBinop(I)) return X; + if (Instruction *Shift = foldRedundantShiftedMasks(&I)) + return Shift; // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. if (SimplifyDemandedInstructionBits(I)) @@ -2787,3 +2795,134 @@ return Changed ? &I : nullptr; } + +// fold expressions x1 and x2 alike: +// x1 = ( and, x, 0x00FF ) +// x2 = (( shl x, 8 ) and 0xFF00 ) +// into +// x2 = shl x1, 8 ; reuse the computation of x1 +Instruction *InstCombiner::foldRedundantShiftedMasks(BinaryOperator *AND) { + // Less than 2 users per masked value disable this function + if (MaxMaskedUsersFold < 2) + return nullptr; + + // 1st Check our desired pattern / structure + if (!AND || AND->getOpcode() != Instruction::And) + return nullptr; + + Instruction *SHIFT = dyn_cast(AND->getOperand(0)); + if (!SHIFT || (SHIFT->getNumOperands() != 2) || (!SHIFT->hasOneUse())) + return nullptr; + + unsigned N0Opcode = SHIFT->getOpcode(); + switch (N0Opcode) { + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + break; + default: + return nullptr; + } + ConstantInt *ShiftAmount = dyn_cast(SHIFT->getOperand(1)); + if (!ShiftAmount) + return nullptr; + + const ConstantInt *Mask = dyn_cast(AND->getOperand(1)); + if (!Mask) + return nullptr; + + Value *MaskedValue = dyn_cast(SHIFT->getOperand(0)); + if (!MaskedValue || MaskedValue->hasNUsesOrMore(MaxMaskedUsersFold)) + return nullptr; + + for (auto OtherUserV : MaskedValue->users()) { + BinaryOperator *OtherUser = dyn_cast(OtherUserV); + if (!OtherUser) + continue; + + if (OtherUser == SHIFT || (OtherUser->getOpcode() != Instruction::And)) + continue; + + // If we are not in the same BB, don't do it for now. + if (OtherUser->getParent() != SHIFT->getParent()) + continue; + + ConstantInt *OtherMask = dyn_cast(OtherUser->getOperand(1)); + + if (!OtherMask) + continue; + + const APInt &MaskValue = Mask->getValue(); + const APInt &ShiftValue = ShiftAmount->getValue(); + const APInt &OtherMaskValue = OtherMask->getValue(); + + KnownBits MaskedValueBits = computeKnownBits(MaskedValue, 0, nullptr); + + if (OtherMaskValue.getBitWidth() != MaskedValueBits.Zero.getBitWidth()) { + LLVM_DEBUG(dbgs() << "Check why MaskedValueBits has different width:"; + MaskedValue->dump(); AND->dump(); MaskValue.dump(); + OtherUser->dump();); + return nullptr; + } + const APInt EffectiveOtherMask = OtherMaskValue & ~MaskedValueBits.Zero; + + KnownBits ShiftedValueBits = computeKnownBits(SHIFT, 0, nullptr); + if (MaskValue.getBitWidth() != ShiftedValueBits.Zero.getBitWidth()) { + LLVM_DEBUG(dbgs() << "Check why ShiftedValueBits has different width:"; + MaskedValue->dump(); AND->dump(); MaskValue.dump(); + OtherUser->dump();); + return nullptr; + } + const APInt EffectiveMask = MaskValue & ~ShiftedValueBits.Zero; + + // 2nd Check if the masks and shifted masks match + bool CanReduce = false; + switch (N0Opcode) { + case Instruction::Shl: + CanReduce = (EffectiveOtherMask.shl(EffectiveMask) == EffectiveMask) || + (EffectiveMask.lshr(ShiftValue) == EffectiveOtherMask); + break; + case Instruction::AShr: + if (!MaskedValueBits.Zero.isSignBitSet()) { + CanReduce = (EffectiveOtherMask.ashr(ShiftValue) == EffectiveMask); + break; + } else // Same as LShr + N0Opcode = Instruction::LShr; + LLVM_FALLTHROUGH; + case Instruction::LShr: + CanReduce = (EffectiveOtherMask.lshr(ShiftValue) == EffectiveMask) || + (EffectiveMask.shl(ShiftValue) == EffectiveOtherMask); + break; + default: + return nullptr; + } + + if (!CanReduce) + continue; + LLVM_DEBUG(errs() << "\tValue being masked and shift-masked: "; + MaskedValue->dump(); + errs() << "\t\tValue zero bits: 0x" + << MaskedValueBits.Zero.toString(16, false) + << "\n\n\t\tApplied mask: 0x" + << OtherMaskValue.toString(16, false) << " : "; + OtherUser->dump(); + errs() << "\t\tEffective mask: 0x" + << EffectiveOtherMask.toString(16, false) + << "\n\n\tShifted by: " << ShiftValue.getZExtValue() + << " : "; + SHIFT->dump(); errs() << "\t\tAnd masked by: 0x" + << MaskValue.toString(16, false) << " : "; + AND->dump(); errs() << "\t\tEffective mask to shifted value: 0x" + << EffectiveMask.toString(16, false); + errs() << "\n\tCan just shift the masked value from "; + OtherUser->dump();); + // 3rd If OtherUser (the new producer) runs after this SHIFT, then we must + // move it higher. + if (!DT.dominates(OtherUser, SHIFT)) + OtherUser->moveBefore(SHIFT); + + return BinaryOperator::Create((Instruction::BinaryOps)(N0Opcode), OtherUser, + ShiftAmount); + } + return nullptr; +} Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -368,6 +368,7 @@ Instruction *visitICmpInst(ICmpInst &I); Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); + Instruction *foldRedundantShiftedMasks(BinaryOperator *AND); Instruction *commonCastTransforms(CastInst &CI); Instruction *commonPointerCastTransforms(CastInst &CI); Instruction *visitTrunc(TruncInst &CI); Index: test/Transforms/InstCombine/D48278.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/D48278.ll @@ -0,0 +1,96 @@ +; RUN: opt -S -O3 %s -o - | FileCheck %s + +; https://reviews.llvm.org/D48278 +; Fold redundant masking operations of shifted value +; In a case where +; x1 = a & 0xFF +; x2 = a << 8 & 0xFF00 +; we can see x2 as: +; x2 = a << 8 & 0xFF << 8 +; that can be translated to +; x2 = (a & 0xFF) << 8 +; that is +; x2 = x1 << 8 + + +define i32 @shl(i16 %a) { +entry: + %0 = sext i16 %a to i32 + %1 = shl i32 %0, 8 + %2 = and i32 %0, 172 + %3 = and i32 %1, 44032 + %4 = or i32 %2, %3 + ret i32 %4 +} + +define i32 @lshr(i16 %a) { +entry: + %0 = sext i16 %a to i32 + %1 = lshr i32 %0, 8 + %2 = and i32 %1, 172 + %3 = and i32 %0, 44032 + %4 = or i32 %2, %3 + ret i32 %4 +} + +define i32 @ashr(i16 %a) { +entry: + %0 = sext i16 %a to i32 + %1 = and i32 %0, 44032 + %2 = ashr i32 %0, 8 + %3 = and i32 %2, 172 + %4 = or i32 %1, %3 + ret i32 %4 +} + +define i32 @shl_nogood(i16 %a) { +entry: + %0 = sext i16 %a to i32 + %1 = and i32 %0, 172 + %2 = shl i32 %0, %1 + %3 = and i32 %2, 44032 + %4 = or i32 %1, %3 + ret i32 %4 +} + +define i32 @shl_nogood2(i16 %a) { +entry: + %0 = sext i16 %a to i32 + %1 = and i32 %0, 172 + %2 = shl i32 %0, 8 + %3 = and i32 %2, %0 + %4 = or i32 %1, %3 + ret i32 %4 +} + +; CHECK-LABEL @shl +; CHECK: %0 = and i16 %a, 172 +; CHECK-NEXT: %1 = shl nuw i16 %0, 8 +; CHECK-NEXT: %2 = or i16 %1, %0 +; CHECK-NEXT: %3 = zext i16 %2 to i32 + +; CHECK-LABEL: @lshr +; CHECK: %0 = and i16 %a, -21504 +; CHECK-NEXT: %1 = zext i16 %0 to i32 +; CHECK-NEXT: %2 = lshr exact i32 %1, 8 +; CHECK-NEXT: %3 = or i32 %2, %1 + +; CHECK-LABEL: @ashr +; CHECK: %0 = and i16 %a, -21504 +; CHECK-NEXT: %1 = zext i16 %0 to i32 +; CHECK-NEXT: %2 = lshr exact i32 %1, 8 +; CHECK-NEXT: %3 = or i32 %2, %1 +; CHECK-NEXT: ret i32 %3 + +; CHECK-LABEL: @shl_nogood +; CHECK: %0 = sext i16 %a to i32 +; CHECK-NEXT: %1 = and i32 %0, 172 +; CHECK-NEXT: %2 = shl i32 %0, %1 +; CHECK-NEXT: %3 = and i32 %2, 44032 +; CHECK-NEXT: %4 = or i32 %3, %1 + +; CHECK-LABEL: @shl_nogood2 +; CHECK: %0 = sext i16 %a to i32 +; CHECK-NEXT: %1 = shl nsw i32 %0, 8 +; CHECK-NEXT: %2 = or i32 %1, 172 +; CHECK-NEXT: %3 = and i32 %2, %0