diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -2358,13 +2358,19 @@ // The bitcast must be to a vectorizable type, otherwise we can't make a new // type to extract from. Type *DestType = BitCast.getType(); - if (!VectorType::isValidElementType(DestType)) - return nullptr; + if (VectorType::isValidElementType(DestType)) { + auto *NewVecType = + VectorType::get(DestType, cast(VecOp->getType())); + auto *NewBC = IC.Builder.CreateBitCast(VecOp, NewVecType, "bc"); + return ExtractElementInst::Create(NewBC, Index); + } - auto *NewVecType = - VectorType::get(DestType, cast(VecOp->getType())); - auto *NewBC = IC.Builder.CreateBitCast(VecOp, NewVecType, "bc"); - return ExtractElementInst::Create(NewBC, Index); + // bitcast (extractelement (bitcast x), index) -> X + Value *X; + if (match(VecOp, m_OneUse(m_BitCast(m_Value(X)))) && X->getType() == DestType) + return IC.replaceInstUsesWith(BitCast, X); + + return nullptr; } /// Change the type of a bitwise logic operation if we can eliminate a bitcast. diff --git a/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll b/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll --- a/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll +++ b/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll @@ -358,14 +358,9 @@ ret i64 %bc2 } -; TODO: This should return %A. - define <2 x i32> @bitcast_extelt3(<2 x i32> %A) { ; CHECK-LABEL: @bitcast_extelt3( -; CHECK-NEXT: [[BC1:%.*]] = bitcast <2 x i32> [[A:%.*]] to <1 x i64> -; CHECK-NEXT: [[EXT:%.*]] = extractelement <1 x i64> [[BC1]], i64 0 -; CHECK-NEXT: [[BC2:%.*]] = bitcast i64 [[EXT]] to <2 x i32> -; CHECK-NEXT: ret <2 x i32> [[BC2]] +; CHECK-NEXT: ret <2 x i32> [[A:%.*]] ; %bc1 = bitcast <2 x i32> %A to <1 x i64> %ext = extractelement <1 x i64> %bc1, i32 0 diff --git a/llvm/test/Transforms/InstCombine/bitcast.ll b/llvm/test/Transforms/InstCombine/bitcast.ll --- a/llvm/test/Transforms/InstCombine/bitcast.ll +++ b/llvm/test/Transforms/InstCombine/bitcast.ll @@ -358,14 +358,9 @@ ret i64 %bc2 } -; TODO: This should return %A. - define <2 x i32> @bitcast_extelt3(<2 x i32> %A) { ; CHECK-LABEL: @bitcast_extelt3( -; CHECK-NEXT: [[BC1:%.*]] = bitcast <2 x i32> [[A:%.*]] to <1 x i64> -; CHECK-NEXT: [[EXT:%.*]] = extractelement <1 x i64> [[BC1]], i64 0 -; CHECK-NEXT: [[BC2:%.*]] = bitcast i64 [[EXT]] to <2 x i32> -; CHECK-NEXT: ret <2 x i32> [[BC2]] +; CHECK-NEXT: ret <2 x i32> [[A:%.*]] ; %bc1 = bitcast <2 x i32> %A to <1 x i64> %ext = extractelement <1 x i64> %bc1, i32 0