diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -228,6 +228,11 @@ return getTypeID() == ScalableVectorTyID || getTypeID() == FixedVectorTyID; } + /// True if this is an instance of ScalableVectorType. + inline bool isScalableVectorTy() const { + return getTypeID() == ScalableVectorTyID; + } + /// Return true if this type could be converted with a lossless BitCast to /// type 'Ty'. For example, i8* to i32*. BitCasts are valid for types of the /// same size only where no re-interpretation of the bits is done. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -2363,13 +2363,20 @@ // The bitcast must be to a vectorizable type, otherwise we can't make a new // type to extract from. Type *DestType = BitCast.getType(); - if (!VectorType::isValidElementType(DestType)) - return nullptr; + VectorType *VecType = cast(VecOp->getType()); + if (VectorType::isValidElementType(DestType)) { + auto *NewVecType = VectorType::get(DestType, VecType); + auto *NewBC = IC.Builder.CreateBitCast(VecOp, NewVecType, "bc"); + return ExtractElementInst::Create(NewBC, Index); + } - auto *NewVecType = - VectorType::get(DestType, cast(VecOp->getType())); - auto *NewBC = IC.Builder.CreateBitCast(VecOp, NewVecType, "bc"); - return ExtractElementInst::Create(NewBC, Index); + // Only solve DestType is vector to avoid inverse transform in visitBitCast. + // bitcast (extractelement <1 x elt>, dest) -> bitcast(<1 x elt>, dest) + if (DestType->isVectorTy() && !VecType->isScalableVectorTy() && + cast(VecType)->getNumElements() == 1) + return CastInst::Create(Instruction::BitCast, VecOp, DestType); + + return nullptr; } /// Change the type of a bitwise logic operation if we can eliminate a bitcast. diff --git a/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll b/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll --- a/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll +++ b/llvm/test/Transforms/InstCombine/bitcast-inseltpoison.ll @@ -353,14 +353,9 @@ ret i64 %bc2 } -; TODO: This should return %A. - define <2 x i32> @bitcast_extelt3(<2 x i32> %A) { ; CHECK-LABEL: @bitcast_extelt3( -; CHECK-NEXT: [[BC1:%.*]] = bitcast <2 x i32> [[A:%.*]] to <1 x i64> -; CHECK-NEXT: [[EXT:%.*]] = extractelement <1 x i64> [[BC1]], i64 0 -; CHECK-NEXT: [[BC2:%.*]] = bitcast i64 [[EXT]] to <2 x i32> -; CHECK-NEXT: ret <2 x i32> [[BC2]] +; CHECK-NEXT: ret <2 x i32> [[A:%.*]] ; %bc1 = bitcast <2 x i32> %A to <1 x i64> %ext = extractelement <1 x i64> %bc1, i32 0 diff --git a/llvm/test/Transforms/InstCombine/bitcast.ll b/llvm/test/Transforms/InstCombine/bitcast.ll --- a/llvm/test/Transforms/InstCombine/bitcast.ll +++ b/llvm/test/Transforms/InstCombine/bitcast.ll @@ -402,14 +402,9 @@ ret i64 %bc2 } -; TODO: This should return %A. - define <2 x i32> @bitcast_extelt3(<2 x i32> %A) { ; CHECK-LABEL: @bitcast_extelt3( -; CHECK-NEXT: [[BC1:%.*]] = bitcast <2 x i32> [[A:%.*]] to <1 x i64> -; CHECK-NEXT: [[EXT:%.*]] = extractelement <1 x i64> [[BC1]], i64 0 -; CHECK-NEXT: [[BC2:%.*]] = bitcast i64 [[EXT]] to <2 x i32> -; CHECK-NEXT: ret <2 x i32> [[BC2]] +; CHECK-NEXT: ret <2 x i32> [[A:%.*]] ; %bc1 = bitcast <2 x i32> %A to <1 x i64> %ext = extractelement <1 x i64> %bc1, i32 0 @@ -433,8 +428,7 @@ define <2 x i32> @bitcast_extelt5(<1 x i64> %A) { ; CHECK-LABEL: @bitcast_extelt5( -; CHECK-NEXT: [[EXT:%.*]] = extractelement <1 x i64> [[A:%.*]], i64 0 -; CHECK-NEXT: [[BC:%.*]] = bitcast i64 [[EXT]] to <2 x i32> +; CHECK-NEXT: [[BC:%.*]] = bitcast <1 x i64> [[A:%.*]] to <2 x i32> ; CHECK-NEXT: ret <2 x i32> [[BC]] ; %ext = extractelement <1 x i64> %A, i32 0