diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -60,6 +60,7 @@ #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" @@ -1683,6 +1684,54 @@ } } + // Try to reassociate to sink a splat shuffle after a binary operation. + if (Inst.isAssociative() && Inst.isCommutative()) { + // Canonicalize shuffle operand as LHS. + if (auto *ShufR = dyn_cast(RHS)) + std::swap(LHS, RHS); + + Value *X; + Constant *MaskC; + const APInt *SplatIndex; + BinaryOperator *BO; + if (!match(LHS, m_OneUse(m_ShuffleVector(m_Value(X), m_Undef(), + m_Constant(MaskC)))) || + !match(MaskC, m_APIntAllowUndef(SplatIndex)) || + X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(BO))) || + BO->getOpcode() != Opcode) + return nullptr; + + Value *Y, *OtherOp; + if (isSplatValue(BO->getOperand(0), SplatIndex->getZExtValue())) { + Y = BO->getOperand(0); + OtherOp = BO->getOperand(1); + } else if (isSplatValue(BO->getOperand(1), SplatIndex->getZExtValue())) { + Y = BO->getOperand(1); + OtherOp = BO->getOperand(0); + } else { + return nullptr; + } + + // X and Y are splatted values, so perform the binary operation on those + // values followed by a splat followed by the 2nd binary operation: + // bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp + Value *NewBO = Builder.CreateBinOp(Opcode, X, Y); + UndefValue *Undef = UndefValue::get(Inst.getType()); + Constant *NewMask = ConstantInt::get(MaskC->getType(), *SplatIndex); + Value *NewSplat = Builder.CreateShuffleVector(NewBO, Undef, NewMask); + Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp); + + // Intersect FMF on both new binops. Other (poison-generating) flags are + // dropped to be safe. + if (isa(R)) { + R->copyFastMathFlags(&Inst); + R->andIRFlags(BO); + } + if (auto *NewInstBO = dyn_cast(NewBO)) + NewInstBO->copyIRFlags(R); + return R; + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle.ll b/llvm/test/Transforms/InstCombine/vec_shuffle.ll --- a/llvm/test/Transforms/InstCombine/vec_shuffle.ll +++ b/llvm/test/Transforms/InstCombine/vec_shuffle.ll @@ -1457,9 +1457,9 @@ define <4 x i32> @splat_assoc_add(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @splat_assoc_add( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <4 x i32> [[X:%.*]], <4 x i32> undef, <4 x i32> zeroinitializer -; CHECK-NEXT: [[A:%.*]] = add <4 x i32> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = add <4 x i32> [[SPLATX]], [[A]] +; CHECK-NEXT: [[TMP1:%.*]] = add <4 x i32> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <4 x i32> zeroinitializer +; CHECK-NEXT: [[R:%.*]] = add <4 x i32> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <4 x i32> [[R]] ; %splatx = shufflevector <4 x i32> %x, <4 x i32> undef, <4 x i32> zeroinitializer @@ -1468,11 +1468,13 @@ ret <4 x i32> %r } +; Non-zero splat index; commute operands; FMF intersect + define <2 x float> @splat_assoc_fmul(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fmul( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <2 x float> [[X:%.*]], <2 x float> undef, <2 x i32> -; CHECK-NEXT: [[A:%.*]] = fmul reassoc nsz <2 x float> [[Y:%.*]], -; CHECK-NEXT: [[R:%.*]] = fmul reassoc nnan nsz <2 x float> [[A]], [[SPLATX]] +; CHECK-NEXT: [[TMP1:%.*]] = fmul reassoc nsz <2 x float> [[X:%.*]], +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x float> [[TMP1]], <2 x float> undef, <2 x i32> +; CHECK-NEXT: [[R:%.*]] = fmul reassoc nsz <2 x float> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <2 x float> [[R]] ; %splatx = shufflevector <2 x float> %x, <2 x float> undef, <2 x i32> @@ -1481,12 +1483,13 @@ ret <2 x float> %r } +; Two splat shuffles; drop poison-generating flags + define <3 x i8> @splat_assoc_mul(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ; CHECK-LABEL: @splat_assoc_mul( -; CHECK-NEXT: [[SPLATX:%.*]] = shufflevector <3 x i8> [[X:%.*]], <3 x i8> undef, <3 x i32> -; CHECK-NEXT: [[SPLATZ:%.*]] = shufflevector <3 x i8> [[Z:%.*]], <3 x i8> undef, <3 x i32> -; CHECK-NEXT: [[A:%.*]] = mul nsw <3 x i8> [[SPLATZ]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = mul <3 x i8> [[A]], [[SPLATX]] +; CHECK-NEXT: [[TMP1:%.*]] = mul <3 x i8> [[X:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <3 x i8> [[TMP1]], <3 x i8> undef, <3 x i32> +; CHECK-NEXT: [[R:%.*]] = mul <3 x i8> [[TMP2]], [[Y:%.*]] ; CHECK-NEXT: ret <3 x i8> [[R]] ; %splatx = shufflevector <3 x i8> %x, <3 x i8> undef, <3 x i32> @@ -1496,7 +1499,7 @@ ret <3 x i8> %r } -; Mismatched splat elements +; Negative test - mismatched splat elements define <3 x i8> @splat_assoc_or(<3 x i8> %x, <3 x i8> %y, <3 x i8> %z) { ; CHECK-LABEL: @splat_assoc_or( @@ -1513,7 +1516,7 @@ ret <3 x i8> %r } -; Not associative +; Negative test - not associative define <2 x float> @splat_assoc_fdiv(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fdiv( @@ -1528,7 +1531,7 @@ ret <2 x float> %r } -; Extra use +; Negative test - extra use define <2 x float> @splat_assoc_fadd(<2 x float> %x, <2 x float> %y) { ; CHECK-LABEL: @splat_assoc_fadd( @@ -1545,7 +1548,7 @@ ret <2 x float> %r } -; Narrowing splat +; Negative test - narrowing splat define <3 x i32> @splat_assoc_and(<4 x i32> %x, <3 x i32> %y) { ; CHECK-LABEL: @splat_assoc_and( @@ -1560,7 +1563,7 @@ ret <3 x i32> %r } -; Widening splat +; Negative test - widening splat define <5 x i32> @splat_assoc_xor(<4 x i32> %x, <5 x i32> %y) { ; CHECK-LABEL: @splat_assoc_xor( @@ -1575,7 +1578,7 @@ ret <5 x i32> %r } -; Opcode mismatch +; Negative test - opcode mismatch define <4 x i32> @splat_assoc_add_mul(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @splat_assoc_add_mul( diff --git a/llvm/test/Transforms/LoopVectorize/induction.ll b/llvm/test/Transforms/LoopVectorize/induction.ll --- a/llvm/test/Transforms/LoopVectorize/induction.ll +++ b/llvm/test/Transforms/LoopVectorize/induction.ll @@ -427,7 +427,7 @@ ; UNROLL: %[[i1:.+]] = or i64 %index, 1 ; UNROLL: %[[i2:.+]] = or i64 %index, 2 ; UNROLL: %[[i3:.+]] = or i64 %index, 3 -; UNROLL: %step.add3 = add <2 x i32> %vec.ind2, +; UNROLL: %[[add:.+]]= add <2 x i32> %[[splat:.+]], ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %index, i32 1 ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i1]], i32 1 ; UNROLL: getelementptr inbounds %pair.i16, %pair.i16* %p, i64 %[[i2]], i32 1