diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -690,6 +690,7 @@ Value *Src = CI.getOperand(0); Type *DestTy = CI.getType(), *SrcTy = Src->getType(); + ConstantInt *Cst; // Attempt to truncate the entire input expression tree to the destination // type. Only do this if the dest type is a simple type, don't convert the @@ -758,7 +759,7 @@ // more efficiently. Support vector types. Cleanup code by using m_OneUse. // Transform trunc(lshr (zext A), Cst) to eliminate one type conversion. - Value *A = nullptr; ConstantInt *Cst = nullptr; + Value *A = nullptr; if (Src->hasOneUse() && match(Src, m_LShr(m_ZExt(m_Value(A)), m_ConstantInt(Cst)))) { // We have three types to worry about here, the type of A, the source of @@ -843,6 +844,38 @@ if (Instruction *I = foldVecTruncToExtElt(CI, *this)) return I; + // Whenever an element is extracted from a vector, and then truncated, + // canonicalize by converting it to a bitcast followed by an + // extractelement. + // + // Example (little endian): + // trunc (extractelement <4 x i64> %X, 0) to i32 + // ---> + // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 + Value *VecOp; + if (match(Src, + m_OneUse(m_ExtractElement(m_Value(VecOp), m_ConstantInt(Cst))))) { + Type *VecOpTy = VecOp->getType(); + unsigned DestScalarSize = DestTy->getScalarSizeInBits(); + unsigned VecOpScalarSize = VecOpTy->getScalarSizeInBits(); + unsigned VecNumElts = VecOpTy->getVectorNumElements(); + + // A badly fit destination size would result in an invalid cast. + if (VecOpScalarSize % DestScalarSize == 0) { + unsigned TruncRatio = VecOpScalarSize / DestScalarSize; + unsigned BitCastNumElts = VecNumElts * TruncRatio; + unsigned VecOpIdx = Cst->getZExtValue(); + unsigned NewIdx = + DL.isBigEndian() + ? (VecOpIdx + 1) * TruncRatio - 1 + : VecOpIdx * TruncRatio; + + Type *BitCastTo = VectorType::get(DestTy, BitCastNumElts); + Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo); + return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx)); + } + } + return nullptr; }