Index: llvm/trunk/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ llvm/trunk/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -294,12 +294,17 @@ static Value *findBaseDefiningValue(Value *I); -/// If we can trivially determine that the index specified in the given vector -/// is a base pointer, return it. In cases where the entire vector is known to -/// consist of base pointers, the entire vector will be returned. This -/// indicates that the relevant extractelement is a valid base pointer and -/// should be used directly. -static Value *findBaseOfVector(Value *I, Value *Index) { +/// Return a base defining value for the 'Index' element of the given vector +/// instruction 'I'. If Index is null, returns a BDV for the entire vector +/// 'I'. As an optimization, this method will try to determine when the +/// element is known to already be a base pointer. If this can be established, +/// the second value in the returned pair will be true. Note that either a +/// vector or a pointer typed value can be returned. For the former, the +/// vector returned is a BDV (and possibly a base) of the entire vector 'I'. +/// If the later, the return pointer is a BDV (or possibly a base) for the +/// particular element in 'I'. +static std::pair +findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) { assert(I->getType()->isVectorTy() && cast(I->getType())->getElementType()->isPointerTy() && "Illegal to ask for the base pointer of a non-pointer type"); @@ -309,7 +314,7 @@ if (isa(I)) // An incoming argument to the function is a base pointer - return I; + return std::make_pair(I, true); // We shouldn't see the address of a global as a vector value? assert(!isa(I) && @@ -320,7 +325,7 @@ if (isa(I)) // utterly meaningless, but useful for dealing with partially optimized // code. - return I; + return std::make_pair(I, true); // Due to inheritance, this must be _after_ the global variable and undef // checks @@ -328,38 +333,56 @@ assert(!isa(I) && !isa(I) && "order of checks wrong!"); assert(Con->isNullValue() && "null is the only case which makes sense"); - return Con; + return std::make_pair(Con, true); } - + if (isa(I)) - return I; - + return std::make_pair(I, true); + // For an insert element, we might be able to look through it if we know - // something about the indexes, but if the indices are arbitrary values, we - // can't without much more extensive scalarization. + // something about the indexes. if (InsertElementInst *IEI = dyn_cast(I)) { - Value *InsertIndex = IEI->getOperand(2); - // This index is inserting the value, look for it's base - if (InsertIndex == Index) - return findBaseDefiningValue(IEI->getOperand(1)); - // Both constant, and can't be equal per above. This insert is definitely - // not relevant, look back at the rest of the vector and keep trying. - if (isa(Index) && isa(InsertIndex)) - return findBaseOfVector(IEI->getOperand(0), Index); - } - - // Note: This code is currently rather incomplete. We are essentially only - // handling cases where the vector element is trivially a base pointer. We - // need to update the entire base pointer construction algorithm to know how - // to track vector elements and potentially scalarize, but the case which - // would motivate the work hasn't shown up in real workloads yet. - llvm_unreachable("no base found for vector element"); + if (Index) { + Value *InsertIndex = IEI->getOperand(2); + // This index is inserting the value, look for its BDV + if (InsertIndex == Index) + return std::make_pair(findBaseDefiningValue(IEI->getOperand(1)), false); + // Both constant, and can't be equal per above. This insert is definitely + // not relevant, look back at the rest of the vector and keep trying. + if (isa(Index) && isa(InsertIndex)) + return findBaseDefiningValueOfVector(IEI->getOperand(0), Index); + } + + // We don't know whether this vector contains entirely base pointers or + // not. To be conservatively correct, we treat it as a BDV and will + // duplicate code as needed to construct a parallel vector of bases. + return std::make_pair(IEI, false); + } + + if (isa(I)) + // We don't know whether this vector contains entirely base pointers or + // not. To be conservatively correct, we treat it as a BDV and will + // duplicate code as needed to construct a parallel vector of bases. + // TODO: There a number of local optimizations which could be applied here + // for particular sufflevector patterns. + return std::make_pair(I, false); + + // A PHI or Select is a base defining value. The outer findBasePointer + // algorithm is responsible for constructing a base value for this BDV. + assert((isa(I) || isa(I)) && + "unknown vector instruction - no base found for vector element"); + return std::make_pair(I, false); } +static bool isKnownBaseResult(Value *V); + /// Helper function for findBasePointer - Will return a value which either a) /// defines the base pointer for the input or b) blocks the simple search /// (i.e. a PHI or Select of two derived pointers) static Value *findBaseDefiningValue(Value *I) { + if (I->getType()->isVectorTy()) + return findBaseDefiningValueOfVector(I).first; + assert(I->getType()->isPointerTy() && "Illegal to ask for the base pointer of a non-pointer type"); @@ -370,16 +393,39 @@ if (auto *EEI = dyn_cast(I)) { Value *VectorOperand = EEI->getVectorOperand(); Value *Index = EEI->getIndexOperand(); - Value *VectorBase = findBaseOfVector(VectorOperand, Index); - // If the result returned is a vector, we know the entire vector must - // contain base pointers. In that case, the extractelement is a valid base - // for this value. - if (VectorBase->getType()->isVectorTy()) - return EEI; - // Otherwise, we needed to look through the vector to find the base for - // this particular element. - assert(VectorBase->getType()->isPointerTy()); - return VectorBase; + std::pair pair = + findBaseDefiningValueOfVector(VectorOperand, Index); + Value *VectorBase = pair.first; + if (VectorBase->getType()->isPointerTy()) + // We found a BDV for this specific element with the vector. This is an + // optimization, but in practice it covers most of the useful cases + // created via scalarization. + return VectorBase; + else { + assert(VectorBase->getType()->isVectorTy()); + if (pair.second) + // If the entire vector returned is known to be entirely base pointers, + // then the extractelement is valid base for this value. + return EEI; + else { + // Otherwise, we have an instruction which potentially produces a + // derived pointer and we need findBasePointers to clone code for us + // such that we can create an instruction which produces the + // accompanying base pointer. + // Note: This code is currently rather incomplete. We don't currently + // support the general form of shufflevector of insertelement. + // Conceptually, these are just 'base defining values' of the same + // variety as phi or select instructions. We need to update the + // findBasePointers algorithm to insert new 'base-only' versions of the + // original instructions. This is relative straight forward to do, but + // the case which would motivate the work hasn't shown up in real + // workloads yet. + assert((isa(VectorBase) || isa(VectorBase)) && + "need to extend findBasePointers for generic vector" + "instruction cases"); + return VectorBase; + } + } } if (isa(I)) @@ -1712,7 +1758,9 @@ /// slightly non-trivial since it requires a format change. Given how rare /// such cases are (for the moment?) scalarizing is an acceptable comprimise. static void splitVectorValues(Instruction *StatepointInst, - StatepointLiveSetTy &LiveSet, DominatorTree &DT) { + StatepointLiveSetTy &LiveSet, + DenseMap& PointerToBase, + DominatorTree &DT) { SmallVector ToSplit; for (Value *V : LiveSet) if (isa(V->getType())) @@ -1721,14 +1769,14 @@ if (ToSplit.empty()) return; + DenseMap> ElementMapping; + Function &F = *(StatepointInst->getParent()->getParent()); DenseMap AllocaMap; // First is normal return, second is exceptional return (invoke only) DenseMap> Replacements; for (Value *V : ToSplit) { - LiveSet.erase(V); - AllocaInst *Alloca = new AllocaInst(V->getType(), "", F.getEntryBlock().getFirstNonPHI()); AllocaMap[V] = Alloca; @@ -1738,7 +1786,7 @@ SmallVector Elements; for (unsigned i = 0; i < VT->getNumElements(); i++) Elements.push_back(Builder.CreateExtractElement(V, Builder.getInt32(i))); - LiveSet.insert(Elements.begin(), Elements.end()); + ElementMapping[V] = Elements; auto InsertVectorReform = [&](Instruction *IP) { Builder.SetInsertPoint(IP); @@ -1771,6 +1819,7 @@ Replacements[V].second = InsertVectorReform(IP); } } + for (Value *V : ToSplit) { AllocaInst *Alloca = AllocaMap[V]; @@ -1814,6 +1863,25 @@ for (Value *V : ToSplit) Allocas.push_back(AllocaMap[V]); PromoteMemToReg(Allocas, DT); + + // Update our tracking of live pointers and base mappings to account for the + // changes we just made. + for (Value *V : ToSplit) { + auto &Elements = ElementMapping[V]; + + LiveSet.erase(V); + LiveSet.insert(Elements.begin(), Elements.end()); + // We need to update the base mapping as well. + assert(PointerToBase.count(V)); + Value *OldBase = PointerToBase[V]; + auto &BaseElements = ElementMapping[OldBase]; + PointerToBase.erase(V); + assert(Elements.size() == BaseElements.size()); + for (unsigned i = 0; i < Elements.size(); i++) { + Value *Elem = Elements[i]; + PointerToBase[Elem] = BaseElements[i]; + } + } } // Helper function for the "rematerializeLiveValues". It walks use chain @@ -2075,17 +2143,6 @@ // site. findLiveReferences(F, DT, P, toUpdate, records); - // Do a limited scalarization of any live at safepoint vector values which - // contain pointers. This enables this pass to run after vectorization at - // the cost of some possible performance loss. TODO: it would be nice to - // natively support vectors all the way through the backend so we don't need - // to scalarize here. - for (size_t i = 0; i < records.size(); i++) { - struct PartiallyConstructedSafepointRecord &info = records[i]; - Instruction *statepoint = toUpdate[i].getInstruction(); - splitVectorValues(cast(statepoint), info.liveset, DT); - } - // B) Find the base pointers for each live pointer /* scope for caching */ { // Cache the 'defining value' relation used in the computation and @@ -2146,6 +2203,18 @@ } holders.clear(); + // Do a limited scalarization of any live at safepoint vector values which + // contain pointers. This enables this pass to run after vectorization at + // the cost of some possible performance loss. TODO: it would be nice to + // natively support vectors all the way through the backend so we don't need + // to scalarize here. + for (size_t i = 0; i < records.size(); i++) { + struct PartiallyConstructedSafepointRecord &info = records[i]; + Instruction *statepoint = toUpdate[i].getInstruction(); + splitVectorValues(cast(statepoint), info.liveset, + info.PointerToBase, DT); + } + // In order to reduce live set of statepoint we might choose to rematerialize // some values instead of relocating them. This is purelly an optimization and // does not influence correctness. Index: llvm/trunk/test/Transforms/RewriteStatepointsForGC/live-vector.ll =================================================================== --- llvm/trunk/test/Transforms/RewriteStatepointsForGC/live-vector.ll +++ llvm/trunk/test/Transforms/RewriteStatepointsForGC/live-vector.ll @@ -105,8 +105,6 @@ ; CHECK-NEXT: bitcast ; CHECK-NEXT: gc.relocate ; CHECK-NEXT: bitcast -; CHECK-NEXT: gc.relocate -; CHECK-NEXT: bitcast ; CHECK-NEXT: insertelement ; CHECK-NEXT: insertelement ; CHECK-NEXT: ret <2 x i64 addrspace(1)*> %7 @@ -116,6 +114,48 @@ ret <2 x i64 addrspace(1)*> %vec } + +; A base vector from a load +define <2 x i64 addrspace(1)*> @test6(i1 %cnd, <2 x i64 addrspace(1)*>* %ptr) + gc "statepoint-example" { +; CHECK-LABEL: test6 +; CHECK-LABEL: merge: +; CHECK-NEXT: = phi +; CHECK-NEXT: = phi +; CHECK-NEXT: extractelement +; CHECK-NEXT: extractelement +; CHECK-NEXT: extractelement +; CHECK-NEXT: extractelement +; CHECK-NEXT: gc.statepoint +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: gc.relocate +; CHECK-NEXT: bitcast +; CHECK-NEXT: insertelement +; CHECK-NEXT: insertelement +; CHECK-NEXT: insertelement +; CHECK-NEXT: insertelement +; CHECK-NEXT: ret <2 x i64 addrspace(1)*> +entry: + br i1 %cnd, label %taken, label %untaken +taken: + %obja = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr + br label %merge +untaken: + %objb = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr + br label %merge + +merge: + %obj = phi <2 x i64 addrspace(1)*> [%obja, %taken], [%objb, %untaken] + %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0) + ret <2 x i64 addrspace(1)*> %obj +} + + declare void @do_safepoint() declare i32 @llvm.experimental.gc.statepoint.p0f_isVoidf(i64, i32, void ()*, i32, i32, ...)