diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -636,10 +636,12 @@ /// Try to narrow the width of math or bitwise logic instructions by pulling a /// truncate ahead of binary operators. -/// TODO: Transforms for truncated shifts should be moved into here. Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + unsigned DestWidth = DestTy->getScalarSizeInBits(); + if (!isa(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; @@ -682,7 +684,30 @@ } break; } - + case Instruction::LShr: + case Instruction::AShr: { + // trunc (*shr (trunc A), C) --> trunc(*shr A, C) + Value *A; + Constant *C; + if (match(BinOp0, m_Trunc(m_Value(A))) && match(BinOp1, m_Constant(C))) { + unsigned MaxShiftAmt = SrcWidth - DestWidth; + // If the shift is small enough, all zero/sign bits created by the shift + // are removed by the trunc. + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { + auto *OldShift = cast(Trunc.getOperand(0)); + bool IsExact = OldShift->isExact(); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } + } + break; + } default: break; } @@ -870,26 +895,6 @@ // TODO: Mask high bits with 'and'. } - // trunc (*shr (trunc A), C) --> trunc(*shr A, C) - if (match(Src, m_OneUse(m_Shr(m_Trunc(m_Value(A)), m_Constant(C))))) { - unsigned MaxShiftAmt = SrcWidth - DestWidth; - - // If the shift is small enough, all zero/sign bits created by the shift are - // removed by the trunc. - if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, - APInt(SrcWidth, MaxShiftAmt)))) { - auto *OldShift = cast(Src); - bool IsExact = OldShift->isExact(); - auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); - ShAmt = Constant::mergeUndefsWith(ShAmt, C); - Value *Shift = - OldShift->getOpcode() == Instruction::AShr - ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) - : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); - return CastInst::CreateTruncOrBitCast(Shift, DestTy); - } - } - if (Instruction *I = narrowBinOp(Trunc)) return I;