diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1407,6 +1407,41 @@ return nullptr; } +/// If both the base vector and the inserted element are extended from the same +/// type, do the insert element in the narrow source type followed by extend. +/// TODO: This can be extended to include other cast opcodes, but particularly +/// if we create a wider insertelement, make sure codegen is not harmed. +static Instruction *narrowInsElt(InsertElementInst &InsElt, + InstCombiner::BuilderTy &Builder) { + // We are creating a vector extend. If the original vector extend has another + // use, that would mean we end up with 2 vector extends, so avoid that. + // TODO: We could ease the use-clause to "if at least one op has one use" + // (assuming that the source types match - see next TODO comment). + Value *Vec = InsElt.getOperand(0); + if (!Vec->hasOneUse()) + return nullptr; + + Value *Scalar = InsElt.getOperand(1); + Value *X, *Y; + CastInst::CastOps CastOpcode; + if (match(Vec, m_FPExt(m_Value(X))) && match(Scalar, m_FPExt(m_Value(Y)))) + CastOpcode = Instruction::FPExt; + else if (match(Vec, m_SExt(m_Value(X))) && match(Scalar, m_SExt(m_Value(Y)))) + CastOpcode = Instruction::SExt; + else if (match(Vec, m_ZExt(m_Value(X))) && match(Scalar, m_ZExt(m_Value(Y)))) + CastOpcode = Instruction::ZExt; + else + return nullptr; + + // TODO: We can allow mismatched types by creating an intermediate cast. + if (X->getType()->getScalarType() != Y->getType()) + return nullptr; + + // inselt (ext X), (ext Y), Index --> ext (inselt X, Y, Index) + Value *NewInsElt = Builder.CreateInsertElement(X, Y, InsElt.getOperand(2)); + return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType()); +} + Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -1526,6 +1561,9 @@ if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE)) return IdentityShuf; + if (Instruction *Ext = narrowInsElt(IE, Builder)) + return Ext; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/insert-ext.ll b/llvm/test/Transforms/InstCombine/insert-ext.ll --- a/llvm/test/Transforms/InstCombine/insert-ext.ll +++ b/llvm/test/Transforms/InstCombine/insert-ext.ll @@ -6,9 +6,8 @@ define <2 x double> @fpext_fpext(<2 x half> %x, half %y, i32 %index) { ; CHECK-LABEL: @fpext_fpext( -; CHECK-NEXT: [[V:%.*]] = fpext <2 x half> [[X:%.*]] to <2 x double> -; CHECK-NEXT: [[S:%.*]] = fpext half [[Y:%.*]] to double -; CHECK-NEXT: [[I:%.*]] = insertelement <2 x double> [[V]], double [[S]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x half> [[X:%.*]], half [[Y:%.*]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[I:%.*]] = fpext <2 x half> [[TMP1]] to <2 x double> ; CHECK-NEXT: ret <2 x double> [[I]] ; %v = fpext <2 x half> %x to <2 x double> @@ -19,9 +18,8 @@ define <2 x i32> @sext_sext(<2 x i8> %x, i8 %y, i32 %index) { ; CHECK-LABEL: @sext_sext( -; CHECK-NEXT: [[V:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32> -; CHECK-NEXT: [[S:%.*]] = sext i8 [[Y:%.*]] to i32 -; CHECK-NEXT: [[I:%.*]] = insertelement <2 x i32> [[V]], i32 [[S]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i8> [[X:%.*]], i8 [[Y:%.*]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[I:%.*]] = sext <2 x i8> [[TMP1]] to <2 x i32> ; CHECK-NEXT: ret <2 x i32> [[I]] ; %v = sext <2 x i8> %x to <2 x i32> @@ -32,9 +30,8 @@ define <2 x i12> @zext_zext(<2 x i8> %x, i8 %y, i32 %index) { ; CHECK-LABEL: @zext_zext( -; CHECK-NEXT: [[V:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i12> -; CHECK-NEXT: [[S:%.*]] = zext i8 [[Y:%.*]] to i12 -; CHECK-NEXT: [[I:%.*]] = insertelement <2 x i12> [[V]], i12 [[S]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i8> [[X:%.*]], i8 [[Y:%.*]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[I:%.*]] = zext <2 x i8> [[TMP1]] to <2 x i12> ; CHECK-NEXT: ret <2 x i12> [[I]] ; %v = zext <2 x i8> %x to <2 x i12> @@ -43,6 +40,8 @@ ret <2 x i12> %i } +; negative test - need same source type + define <2 x double> @fpext_fpext_types(<2 x half> %x, float %y, i32 %index) { ; CHECK-LABEL: @fpext_fpext_types( ; CHECK-NEXT: [[V:%.*]] = fpext <2 x half> [[X:%.*]] to <2 x double> @@ -56,6 +55,8 @@ ret <2 x double> %i } +; negative test - need same source type + define <2 x i32> @sext_sext_types(<2 x i16> %x, i8 %y, i32 %index) { ; CHECK-LABEL: @sext_sext_types( ; CHECK-NEXT: [[V:%.*]] = sext <2 x i16> [[X:%.*]] to <2 x i32> @@ -69,6 +70,8 @@ ret <2 x i32> %i } +; negative test - need same extend opcode + define <2 x i12> @sext_zext(<2 x i8> %x, i8 %y, i32 %index) { ; CHECK-LABEL: @sext_zext( ; CHECK-NEXT: [[V:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i12> @@ -82,6 +85,8 @@ ret <2 x i12> %i } +; negative test - don't trade scalar extend for vector extend + define <2 x i32> @sext_sext_use1(<2 x i8> %x, i8 %y, i32 %index) { ; CHECK-LABEL: @sext_sext_use1( ; CHECK-NEXT: [[V:%.*]] = sext <2 x i8> [[X:%.*]] to <2 x i32> @@ -99,10 +104,10 @@ define <2 x i32> @zext_zext_use2(<2 x i8> %x, i8 %y, i32 %index) { ; CHECK-LABEL: @zext_zext_use2( -; CHECK-NEXT: [[V:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32> ; CHECK-NEXT: [[S:%.*]] = zext i8 [[Y:%.*]] to i32 ; CHECK-NEXT: call void @use(i32 [[S]]) -; CHECK-NEXT: [[I:%.*]] = insertelement <2 x i32> [[V]], i32 [[S]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i8> [[X:%.*]], i8 [[Y]], i32 [[INDEX:%.*]] +; CHECK-NEXT: [[I:%.*]] = zext <2 x i8> [[TMP1]] to <2 x i32> ; CHECK-NEXT: ret <2 x i32> [[I]] ; %v = zext <2 x i8> %x to <2 x i32> @@ -112,6 +117,8 @@ ret <2 x i32> %i } +; negative test - don't create an extra extend + define <2 x i32> @zext_zext_use3(<2 x i8> %x, i8 %y, i32 %index) { ; CHECK-LABEL: @zext_zext_use3( ; CHECK-NEXT: [[V:%.*]] = zext <2 x i8> [[X:%.*]] to <2 x i32>