Index: lib/Transforms/Scalar/RewriteStatepointsForGC.cpp =================================================================== --- lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -377,8 +377,9 @@ static bool isKnownBaseResult(Value *V); /// Helper function for findBasePointer - Will return a value which either a) -/// defines the base pointer for the input or b) blocks the simple search -/// (i.e. a PHI or Select of two derived pointers) +/// defines the base pointer for the input, b) blocks the simple search +/// (i.e. a PHI or Select of two derived pointers), or c) involves a change +/// from pointer to vector type or back. static Value *findBaseDefiningValue(Value *I) { if (I->getType()->isVectorTy()) return findBaseDefiningValueOfVector(I).first; @@ -386,48 +387,6 @@ assert(I->getType()->isPointerTy() && "Illegal to ask for the base pointer of a non-pointer type"); - // This case is a bit of a hack - it only handles extracts from vectors which - // trivially contain only base pointers or cases where we can directly match - // the index of the original extract element to an insertion into the vector. - // See note inside the function for how to improve this. - if (auto *EEI = dyn_cast(I)) { - Value *VectorOperand = EEI->getVectorOperand(); - Value *Index = EEI->getIndexOperand(); - std::pair pair = - findBaseDefiningValueOfVector(VectorOperand, Index); - Value *VectorBase = pair.first; - if (VectorBase->getType()->isPointerTy()) - // We found a BDV for this specific element with the vector. This is an - // optimization, but in practice it covers most of the useful cases - // created via scalarization. - return VectorBase; - else { - assert(VectorBase->getType()->isVectorTy()); - if (pair.second) - // If the entire vector returned is known to be entirely base pointers, - // then the extractelement is valid base for this value. - return EEI; - else { - // Otherwise, we have an instruction which potentially produces a - // derived pointer and we need findBasePointers to clone code for us - // such that we can create an instruction which produces the - // accompanying base pointer. - // Note: This code is currently rather incomplete. We don't currently - // support the general form of shufflevector of insertelement. - // Conceptually, these are just 'base defining values' of the same - // variety as phi or select instructions. We need to update the - // findBasePointers algorithm to insert new 'base-only' versions of the - // original instructions. This is relative straight forward to do, but - // the case which would motivate the work hasn't shown up in real - // workloads yet. - assert((isa(VectorBase) || isa(VectorBase)) && - "need to extend findBasePointers for generic vector" - "instruction cases"); - return VectorBase; - } - } - } - if (isa(I)) // An incoming argument to the function is a base pointer // We should have never reached here if this argument isn't an gc value @@ -532,6 +491,33 @@ assert(!isa(I) && "Base pointer for a struct is meaningless"); + // An extractelement produces a base result exactly when it's input does. + // We may need to insert a parallel instruction to extract the appropriate + // element out of the base vector corresponding to the input. Given this, + // it's analogous to the phi and select case even though it's not a merge. + if (auto *EEI = dyn_cast(I)) { + Value *VectorOperand = EEI->getVectorOperand(); + Value *Index = EEI->getIndexOperand(); + std::pair pair = + findBaseDefiningValueOfVector(VectorOperand, Index); + Value *VectorBase = pair.first; + if (VectorBase->getType()->isPointerTy()) + // We found a BDV for this specific element with the vector. This is an + // optimization, but in practice it covers most of the useful cases + // created via scalarization. Note: The peephole optimization here is + // currently needed for correctness since the general algorithm doesn't + // yet handle insertelements. That will change shortly. + return VectorBase; + else { + assert(VectorBase->getType()->isVectorTy()); + // Otherwise, we have an instruction which potentially produces a + // derived pointer and we need findBasePointers to clone code for us + // such that we can create an instruction which produces the + // accompanying base pointer. + return EEI; + } + } + // The last two cases here don't return a base pointer. Instead, they // return a value which dynamically selects from amoung several base // derived pointers (each with it's own base potentially). It's the job of @@ -569,7 +555,7 @@ /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV, /// is it known to be a base pointer? Or do we need to continue searching. static bool isKnownBaseResult(Value *V) { - if (!isa(V) && !isa(V)) { + if (!isa(V) && !isa(V) && !isa(V)) { // no recursion possible return true; } @@ -722,7 +708,7 @@ #ifndef NDEBUG auto isExpectedBDVType = [](Value *BDV) { - return isa(BDV) || isa(BDV); + return isa(BDV) || isa(BDV) || isa(BDV); }; #endif @@ -754,10 +740,16 @@ if (PHINode *Phi = dyn_cast(Current)) { for (Value *InVal : Phi->incoming_values()) visitIncomingValue(InVal); - } else { - SelectInst *Sel = cast(Current); + } else if (SelectInst *Sel = dyn_cast(Current)) { visitIncomingValue(Sel->getTrueValue()); visitIncomingValue(Sel->getFalseValue()); + } else if (auto *EE = dyn_cast(Current)) { + visitIncomingValue(EE->getVectorOperand()); + } else { + // There are two classes of instructions we know we don't handle. + assert(isa(Current) || + isa(Current)); + llvm_unreachable("unimplemented instruction case"); } } // The frontier of visited instructions are the ones we might need to @@ -771,7 +763,7 @@ if (TraceLSP) { errs() << "States after initialization:\n"; for (auto Pair : states) - dbgs() << " " << Pair.second << " for " << Pair.first << "\n"; + dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"; } // TODO: come back and revisit the state transitions around inputs which @@ -809,9 +801,16 @@ if (SelectInst *select = dyn_cast(v)) { calculateMeet.meetWith(getStateForInput(select->getTrueValue())); calculateMeet.meetWith(getStateForInput(select->getFalseValue())); - } else - for (Value *Val : cast(v)->incoming_values()) + } else if (PHINode *Phi = dyn_cast(v)) { + for (Value *Val : Phi->incoming_values()) calculateMeet.meetWith(getStateForInput(Val)); + } else { + // The 'meet' for an extractelement is slightly trivial, but it's still + // useful in that it drives us to conflict if our input is. + auto *EE = cast(v); + calculateMeet.meetWith(getStateForInput(EE->getVectorOperand())); + } + BDVState oldState = states[v]; BDVState newState = calculateMeet.getResult(); @@ -828,7 +827,7 @@ if (TraceLSP) { errs() << "States after meet iteration:\n"; for (auto Pair : states) - dbgs() << " " << Pair.second << " for " << Pair.first << "\n"; + dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"; } // Insert Phis for all conflicts @@ -848,6 +847,24 @@ BDVState State = states[I]; assert(!isKnownBaseResult(I) && "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); + + // extractelement instructions are a bit special in that we may need to + // insert an extract even when we know an exact base for the instruction. + // The problem is that we need to convert from a vector base to a scalar + // base for the particular indice we're interested in. + if (State.isBase() && isa(I) && + isa(State.getBase()->getType())) { + auto *EE = cast(I); + // TODO: In many cases, the new instruction is just EE itself. We should + // exploit this, but can't do it here since it would break the invariant + // about the BDV not being known to be a base. + auto *BaseInst = ExtractElementInst::Create(State.getBase(), + EE->getIndexOperand(), + "base_ee", EE); + BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); + states[I] = BDVState(BDVState::Base, BaseInst); + } + if (!State.isConflict()) continue; @@ -861,14 +878,21 @@ std::string Name = I->hasName() ? (I->getName() + ".base").str() : "base_phi"; return PHINode::Create(I->getType(), NumPreds, Name, I); + } else if (SelectInst *Sel = dyn_cast(I)) { + // The undef will be replaced later + UndefValue *Undef = UndefValue::get(Sel->getType()); + std::string Name = I->hasName() ? + (I->getName() + ".base").str() : "base_select"; + return SelectInst::Create(Sel->getCondition(), Undef, + Undef, Name, Sel); + } else { + auto *EE = cast(I); + UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType()); + std::string Name = I->hasName() ? + (I->getName() + ".base").str() : "base_ee"; + return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name, + EE); } - SelectInst *Sel = cast(I); - // The undef will be replaced later - UndefValue *Undef = UndefValue::get(Sel->getType()); - std::string Name = I->hasName() ? - (I->getName() + ".base").str() : "base_select"; - return SelectInst::Create(Sel->getCondition(), Undef, - Undef, Name, Sel); }; Instruction *BaseInst = MakeBaseInstPlaceholder(I); // Add metadata marking this as a base value @@ -947,8 +971,7 @@ basephi->addIncoming(base, InBB); } assert(basephi->getNumIncomingValues() == NumPHIValues); - } else { - SelectInst *basesel = cast(state.getBase()); + } else if (SelectInst *basesel = dyn_cast(state.getBase())) { SelectInst *sel = cast(v); // Operand 1 & 2 are true, false path respectively. TODO: refactor to // something more safe and less hacky. @@ -971,6 +994,18 @@ } basesel->setOperand(i, base); } + } else { + auto *BaseEE = cast(state.getBase()); + Value *InVal = cast(v)->getVectorOperand(); + Value *Base = findBaseOrBDV(InVal, cache); + if (!isKnownBaseResult(Base)) { + // Either conflict or base. + assert(states.count(Base)); + Base = states[Base].getBase(); + assert(Base != nullptr && "unknown BDVState!"); + } + assert(Base && "can't be null"); + BaseEE->setOperand(0, Base); } } Index: test/Transforms/RewriteStatepointsForGC/base-vector.ll =================================================================== --- test/Transforms/RewriteStatepointsForGC/base-vector.ll +++ test/Transforms/RewriteStatepointsForGC/base-vector.ll @@ -0,0 +1,88 @@ +; RUN: opt %s -rewrite-statepoints-for-gc -S | FileCheck %s + +define i64 addrspace(1)* @test(<2 x i64 addrspace(1)*> %vec, i32 %idx) gc "statepoint-example" { +; CHECK-LABEL: @test +; CHECK: extractelement +; CHECK: extractelement +; CHECK: statepoint +; CHECK: gc.relocate +; CHECK-DAG: ; (%base_ee, %base_ee) +; CHECK: gc.relocate +; CHECK-DAG: ; (%base_ee, %obj) +; Note that the second extractelement is actually redundant here. A correct output would +; be to reuse the existing obj as a base since it is actually a base pointer. +entry: + %obj = extractelement <2 x i64 addrspace(1)*> %vec, i32 %idx + %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0) + + ret i64 addrspace(1)* %obj +} + +define i64 addrspace(1)* @test2(<2 x i64 addrspace(1)*>* %ptr, i1 %cnd, i32 %idx1, i32 %idx2) + gc "statepoint-example" { +; CHECK-LABEL: test2 +entry: + br i1 %cnd, label %taken, label %untaken +taken: + %obja = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr + br label %merge +untaken: + %objb = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr + br label %merge +merge: + %vec = phi <2 x i64 addrspace(1)*> [%obja, %taken], [%objb, %untaken] + br i1 %cnd, label %taken2, label %untaken2 +taken2: + %obj0 = extractelement <2 x i64 addrspace(1)*> %vec, i32 %idx1 + br label %merge2 +untaken2: + %obj1 = extractelement <2 x i64 addrspace(1)*> %vec, i32 %idx2 + br label %merge2 +merge2: +; CHECK-LABEL: merge2: +; CHECK: %obj.base = phi i64 addrspace(1)* +; CHECK: %obj = phi i64 addrspace(1)* +; CHECK: statepoint +; CHECK: gc.relocate +; CHECK-DAG: ; (%obj.base, %obj) +; CHECK: gc.relocate +; CHECK-DAG: ; (%obj.base, %obj.base) + %obj = phi i64 addrspace(1)* [%obj0, %taken2], [%obj1, %untaken2] + %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0) + ret i64 addrspace(1)* %obj +} + +define i64 addrspace(1)* @test3(i64 addrspace(1)* %ptr) + gc "statepoint-example" { +; CHECK-LABEL: test3 +entry: + %vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %ptr, i32 0 + %obj = extractelement <2 x i64 addrspace(1)*> %vec, i32 0 +; CHECK: insertelement +; CHECK: extractelement +; CHECK: statepoint +; CHECK: gc.relocate +; CHECK-DAG: ; (%ptr, %obj) + %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0) + ret i64 addrspace(1)* %obj +} +define i64 addrspace(1)* @test4(i64 addrspace(1)* %ptr) + gc "statepoint-example" { +; CHECK-LABEL: test4 +entry: + %derived = getelementptr i64, i64 addrspace(1)* %ptr, i64 16 + %veca = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %derived, i32 0 + %vec = insertelement <2 x i64 addrspace(1)*> %veca, i64 addrspace(1)* %ptr, i32 1 + %obj = extractelement <2 x i64 addrspace(1)*> %vec, i32 0 +; CHECK: statepoint +; CHECK: gc.relocate +; CHECK-DAG: ; (%ptr, %obj) +; CHECK: gc.relocate +; CHECK-DAG: ; (%ptr, %ptr) + %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0) + ret i64 addrspace(1)* %obj +} + +declare void @do_safepoint() + +declare i32 @llvm.experimental.gc.statepoint.p0f_isVoidf(i64, i32, void ()*, i32, i32, ...)