diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -1815,20 +1815,16 @@ // Round NTZ down to the next byte. If we have 11 trailing zeros, then // we need all the bits down to bit 8. Likewise, round NLZ. If we // have 14 leading zeros, round to 8. - NLZ &= ~7; - NTZ &= ~7; + NLZ = alignDown(NLZ, 8); + NTZ = alignDown(NTZ, 8); // If we need exactly one byte, we can do this transformation. if (BitWidth - NLZ - NTZ == 8) { - unsigned ResultBit = NTZ; - unsigned InputBit = BitWidth - NTZ - 8; - // Replace this with either a left or right shift to get the byte into // the right place. - unsigned ShiftOpcode = InputBit > ResultBit ? ISD::SRL : ISD::SHL; + unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL; if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) { EVT ShiftAmtTy = getShiftAmountTy(VT, DL); - unsigned ShiftAmount = - InputBit > ResultBit ? InputBit - ResultBit : ResultBit - InputBit; + unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ; SDValue ShAmt = TLO.DAG.getConstant(ShiftAmount, dl, ShiftAmtTy); SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt); return TLO.CombineTo(Op, NewOp); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -800,22 +800,21 @@ // Round NTZ down to the next byte. If we have 11 trailing zeros, then // we need all the bits down to bit 8. Likewise, round NLZ. If we // have 14 leading zeros, round to 8. - NLZ &= ~7; - NTZ &= ~7; + NLZ = alignDown(NLZ, 8); + NTZ = alignDown(NTZ, 8); // If we need exactly one byte, we can do this transformation. - if (BitWidth-NLZ-NTZ == 8) { - unsigned ResultBit = NTZ; - unsigned InputBit = BitWidth-NTZ-8; - + if (BitWidth - NLZ - NTZ == 8) { // Replace this with either a left or right shift to get the byte into // the right place. Instruction *NewVal; - if (InputBit > ResultBit) - NewVal = BinaryOperator::CreateLShr(II->getArgOperand(0), - ConstantInt::get(I->getType(), InputBit-ResultBit)); + if (NLZ > NTZ) + NewVal = BinaryOperator::CreateLShr( + II->getArgOperand(0), + ConstantInt::get(I->getType(), NLZ - NTZ)); else - NewVal = BinaryOperator::CreateShl(II->getArgOperand(0), - ConstantInt::get(I->getType(), ResultBit-InputBit)); + NewVal = BinaryOperator::CreateShl( + II->getArgOperand(0), + ConstantInt::get(I->getType(), NTZ - NLZ)); NewVal->takeName(I); return InsertNewInstWith(NewVal, *I); }