diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h --- a/llvm/include/llvm/IR/Constant.h +++ b/llvm/include/llvm/IR/Constant.h @@ -204,6 +204,12 @@ /// Try to replace undefined constant C or undefined elements in C with /// Replacement. If no changes are made, the constant C is returned. static Constant *replaceUndefsWith(Constant *C, Constant *Replacement); + + /// Merges undefs of a Constant with another Constant, along with the + /// undefs already present. Other doesn't have to be the same type as C, but + /// both must either be scalars or vectors with the same element count. If no + /// changes are made, the constant C is returned. + static Constant *mergeUndefsWith(Constant *C, Constant *Other); }; } // end namespace llvm diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -737,6 +737,40 @@ return ConstantVector::get(NewC); } +Constant *Constant::mergeUndefsWith(Constant *C, Constant *Other) { + assert(C && Other && "Expected non-nullptr constant arguments"); + if (match(C, m_Undef())) + return C; + + Type *Ty = C->getType(); + if (match(Other, m_Undef())) + return UndefValue::get(Ty); + + auto *VTy = dyn_cast(Ty); + if (!VTy) + return C; + + Type *EltTy = VTy->getElementType(); + unsigned NumElts = VTy->getNumElements(); + assert(isa(Other->getType()) && + cast(Other->getType())->getNumElements() == NumElts && + "Type mismatch"); + + bool FoundExtraUndef = false; + SmallVector NewC(NumElts); + for (unsigned I = 0; I != NumElts; ++I) { + NewC[I] = C->getAggregateElement(I); + Constant *OtherEltC = Other->getAggregateElement(I); + assert(NewC[I] && OtherEltC && "Unknown vector element"); + if (!match(NewC[I], m_Undef()) && match(OtherEltC, m_Undef())) { + NewC[I] = UndefValue::get(EltTy); + FoundExtraUndef = true; + } + } + if (FoundExtraUndef) + return ConstantVector::get(NewC); + return C; +} //===----------------------------------------------------------------------===// // ConstantInt diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -810,8 +810,6 @@ // If the shift is small enough, all zero bits created by the shift are // removed by the trunc. - // TODO: Support passing through undef shift amounts - these currently get - // clamped to MaxAmt. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { // trunc (lshr (sext A), C) --> ashr A, C @@ -819,6 +817,7 @@ Constant *MaxAmt = ConstantInt::get(SrcTy, DestWidth - 1, false); Constant *ShAmt = ConstantExpr::getUMin(C, MaxAmt); ShAmt = ConstantExpr::getTrunc(ShAmt, A->getType()); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); return IsExact ? BinaryOperator::CreateExactAShr(A, ShAmt) : BinaryOperator::CreateAShr(A, ShAmt); } @@ -841,13 +840,12 @@ // If the shift is small enough, all zero/sign bits created by the shift are // removed by the trunc. - // TODO: Support passing through undef shift amounts - these currently get - // zero'd by getIntegerCast. if (match(C, m_SpecificInt_ICMP(ICmpInst::ICMP_ULE, APInt(SrcWidth, MaxShiftAmt)))) { auto *OldShift = cast(Src); - auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); bool IsExact = OldShift->isExact(); + auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); + ShAmt = Constant::mergeUndefsWith(ShAmt, C); Value *Shift = OldShift->getOpcode() == Instruction::AShr ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) diff --git a/llvm/test/Transforms/InstCombine/cast.ll b/llvm/test/Transforms/InstCombine/cast.ll --- a/llvm/test/Transforms/InstCombine/cast.ll +++ b/llvm/test/Transforms/InstCombine/cast.ll @@ -1570,7 +1570,7 @@ define <2 x i8> @trunc_lshr_sext_uniform_undef(<2 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_uniform_undef( -; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], +; ALL-NEXT: [[D:%.*]] = ashr <2 x i8> [[A:%.*]], ; ALL-NEXT: ret <2 x i8> [[D]] ; %B = sext <2 x i8> %A to <2 x i32> @@ -1592,7 +1592,7 @@ define <3 x i8> @trunc_lshr_sext_nonuniform_undef(<3 x i8> %A) { ; ALL-LABEL: @trunc_lshr_sext_nonuniform_undef( -; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], +; ALL-NEXT: [[D:%.*]] = ashr <3 x i8> [[A:%.*]], ; ALL-NEXT: ret <3 x i8> [[D]] ; %B = sext <3 x i8> %A to <3 x i32> diff --git a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll --- a/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll +++ b/llvm/test/Transforms/InstCombine/trunc-shift-trunc.ll @@ -45,7 +45,7 @@ define <2 x i8> @trunc_lshr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_lshr_trunc_uniform_undef( -; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[C1:%.*]] = lshr <2 x i64> [[A:%.*]], ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ; @@ -131,7 +131,7 @@ define <2 x i8> @trunc_ashr_trunc_uniform_undef(<2 x i64> %a) { ; CHECK-LABEL: @trunc_ashr_trunc_uniform_undef( -; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], +; CHECK-NEXT: [[C1:%.*]] = ashr <2 x i64> [[A:%.*]], ; CHECK-NEXT: [[D:%.*]] = trunc <2 x i64> [[C1]] to <2 x i8> ; CHECK-NEXT: ret <2 x i8> [[D]] ;