Index: lib/Transforms/Scalar/MemCpyOptimizer.cpp =================================================================== --- lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -992,8 +992,92 @@ MemDepResult SourceDep = MD->getPointerDependencyFrom(MemoryLocation::getForSource(MDep), false, M->getIterator(), M->getParent()); - if (!SourceDep.isClobber() || SourceDep.getInst() != MDep) + MemDepResult MSourceDep = MD->getPointerDependencyFrom( + MemoryLocation::getForSource(M), false, M->getIterator(), M->getParent()); + MemDepResult DestDep = + MD->getPointerDependencyFrom(MemoryLocation::getForDest(M), false, + M->getIterator(), M->getParent(), M); + DominatorTree &DT = LookupDomTree(); + + // Three cases: + // Case 1: + // memcpy(b <- a); ...; *b = 42; ...; memcpy(a <- b); + // => if a is never mod/refed in between the two memcpys + // ...; *a = 42; ...; memcpy(b <- a); + if (M->getDest() == MDep->getSource() && DestDep.getInst() == MDep) { + // TODO: figure out how to replace uses within a basic block range + DEBUG(dbgs() << "TODO: for case 1, figure out how to replace uses within " + "bb range\n"); return false; + } + + // Case 2: + // memcpy(b <- a); ...; memcpy(c <- b); + // => if "..." doesn't mod/ref either c or b + // memcpy(c <- a); memcpy(b <- a); *a = 42; + else if (MSourceDep.getInst() == MDep && + (!DestDep.getInst() || DestDep.getInst() == MDep || + DT.dominates(DestDep.getInst(), MDep))) { + DEBUG(dbgs() << "case 2: " << *MDep << "\n"); + // move our memcpy up to just after mdep + DenseSet inrange, visited; + for (Instruction &i : make_range(MDep->getIterator(), M->getIterator())) { + inrange.insert(&i); + } + // identify dependencies of the memcpy that also need to moved upwards. + SmallVector tomove, stack{M}; + while (!stack.empty()) { + SmallVector next; + Instruction *cur = stack.back(); + for (Use &op : cur->operands()) { + if (Instruction *i = dyn_cast(op.get())) { + if (inrange.find(i) != inrange.end() && + visited.find(i) == visited.end()) { + next.push_back(i); + } + } + } + if (next.empty()) { + // leaf node + tomove.push_back(cur); + visited.insert(cur); + stack.pop_back(); + } else { + stack.append(next.begin(), next.end()); + } + } + + for (auto i : tomove) { + i->moveBefore(MDep); + // refresh MemDep cache + MD->removeInstruction(i); + } + } + + // TODO: Case 3: + // memcpy(b <- a); ...; memcpy(c <- b) + // => if "..." doesn't mod/ref b or a + // ...; memcpy(b <- a); memcpy(c <- b) + else if (MSourceDep.getInst() == MDep && + (!SourceDep.getInst() || SourceDep.getInst() == MDep || + DT.dominates(SourceDep.getInst(), MDep))) { + DEBUG(dbgs() << "TODO: case 3.\n"); + return false; + } else { + // none of the cases match; ignore. + DEBUG(dbgs() << "No matching case. Ignoring. " << *M << "\n" + << *MDep << "\n"); + return false; + } + + // Bail early if `memcpy(a <- b); memcpy(b <- a)` + if (AA.isMustAlias(MemoryLocation::getForDest(M), + MemoryLocation::getForSource(MDep))) { + MD->removeInstruction(M); + M->eraseFromParent(); + ++NumMemCpyInstr; + return true; + } // If the dest of the second might alias the source of the first, then the // source and dest might overlap. We still want to eliminate the intermediate @@ -1193,10 +1277,7 @@ MemDepResult SrcDepInfo = MD->getPointerDependencyFrom( SrcLoc, true, M->getIterator(), M->getParent()); - if (SrcDepInfo.isClobber()) { - if (MemCpyInst *MDep = dyn_cast(SrcDepInfo.getInst())) - return processMemCpyMemCpyDependence(M, MDep); - } else if (SrcDepInfo.isDef()) { + if (SrcDepInfo.isDef()) { Instruction *I = SrcDepInfo.getInst(); bool hasUndefContents = false; @@ -1226,6 +1307,19 @@ return true; } + // search upwards within bb for possible memcpy-memcpy dep + for (MemDepResult d = MD->getPointerDependencyFrom( + SrcLoc, false, M->getIterator(), M->getParent()); + !d.isNonLocal() && d.getInst(); + d = MD->getPointerDependencyFrom( + SrcLoc, false, d.getInst()->getIterator(), M->getParent(), M)) { + if (MemCpyInst *MDep = dyn_cast(d.getInst())) { + if (MDep->getDest() == M->getSource()) { + return processMemCpyMemCpyDependence(M, MDep); + } + } + } + return false; }