diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -959,7 +959,6 @@ static Constant *getAnd(Constant *C1, Constant *C2); static Constant *getOr(Constant *C1, Constant *C2); static Constant *getXor(Constant *C1, Constant *C2); - static Constant *getUMin(Constant *C1, Constant *C2); static Constant *getShl(Constant *C1, Constant *C2, bool HasNUW = false, bool HasNSW = false); static Constant *getLShr(Constant *C1, Constant *C2, bool isExact = false); @@ -1035,6 +1034,13 @@ return getLShr(C1, C2, true); } + /// Return a constant with unsigned clamped values between MinValue and + /// MaxValue inclusive: MinValue <= C <= MaxValue. + /// If PassThruUndefs is true then undefs from C will be returned as undef, + /// otherwise they will be set to MinValue. + static Constant *getUClamp(Constant *C, const APInt &MinValue, + const APInt &MaxValue, bool PassThruUndefs = false); + /// Return the identity constant for a binary opcode. /// The identity constant C is defined as X op C = X and C op X = X for every /// X when the binary operation is commutative. If the binop is not @@ -1108,6 +1114,13 @@ Type *Ty ///< The integer type to cast to ); + /// Return a bool or bool-vector set to true where the element was undef + /// and false otherwise. The type will be a i1 or matching the + /// number of elements in the source type. If the source type is scalable + /// it can only be resolved if C is an UndefValue. + /// May return nullptr if C is an unsupported type. + static Constant *getUndefMask(Constant *C); + /// Return true if this is a convert constant expression bool isCast() const; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -713,6 +713,47 @@ } } +Constant *ConstantExpr::getUndefMask(Constant *C) { + assert(C && "Expected non-nullptr constant argument"); + Type *Ty = C->getType(); + assert((Ty->isIntOrIntVectorTy() || Ty->isFPOrFPVectorTy()) && + "Only int/fp scalar/vectors supported"); + + Type *BoolSclTy = IntegerType::getInt1Ty(C->getContext()); + Type *BoolTy = BoolSclTy; + if (Ty->isVectorTy()) + BoolTy = VectorType::get(BoolSclTy, cast(Ty)); + + // Entire constant is undef - return true; + if (match(C, m_Undef())) + return ConstantInt::getTrue(BoolTy); + + // Scalar isn't undef - return false. + if (!Ty->isVectorTy()) + return ConstantInt::getFalse(BoolTy); + + // Splats contains no undefs - return false; + if (C->getSplatValue(/*AllowUndefs*/ false)) + return ConstantInt::getFalse(BoolTy); + + // Don't know how to deal with this constant. + auto *VTy = dyn_cast(Ty); + if (!VTy) + return nullptr; + + unsigned NumElts = VTy->getNumElements(); + Constant *TrueElt = ConstantInt::getTrue(BoolSclTy); + Constant *FalseElt = ConstantInt::getFalse(BoolSclTy); + SmallVector NewC(NumElts); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *EltC = C->getAggregateElement(i); + if (!EltC) + return nullptr; + NewC[i] = match(EltC, m_Undef()) ? TrueElt : FalseElt; + } + return ConstantVector::get(NewC); +} + Constant *Constant::replaceUndefsWith(Constant *C, Constant *Replacement) { assert(C && Replacement && "Expected non-nullptr constant arguments"); Type *Ty = C->getType(); @@ -2560,11 +2601,6 @@ return get(Instruction::Xor, C1, C2); } -Constant *ConstantExpr::getUMin(Constant *C1, Constant *C2) { - Constant *Cmp = ConstantExpr::getICmp(CmpInst::ICMP_ULT, C1, C2); - return getSelect(Cmp, C1, C2); -} - Constant *ConstantExpr::getShl(Constant *C1, Constant *C2, bool HasNUW, bool HasNSW) { unsigned Flags = (HasNUW ? OverflowingBinaryOperator::NoUnsignedWrap : 0) | @@ -2582,6 +2618,30 @@ isExact ? PossiblyExactOperator::IsExact : 0); } +Constant *ConstantExpr::getUClamp(Constant *C, const APInt &MinValue, + const APInt &MaxValue, bool PassThruUndefs) { + assert(MinValue.ule(MaxValue) && "Illegal clamp range"); + Type *Ty = C->getType(); + Constant *MinC = ConstantInt::get(Ty, MinValue); + Constant *MaxC = ConstantInt::get(Ty, MaxValue); + + // By using UGE then UndefValue in C will be returned as true, + // we then select with MinC as the true value guaranteeing that they are set. + Constant *Result = C; + Constant *MinCmp = ConstantExpr::getICmp(CmpInst::ICMP_UGE, MinC, Result); + Result = getSelect(MinCmp, MinC, Result); + Constant *MaxCmp = ConstantExpr::getICmp(CmpInst::ICMP_ULE, MaxC, Result); + Result = getSelect(MaxCmp, MaxC, Result); + + if (PassThruUndefs) { + Constant *IsUndef = ConstantExpr::getUndefMask(C); + assert(IsUndef && "Unknown undefs"); + Result = getSelect(IsUndef, UndefValue::get(Ty), Result); + } + + return Result; +} + Constant *ConstantExpr::getBinOpIdentity(unsigned Opcode, Type *Ty, bool AllowRHSConstant) { assert(Instruction::isBinaryOp(Opcode) && "Only binops allowed"); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -811,14 +811,13 @@ // If the shift is small enough, all zero bits created by the shift are // removed by the trunc. - // TODO: Support passing through undef shift amounts - these currently get - // clamped to MaxAmt. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { // trunc (lshr (sext A), C) --> ashr A, C if (A->getType() == DestTy) { - Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); - Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + Constant *ShAmt = + ConstantExpr::getUClamp(C, APInt::getNullValue(SrcWidth), + APInt(SrcWidth, DestWidth - 1), true); ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) : BinaryOperator::CreateAShr(A, ShAmt); @@ -826,8 +825,9 @@ // The types are mismatched, so create a cast after shifting: // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) if (Src->hasOneUse()) { - Constant *MaxAmt = ConstantInt::get(SrcTy, AWidth - 1, false); - Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); + Constant *ShAmt = + ConstantExpr::getUClamp(C, APInt::getNullValue(SrcWidth), + APInt(SrcWidth, AWidth - 1), true); ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); Value *Shift = Builder.CreateAShr(A, ShAmt, "", IsExact); return CastInst::CreateIntegerCast(Shift, DestTy, true); @@ -842,13 +842,14 @@ // If the shift is small enough, all zero/sign bits created by the shift are // removed by the trunc. - // TODO: Support passing through undef shift amounts - these currently get - // zero'd by getIntegerCast. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { auto *OldShift = cast(Src); - auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); bool IsExact = OldShift->isExact(); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + if (auto *IsUndef = ConstantExpr::getUndefMask(C)) + ShAmt = ConstantExpr::getSelect(IsUndef, UndefValue::get(A->getType()), + ShAmt); Value *Shift = OldShift->getOpcode() == Instruction::AShr ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll --- a/llvm/test/Transforms/InstCombine/cast.ll +++ b/llvm/test/Transforms/InstCombine/cast.ll @@ -1570,7 +1570,7 @@ define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_uniform_undef( -; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], +; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], ; ALL-NEXT: ret <2 x i8> [[D]] ; %B = sext <2 x i8> %A to <2 x i32> @@ -1592,7 +1592,7 @@ define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef( -; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], +; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], ; ALL-NEXT: ret <3 x i8> [[D]] ; %B = sext <3 x i8> %A to <3 x i32> diff --git a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll --- a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll +++ b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll @@ -45,7 +45,7 @@ define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef( -; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; @@ -131,7 +131,7 @@ define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef( -; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ;