diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -34,6 +34,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopAccessAnalysis.h" @@ -84,8 +85,11 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/InjectTLIMappings.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Vectorize.h" #include #include @@ -659,8 +663,12 @@ } MinBWs.clear(); InstrElementSize.clear(); + MemBounds.clear(); } + /// Did we already version the current block? + bool Versioned = false; + unsigned getTreeSize() const { return VectorizableTree.size(); } /// Perform LICM and CSE on the newly generated gather sequences. @@ -1916,6 +1924,10 @@ /// A list of scalars that we found that we need to keep as scalars. ValueSet MustGather; + // Map of objects to start & end pointers we need to generate runtime checks + // for. + DenseMap> MemBounds; + /// This POD struct describes one external user in the vectorized tree. struct ExternalUser { ExternalUser(Value *S, llvm::User *U, int L) @@ -5283,6 +5295,103 @@ scheduleBlock(BSIter.second.get()); } + // If we collected pointer bounds for memory runtime checks, generate the + // checks. + // TODO: refactor addRuntimeChecks in LoopUtils.cpp so it can be re-used here. + if (!MemBounds.empty()) { + Instruction *VL = VectorizableTree[0]->getMainOp(); + BasicBlock *BB = VL->getParent(); + + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + std::string OriginalName = BB->getName().str(); + auto *CheckBlock = splitBlockBefore(BB, &*BB->getFirstNonPHI(), &DTU, LI, + nullptr, OriginalName + ".slpmemcheck"); + auto *MergeBlock = BB; + BB = splitBlockBefore(BB, BB->getTerminator(), &DTU, LI, nullptr, + OriginalName + ".slpversioned"); + + ValueToValueMapTy VMap; + auto *Scalar = CloneBasicBlock(BB, VMap, "", BB->getParent()); + Scalar->setName(OriginalName + ".scalar"); + MergeBlock->setName(OriginalName + ".merge"); + SmallVector Tmp; + Tmp.push_back(Scalar); + remapInstructionsInBlocks(Tmp, VMap); + + Value *MemoryRuntimeCheck = nullptr; + Instruction *FirstInst = nullptr; + SCEVExpander Exp(*SE, BB->getParent()->getParent()->getDataLayout(), + "memcheck"); + SmallVector, 4> ExpandedBounds; + Type *PtrArithTy = + Type::getInt8PtrTy(BB->getParent()->getParent()->getContext(), 0); + for (auto &KV : MemBounds) { + ExpandedBounds.emplace_back( + Exp.expandCodeFor(KV.second.first, PtrArithTy, + CheckBlock->getTerminator()), + Exp.expandCodeFor(KV.second.second, PtrArithTy, + CheckBlock->getTerminator())); + } + auto GetFirstInst = [](Instruction *FirstInst, Value *V, + Instruction *Loc) -> Instruction * { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; + }; + + Instruction *Loc = CheckBlock->getTerminator(); + LLVMContext &Ctx = VL->getContext(); + IRBuilder<> ChkBuilder(CheckBlock->getTerminator()); + for (unsigned i = 0; i < MemBounds.size(); ++i) { + for (unsigned j = i + 1; j < MemBounds.size(); ++j) { + Value *ALow = ExpandedBounds[i].first; + Value *AHigh = ExpandedBounds[i].second; + Value *BLow = ExpandedBounds[j].first; + Value *BHigh = ExpandedBounds[j].second; + + unsigned AS0 = ALow->getType()->getPointerAddressSpace(); + unsigned AS1 = BLow->getType()->getPointerAddressSpace(); + + Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); + Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); + Value *Start0 = ChkBuilder.CreateBitCast(ALow, PtrArithTy0, "bc"); + Value *Start1 = ChkBuilder.CreateBitCast(BLow, PtrArithTy1, "bc"); + Value *End0 = ChkBuilder.CreateBitCast(AHigh, PtrArithTy1, "bc"); + Value *End1 = ChkBuilder.CreateBitCast(BHigh, PtrArithTy0, "bc"); + // [A|B].Start points to the first accessed byte under base [A|B]. + // [A|B].End points to the last accessed byte, plus one. + // There is no conflict when the intervals are disjoint: + // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) + // + // bound0 = (B.Start < A.End) + // bound1 = (A.Start < B.End) + // IsConflict = bound0 & bound1 + Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); + FirstInst = GetFirstInst(FirstInst, Cmp0, Loc); + Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); + FirstInst = GetFirstInst(FirstInst, Cmp1, Loc); + Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + if (MemoryRuntimeCheck) { + IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, + "conflict.rdx"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + } + MemoryRuntimeCheck = IsConflict; + } + } + + ChkBuilder.CreateCondBr(MemoryRuntimeCheck, BB, Scalar); + CheckBlock->getTerminator()->eraseFromParent(); + DTU.applyUpdates({{DT->Insert, CheckBlock, Scalar}}); + + Versioned = true; + + MemBounds.clear(); + } + Builder.SetInsertPoint(&F->getEntryBlock().front()); auto *VectorRoot = vectorizeTree(VectorizableTree[0].get()); @@ -5842,6 +5951,7 @@ while (DepDest) { assert(isInSchedulingRegion(DepDest)); + ScheduleData *DestBundle = DepDest->FirstInBundle; // We have two limits to reduce the complexity: // 1) AliasedCheckLimit: It's a small limit to reduce calls to // SLP->isAliased (which is the expensive part in this loop). @@ -5859,14 +5969,74 @@ // balance between reduced runtime and accurate dependencies. numAliased++; - DepDest->MemoryDependencies.push_back(BundleMember); - BundleMember->Dependencies++; - ScheduleData *DestBundle = DepDest->FirstInBundle; - if (!DestBundle->IsScheduled) { - BundleMember->incrementUnscheduledDeps(1); + bool CanVersion = false; + // If this bundle is not scheduled and no versioned code has been + // generated yet, try to collect the bounds of the accesses to + // generate runtime checks. + if (!DestBundle->IsScheduled && !SLP->Versioned) { + // FIXME Naming + auto GetPtr = [](Instruction *I) -> Value * { + if (auto *L = dyn_cast(I)) + return L->getPointerOperand(); + if (auto *S = dyn_cast(I)) + return S->getPointerOperand(); + return nullptr; + }; + auto *Src = GetPtr(SrcInst); + auto *Dst = GetPtr(DepDest->Inst); + + auto NoOrSingleSucc = [](BasicBlock *BB) { + return succ_begin(BB) == succ_end(BB) || + std::next(succ_begin(BB)) == succ_end(BB); + }; + auto NoOrSinglePred = [](BasicBlock *BB) { + return pred_begin(BB) == pred_end(BB) || + std::next(pred_begin(BB)) == pred_end(BB); + }; + + if (NoOrSingleSucc(SrcInst->getParent()) && + NoOrSinglePred(SrcInst->getParent()) && + SrcInst->getParent() == DepDest->Inst->getParent() && Src && + Dst) { + Value *SrcObj = getUnderlyingObject(Src); + Value *DstObj = getUnderlyingObject(Dst); + if (SrcObj && SrcObj != DstObj) { + const SCEV *SrcOffset = SLP->SE->getSCEV(Src); + const SCEV *DstOffset = SLP->SE->getSCEV(Dst); + + auto SrcBound = + SLP->MemBounds.insert({SrcObj, {SrcOffset, SrcOffset}}); + if (SLP->SE->isKnownPredicate(CmpInst::ICMP_ULT, SrcOffset, + SrcBound.first->second.first)) + SrcBound.first->second.first = SLP->SE->getSCEV(Src); + if (SLP->SE->isKnownPredicate( + CmpInst::ICMP_UGT, SrcOffset, + SrcBound.first->second.second)) + SrcBound.first->second.second = SLP->SE->getSCEV(Src); + + auto DstBound = + SLP->MemBounds.insert({DstObj, {DstOffset, DstOffset}}); + if (SLP->SE->isKnownPredicate(CmpInst::ICMP_ULT, DstOffset, + DstBound.first->second.first)) + DstBound.first->second.first = SLP->SE->getSCEV(Dst); + if (SLP->SE->isKnownPredicate( + CmpInst::ICMP_UGT, DstOffset, + DstBound.first->second.second)) + DstBound.first->second.second = SLP->SE->getSCEV(Dst); + + CanVersion = true; + } + } } - if (!DestBundle->hasValidDependencies()) { - WorkList.push_back(DestBundle); + if (!CanVersion) { + DepDest->MemoryDependencies.push_back(BundleMember); + BundleMember->Dependencies++; + if (!DestBundle->IsScheduled) { + BundleMember->incrementUnscheduledDeps(1); + } + if (!DestBundle->hasValidDependencies()) { + WorkList.push_back(DestBundle); + } } } DepDest = DepDest->NextLoadStore; @@ -6338,7 +6508,7 @@ return PreservedAnalyses::all(); PreservedAnalyses PA; - PA.preserveSet(); + // PA.preserveSet(); return PA; } diff --git a/llvm/test/Transforms/SLPVectorizer/AArch64/memory-runtime-checks.ll b/llvm/test/Transforms/SLPVectorizer/AArch64/memory-runtime-checks.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/SLPVectorizer/AArch64/memory-runtime-checks.ll @@ -0,0 +1,71 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -slp-vectorizer -mtriple=arm64-apple-darwin -S %s | FileCheck %s + +define void @needs_versioning(i32* %dst, i32* %src) { +; CHECK-LABEL: @needs_versioning( +; CHECK-NEXT: entry.slpmemcheck: +; CHECK-NEXT: [[DST8:%.*]] = bitcast i32* [[DST:%.*]] to i8* +; CHECK-NEXT: [[SCEVGEP:%.*]] = getelementptr i32, i32* [[SRC:%.*]], i64 1 +; CHECK-NEXT: [[SCEVGEP9:%.*]] = bitcast i32* [[SCEVGEP]] to i8* +; CHECK-NEXT: [[BOUND0:%.*]] = icmp ult i8* [[DST8]], [[SCEVGEP9]] +; CHECK-NEXT: [[BOUND1:%.*]] = icmp ult i8* [[SCEVGEP9]], [[DST8]] +; CHECK-NEXT: [[FOUND_CONFLICT:%.*]] = and i1 [[BOUND0]], [[BOUND1]] +; CHECK-NEXT: br i1 [[FOUND_CONFLICT]], label [[ENTRY_SLPVERSIONED:%.*]], label [[ENTRY_SCALAR:%.*]] +; CHECK: entry.slpversioned: +; CHECK-NEXT: [[SRC_GEP_1:%.*]] = getelementptr inbounds i32, i32* [[SRC]], i64 1 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[SRC]] to <2 x i32>* +; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i32>, <2 x i32>* [[TMP0]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = ashr <2 x i32> [[TMP1]], +; CHECK-NEXT: [[DST_GEP_1:%.*]] = getelementptr inbounds i32, i32* [[DST]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast i32* [[DST]] to <2 x i32>* +; CHECK-NEXT: store <2 x i32> [[TMP2]], <2 x i32>* [[TMP3]], align 4 +; CHECK-NEXT: br label [[ENTRY_MERGE:%.*]] +; CHECK: entry.merge: +; CHECK-NEXT: ret void +; CHECK: entry.scalar: +; CHECK-NEXT: [[SRC_GEP_12:%.*]] = getelementptr inbounds i32, i32* [[SRC]], i64 1 +; CHECK-NEXT: [[SRC_13:%.*]] = load i32, i32* [[SRC_GEP_12]], align 4 +; CHECK-NEXT: [[SRC_04:%.*]] = load i32, i32* [[SRC]], align 4 +; CHECK-NEXT: [[R_15:%.*]] = ashr i32 [[SRC_13]], 16 +; CHECK-NEXT: [[R_06:%.*]] = ashr i32 [[SRC_04]], 16 +; CHECK-NEXT: [[DST_GEP_17:%.*]] = getelementptr inbounds i32, i32* [[DST]], i64 1 +; CHECK-NEXT: store i32 [[R_15]], i32* [[DST_GEP_17]], align 4 +; CHECK-NEXT: store i32 [[R_06]], i32* [[DST]], align 4 +; CHECK-NEXT: br label [[ENTRY_MERGE]] +; +entry: + %src.0 = load i32, i32* %src, align 4 + %r.0 = ashr i32 %src.0, 16 + store i32 %r.0, i32* %dst, align 4 + %src.gep.1 = getelementptr inbounds i32, i32* %src, i64 1 + %src.1 = load i32, i32* %src.gep.1, align 4 + %r.1 = ashr i32 %src.1, 16 + %dst.gep.1 = getelementptr inbounds i32, i32* %dst, i64 1 + store i32 %r.1, i32* %dst.gep.1, align 4 + ret void +} + +define void @no_version(i32* nocapture %dst, i32* nocapture readonly %src) { +; CHECK-LABEL: @no_version( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[SRC_GEP_1:%.*]] = getelementptr inbounds i32, i32* [[SRC:%.*]], i64 1 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast i32* [[SRC]] to <2 x i32>* +; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i32>, <2 x i32>* [[TMP0]], align 4 +; CHECK-NEXT: [[TMP2:%.*]] = ashr <2 x i32> [[TMP1]], +; CHECK-NEXT: [[DST_GEP_1:%.*]] = getelementptr inbounds i32, i32* [[DST:%.*]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast i32* [[DST]] to <2 x i32>* +; CHECK-NEXT: store <2 x i32> [[TMP2]], <2 x i32>* [[TMP3]], align 4 +; CHECK-NEXT: ret void +; +entry: + %src.0 = load i32, i32* %src, align 4 + %src.gep.1 = getelementptr inbounds i32, i32* %src, i64 1 + %src.1 = load i32, i32* %src.gep.1, align 4 + %r.0 = ashr i32 %src.0, 16 + %r.1 = ashr i32 %src.1, 16 + %dst.gep.1 = getelementptr inbounds i32, i32* %dst, i64 1 + store i32 %r.0, i32* %dst, align 4 + store i32 %r.1, i32* %dst.gep.1, align 4 + ret void +} +