diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -959,6 +959,7 @@ static Constant *getAnd(Constant *C1, Constant *C2); static Constant *getOr(Constant *C1, Constant *C2); static Constant *getXor(Constant *C1, Constant *C2); + static Constant *getUMin(Constant *C1, Constant *C2); static Constant *getShl(Constant *C1, Constant *C2, bool HasNUW = false, bool HasNSW = false); static Constant *getLShr(Constant *C1, Constant *C2, bool isExact = false); diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -2560,6 +2560,11 @@ return get(Instruction::Xor, C1, C2); } +Constant *ConstantExpr::getUMin(Constant *C1, Constant *C2) { + Constant *Cmp = ConstantExpr::getICmp(CmpInst::ICMP_ULT, C1, C2); + return getSelect(Cmp, C1, C2); +} + Constant *ConstantExpr::getShl(Constant *C1, Constant *C2, bool HasNUW, bool HasNSW) { unsigned Flags = (HasNUW ? OverflowingBinaryOperator::NoUnsignedWrap : 0) | diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -827,23 +827,30 @@ return CastInst::CreateIntegerCast(Shift, DestTy, false); } - const APInt *C; - if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) { + Constant *C; + if (match(Src, m_LShr(m_SExt(m_Value(A)), m_Constant(C)))) { unsigned AWidth = A->getType()->getScalarSizeInBits(); unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth); // If the shift is small enough, all zero bits created by the shift are // removed by the trunc. - if (C->getZExtValue() <= MaxShiftAmt) { + // TODO: Support passing through undef shift amounts - these currently get + // clamped to MaxAmt. + if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, + APInt(SrcWidth, MaxShiftAmt)))) { // trunc (lshr (sext A), C) --> ashr A, C if (A->getType() == DestTy) { - unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1); - return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt)); + Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); + Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + return BinaryOperator::CreateAShr(A, ShAmt); } // The types are mismatched, so create a cast after shifting: // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) if (Src->hasOneUse()) { - unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1); + Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false); + Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); Value *Shift = Builder.CreateAShr(A, ShAmt); return CastInst::CreateIntegerCast(Shift, DestTy, true); } diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll --- a/llvm/test/Transforms/InstCombine/cast.ll +++ b/llvm/test/Transforms/InstCombine/cast.ll @@ -1559,9 +1559,7 @@ define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_uniform_undef( -; ALL-NEXT: [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32> -; ALL-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; ALL-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], ; ALL-NEXT: ret <2 x i8> [[D]] ; %B = sext <2 x i8> %A to <2 x i32> @@ -1572,9 +1570,7 @@ define <2 x i8> @trunc_lshr_sext_nonuniform(<2 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_nonuniform( -; ALL-NEXT: [[B:%.*]] = sext <2 x i8> [[A:%.*]] to <2 x i32> -; ALL-NEXT: [[C:%.*]] = lshr <2 x i32> [[B]], -; ALL-NEXT: [[D:%.*]] = trunc <2 x i32> [[C]] to <2 x i8> +; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], ; ALL-NEXT: ret <2 x i8> [[D]] ; %B = sext <2 x i8> %A to <2 x i32> @@ -1585,9 +1581,7 @@ define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef( -; ALL-NEXT: [[B:%.*]] = sext <3 x i8> [[A:%.*]] to <3 x i32> -; ALL-NEXT: [[C:%.*]] = lshr <3 x i32> [[B]], -; ALL-NEXT: [[D:%.*]] = trunc <3 x i32> [[C]] to <3 x i8> +; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], ; ALL-NEXT: ret <3 x i8> [[D]] ; %B = sext <3 x i8> %A to <3 x i32>