diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -8640,16 +8640,101 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { bool Changed = false; + // Sort by type, base pointers and values operand. Value operands must be + // compatible (have the same opcode, same parent), otherwise it is + // definitely not profitable to try to vectorize them. + auto &&StoreSorter = [this](StoreInst *V, StoreInst *V2) { + if (V->getPointerOperandType()->getTypeID() < + V2->getPointerOperandType()->getTypeID()) + return true; + if (V->getPointerOperandType()->getTypeID() > + V2->getPointerOperandType()->getTypeID()) + return false; + // UndefValues are compatible with all other values. + if (isa(V->getValueOperand()) || + isa(V2->getValueOperand())) + return false; + if (auto *I1 = dyn_cast(V->getValueOperand())) + if (auto *I2 = dyn_cast(V2->getValueOperand())) { + auto *NodeI1 = DT->getNode(I1->getParent()); + auto *NodeI2 = DT->getNode(I2->getParent()); + assert(NodeI1 && "Should only process reachable instructions"); + assert(NodeI1 && "Should only process reachable instructions"); + assert((NodeI1 == NodeI2) == + (NodeI1->getDFSNumIn() == NodeI2->getDFSNumIn()) && + "Different nodes should have different DFS numbers"); + if (NodeI1 != NodeI2) + return NodeI1->getDFSNumIn() < NodeI2->getDFSNumIn(); + InstructionsState S = getSameOpcode({I1, I2}); + if (S.getOpcode()) + return false; + return I1->getOpcode() < I2->getOpcode(); + } + if (isa(V->getValueOperand()) && + isa(V2->getValueOperand())) + return false; + return V->getValueOperand()->getValueID() < + V2->getValueOperand()->getValueID(); + }; + + auto &&AreCompatibleStores = [](StoreInst *V1, StoreInst *V2) { + if (V1 == V2) + return true; + if (V1->getPointerOperandType() != V2->getPointerOperandType()) + return false; + // Undefs are compatible with any other value. + if (isa(V1->getValueOperand()) || + isa(V2->getValueOperand())) + return true; + if (auto *I1 = dyn_cast(V1->getValueOperand())) + if (auto *I2 = dyn_cast(V2->getValueOperand())) { + if (I1->getParent() != I2->getParent()) + return false; + InstructionsState S = getSameOpcode({I1, I2}); + return S.getOpcode() > 0; + } + if (isa(V1->getValueOperand()) && + isa(V2->getValueOperand())) + return true; + return V1->getValueOperand()->getValueID() == + V2->getValueOperand()->getValueID(); + }; + // Attempt to sort and vectorize each of the store-groups. - for (StoreListMap::iterator it = Stores.begin(), e = Stores.end(); it != e; - ++it) { - if (it->second.size() < 2) + for (auto &Pair : Stores) { + if (Pair.second.size() < 2) continue; LLVM_DEBUG(dbgs() << "SLP: Analyzing a store chain of length " - << it->second.size() << ".\n"); + << Pair.second.size() << ".\n"); + + stable_sort(Pair.second, StoreSorter); + + // Try to vectorize elements based on their compatibility. + for (ArrayRef::iterator IncIt = Pair.second.begin(), + E = Pair.second.end(); + IncIt != E;) { + + // Look for the next elements with the same type. + ArrayRef::iterator SameTypeIt = IncIt; + Type *EltTy = (*IncIt)->getPointerOperand()->getType(); - Changed |= vectorizeStores(it->second, R); + while (SameTypeIt != E && AreCompatibleStores(*SameTypeIt, *IncIt)) + ++SameTypeIt; + + // Try to vectorize them. + unsigned NumElts = (SameTypeIt - IncIt); + LLVM_DEBUG(dbgs() << "SLP: Trying to vectorize starting at stores (" + << NumElts << ")\n"); + if (NumElts > 1 && !EltTy->getPointerElementType()->isVectorTy() && + vectorizeStores(makeArrayRef(IncIt, NumElts), R)) { + // Success start over because instructions might have been changed. + Changed = true; + } + + // Start over at the next instruction of a different type (or the end). + IncIt = SameTypeIt; + } } return Changed; } diff --git a/llvm/test/Transforms/SLPVectorizer/X86/stores-non-ordered.ll b/llvm/test/Transforms/SLPVectorizer/X86/stores-non-ordered.ll --- a/llvm/test/Transforms/SLPVectorizer/X86/stores-non-ordered.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/stores-non-ordered.ll @@ -4,24 +4,23 @@ define i32 @non-ordered-stores(i32* noalias nocapture %in, i32* noalias nocapture %inn, i32* noalias nocapture %out) { ; CHECK-LABEL: @non-ordered-stores( ; CHECK-NEXT: [[IN_ADDR:%.*]] = getelementptr inbounds i32, i32* [[IN:%.*]], i64 0 -; CHECK-NEXT: [[LOAD_1:%.*]] = load i32, i32* [[IN_ADDR]], align 4 ; CHECK-NEXT: [[GEP_2:%.*]] = getelementptr inbounds i32, i32* [[IN_ADDR]], i64 2 -; CHECK-NEXT: [[LOAD_3:%.*]] = load i32, i32* [[GEP_2]], align 4 ; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32*> poison, i32* [[IN_ADDR]], i32 0 -; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x i32*> [[TMP1]], i32* [[IN_ADDR]], i32 1 -; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i32, <2 x i32*> [[TMP2]], <2 x i64> -; CHECK-NEXT: [[TMP4:%.*]] = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> [[TMP3]], i32 4, <2 x i1> , <2 x i32> undef) +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <2 x i32*> [[TMP1]], i32* [[GEP_2]], i32 1 +; CHECK-NEXT: [[TMP3:%.*]] = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> [[TMP2]], i32 4, <2 x i1> , <2 x i32> undef) +; CHECK-NEXT: [[TMP4:%.*]] = insertelement <2 x i32*> [[TMP1]], i32* [[IN_ADDR]], i32 1 +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i32, <2 x i32*> [[TMP4]], <2 x i64> +; CHECK-NEXT: [[TMP6:%.*]] = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> [[TMP5]], i32 4, <2 x i1> , <2 x i32> undef) ; CHECK-NEXT: [[INN_ADDR:%.*]] = getelementptr inbounds i32, i32* [[INN:%.*]], i64 0 -; CHECK-NEXT: [[LOAD_5:%.*]] = load i32, i32* [[INN_ADDR]], align 4 ; CHECK-NEXT: [[GEP_5:%.*]] = getelementptr inbounds i32, i32* [[INN_ADDR]], i64 2 -; CHECK-NEXT: [[LOAD_7:%.*]] = load i32, i32* [[GEP_5]], align 4 -; CHECK-NEXT: [[TMP5:%.*]] = insertelement <2 x i32*> poison, i32* [[INN_ADDR]], i32 0 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <2 x i32*> [[TMP5]], i32* [[INN_ADDR]], i32 1 -; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i32, <2 x i32*> [[TMP6]], <2 x i64> -; CHECK-NEXT: [[TMP8:%.*]] = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> [[TMP7]], i32 4, <2 x i1> , <2 x i32> undef) -; CHECK-NEXT: [[MUL_1:%.*]] = mul i32 [[LOAD_1]], [[LOAD_5]] -; CHECK-NEXT: [[MUL_3:%.*]] = mul i32 [[LOAD_3]], [[LOAD_7]] -; CHECK-NEXT: [[TMP9:%.*]] = mul <2 x i32> [[TMP4]], [[TMP8]] +; CHECK-NEXT: [[TMP7:%.*]] = insertelement <2 x i32*> poison, i32* [[INN_ADDR]], i32 0 +; CHECK-NEXT: [[TMP8:%.*]] = insertelement <2 x i32*> [[TMP7]], i32* [[GEP_5]], i32 1 +; CHECK-NEXT: [[TMP9:%.*]] = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> [[TMP8]], i32 4, <2 x i1> , <2 x i32> undef) +; CHECK-NEXT: [[TMP10:%.*]] = insertelement <2 x i32*> [[TMP7]], i32* [[INN_ADDR]], i32 1 +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i32, <2 x i32*> [[TMP10]], <2 x i64> +; CHECK-NEXT: [[TMP12:%.*]] = call <2 x i32> @llvm.masked.gather.v2i32.v2p0i32(<2 x i32*> [[TMP11]], i32 4, <2 x i1> , <2 x i32> undef) +; CHECK-NEXT: [[TMP13:%.*]] = mul <2 x i32> [[TMP3]], [[TMP9]] +; CHECK-NEXT: [[TMP14:%.*]] = mul <2 x i32> [[TMP6]], [[TMP12]] ; CHECK-NEXT: br label [[BLOCK1:%.*]] ; CHECK: block1: ; CHECK-NEXT: [[GEP_X:%.*]] = getelementptr inbounds i32, i32* [[INN_ADDR]], i64 5 @@ -33,11 +32,11 @@ ; CHECK-NEXT: [[GEP_9:%.*]] = getelementptr inbounds i32, i32* [[OUT]], i64 2 ; CHECK-NEXT: [[GEP_10:%.*]] = getelementptr inbounds i32, i32* [[OUT]], i64 3 ; CHECK-NEXT: [[GEP_11:%.*]] = getelementptr inbounds i32, i32* [[OUT]], i64 4 -; CHECK-NEXT: store i32 [[MUL_1]], i32* [[GEP_10]], align 4 ; CHECK-NEXT: store i32 [[LOAD_9]], i32* [[GEP_9]], align 4 -; CHECK-NEXT: store i32 [[MUL_3]], i32* [[GEP_11]], align 4 -; CHECK-NEXT: [[TMP10:%.*]] = bitcast i32* [[GEP_7]] to <2 x i32>* -; CHECK-NEXT: store <2 x i32> [[TMP9]], <2 x i32>* [[TMP10]], align 4 +; CHECK-NEXT: [[TMP15:%.*]] = bitcast i32* [[GEP_10]] to <2 x i32>* +; CHECK-NEXT: store <2 x i32> [[TMP13]], <2 x i32>* [[TMP15]], align 4 +; CHECK-NEXT: [[TMP16:%.*]] = bitcast i32* [[GEP_7]] to <2 x i32>* +; CHECK-NEXT: store <2 x i32> [[TMP14]], <2 x i32>* [[TMP16]], align 4 ; CHECK-NEXT: ret i32 undef ; %in.addr = getelementptr inbounds i32, i32* %in, i64 0