diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -387,8 +387,13 @@ Result.LiveSet = LiveSet; } +// Returns true is V is a knownBaseResult. static bool isKnownBaseResult(Value *V); +// Returns true if V is a BaseResult that already exists in the IR, i.e. it is +// not created by the findBasePointers algorithm. +static bool isOriginalBaseResult(Value *V); + namespace { /// A single base defining value - An immediate base defining value for an @@ -633,15 +638,20 @@ return Def; } +/// This value is a base pointer that is not generated by RS4GC, i.e. it already +/// exists in the code. +static bool isOriginalBaseResult(Value *V) { + // no recursion possible + return !isa(V) && !isa(V) && + !isa(V) && !isa(V) && + !isa(V); +} + /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV, /// is it known to be a base pointer? Or do we need to continue searching. static bool isKnownBaseResult(Value *V) { - if (!isa(V) && !isa(V) && - !isa(V) && !isa(V) && - !isa(V)) { - // no recursion possible + if (isOriginalBaseResult(V)) return true; - } if (isa(V) && cast(V)->getMetadata("is_base_value")) { // This is a previously inserted base phi or select. We know @@ -653,6 +663,12 @@ return false; } +// Returns true if First and Second values are both scalar or both vector. +static bool isCorrectType(Value *First, Value *Second) { + return isa(First->getType()) == + isa(Second->getType()); +} + namespace { /// Models the state of a single base defining value in the findBasePointer @@ -762,7 +778,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { Value *Def = findBaseOrBDV(I, Cache); - if (isKnownBaseResult(Def)) + if (isKnownBaseResult(Def) && isCorrectType(Def, I)) return Def; // Here's the rough algorithm: @@ -810,13 +826,16 @@ States.insert({Def, BDVState()}); while (!Worklist.empty()) { Value *Current = Worklist.pop_back_val(); - assert(!isKnownBaseResult(Current) && "why did it get added?"); + assert(!isOriginalBaseResult(Current) && "why did it get added?"); auto visitIncomingValue = [&](Value *InVal) { Value *Base = findBaseOrBDV(InVal, Cache); - if (isKnownBaseResult(Base)) + if (isKnownBaseResult(Base) && isCorrectType(Base, InVal)) // Known bases won't need new instructions introduced and can be - // ignored safely + // ignored safely. However, this can only be done when InVal and Base + // are both scalar or both vector. Otherwise, we need to find a + // correct BDV for InVal, by creating an entry in the lattice + // (States). return; assert(isExpectedBDVType(Base) && "the only non-base values " "we see should be base defining values"); @@ -853,10 +872,10 @@ // Return a phi state for a base defining value. We'll generate a new // base state for known bases and expect to find a cached state otherwise. - auto getStateForBDV = [&](Value *baseValue) { - if (isKnownBaseResult(baseValue)) - return BDVState(baseValue); - auto I = States.find(baseValue); + auto GetStateForBDV = [&](Value *BaseValue, Value *Input) { + if (isKnownBaseResult(BaseValue) && isCorrectType(BaseValue, Input)) + return BDVState(BaseValue); + auto I = States.find(BaseValue); assert(I != States.end() && "lookup failed!"); return I->second; }; @@ -873,13 +892,18 @@ // much faster. for (auto Pair : States) { Value *BDV = Pair.first; - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(BDV) || + !isCorrectType(BDV, Pair.second.getBaseValue())) && + "why did it get added?"); // Given an input value for the current instruction, return a BDVState // instance which represents the BDV of that value. auto getStateForInput = [&](Value *V) mutable { Value *BDV = findBaseOrBDV(V, Cache); - return getStateForBDV(BDV); + return GetStateForBDV(BDV, V); }; BDVState NewState; @@ -926,41 +950,41 @@ } #endif - // Handle extractelement instructions and their uses. + // Handle all instructions that have a vector BDV, but the instruction itself + // is of scalar type. for (auto Pair : States) { Instruction *I = cast(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(I) && "why did it get added?"); + auto *BaseValue = State.getBaseValue(); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(I) || !isCorrectType(I, BaseValue)) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); + if (!State.isBase() || !isa(BaseValue->getType())) + continue; // extractelement instructions are a bit special in that we may need to // insert an extract even when we know an exact base for the instruction. // The problem is that we need to convert from a vector base to a scalar // base for the particular indice we're interested in. - if (!State.isBase() || !isa(I) || - !isa(State.getBaseValue()->getType())) - continue; - auto *EE = cast(I); - // TODO: In many cases, the new instruction is just EE itself. We should - // exploit this, but can't do it here since it would break the invariant - // about the BDV not being known to be a base. - auto *BaseInst = ExtractElementInst::Create( - State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); - BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); - States[I] = BDVState(BDVState::Base, BaseInst); - - // We need to handle uses of the extractelement that have the same vector - // base as well but the use is a scalar type. Since we cannot reuse the - // same BaseInst above (may not satisfy property that base pointer should - // always dominate derived pointer), we conservatively set this as conflict. - // Setting the base value for these conflicts is handled in the next loop - // which traverses States. - for (User *U : I->users()) { - auto *UseI = dyn_cast(U); - if (!UseI || !States.count(UseI)) - continue; - if (!isa(UseI->getType()) && States[UseI] == State) - States[UseI] = BDVState(BDVState::Conflict); + if (isa(I)) { + auto *EE = cast(I); + // TODO: In many cases, the new instruction is just EE itself. We should + // exploit this, but can't do it here since it would break the invariant + // about the BDV not being known to be a base. + auto *BaseInst = ExtractElementInst::Create( + State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); + BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); + States[I] = BDVState(BDVState::Base, BaseInst); + } else if (!isa(I->getType())) { + // We need to handle cases that have a vector base but the instruction is + // a scalar type (these could be phis or selects or any instruction that + // are of scalar type, but the base can be a vector type). We + // conservatively set this as conflict. Setting the base value for these + // conflicts is handled in the next loop which traverses States. + States[I] = BDVState(BDVState::Conflict); } } @@ -969,7 +993,11 @@ for (auto Pair : States) { Instruction *I = cast(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(I) && "why did it get added?"); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(I) || !isCorrectType(I, State.getBaseValue())) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); // Since we're joining a vector and scalar base, they can never be the @@ -1030,7 +1058,7 @@ auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) { Value *BDV = findBaseOrBDV(Input, Cache); Value *Base = nullptr; - if (isKnownBaseResult(BDV)) { + if (isKnownBaseResult(BDV) && isCorrectType(BDV, Input)) { Base = BDV; } else { // Either conflict or base. @@ -1051,7 +1079,12 @@ Instruction *BDV = cast(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(BDV) || + !isCorrectType(BDV, State.getBaseValue())) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); if (!State.isConflict()) continue; @@ -1141,7 +1174,11 @@ auto *BDV = Pair.first; Value *Base = Pair.second.getBaseValue(); assert(BDV && Base); - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + // Only values that do not have known bases or those that have differing + // type (scalar versus vector) from a possible known base should be in the + // lattice. + assert((!isKnownBaseResult(BDV) || !isCorrectType(BDV, Base)) && + "why did it get added?"); LLVM_DEBUG( dbgs() << "Updating base value cache" diff --git a/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll b/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll --- a/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll +++ b/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll @@ -192,5 +192,75 @@ br label %header } +; Uses of extractelement that are of scalar type should not have the BDV +; incorrectly identified as a vector type. +define void @widget() gc "statepoint-example" { +; CHECK-LABEL: @widget( +; CHECK-NEXT: bb6: +; CHECK-NEXT: [[BASE_EE:%.*]] = extractelement <2 x i8 addrspace(1)*> zeroinitializer, i32 1, !is_base_value !0 +; CHECK-NEXT: [[TMP:%.*]] = extractelement <2 x i8 addrspace(1)*> undef, i32 1 +; CHECK-NEXT: br i1 undef, label [[BB7:%.*]], label [[BB9:%.*]] +; CHECK: bb7: +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12 +; CHECK-NEXT: br label [[BB11:%.*]] +; CHECK: bb9: +; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12 +; CHECK-NEXT: br i1 undef, label [[BB11]], label [[BB15:%.*]] +; CHECK: bb11: +; CHECK-NEXT: [[TMP12_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB7]] ], [ [[BASE_EE]], [[BB9]] ], !is_base_value !0 +; CHECK-NEXT: [[TMP12:%.*]] = phi i8 addrspace(1)* [ [[TMP8]], [[BB7]] ], [ [[TMP10]], [[BB9]] ] +; CHECK-NEXT: [[STATEPOINT_TOKEN:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP12_BASE]], i8 addrspace(1)* [[TMP12]]) +; CHECK-NEXT: [[TMP12_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 8) +; CHECK-NEXT: [[TMP12_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 9) +; CHECK-NEXT: br label [[BB15]] +; CHECK: bb15: +; CHECK-NEXT: [[TMP16_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB9]] ], [ [[TMP12_BASE_RELOCATED]], [[BB11]] ], !is_base_value !0 +; CHECK-NEXT: [[TMP16:%.*]] = phi i8 addrspace(1)* [ [[TMP10]], [[BB9]] ], [ [[TMP12_RELOCATED]], [[BB11]] ] +; CHECK-NEXT: br i1 undef, label [[BB17:%.*]], label [[BB20:%.*]] +; CHECK: bb17: +; CHECK-NEXT: [[STATEPOINT_TOKEN1:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP16_BASE]], i8 addrspace(1)* [[TMP16]]) +; CHECK-NEXT: [[TMP16_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 8) +; CHECK-NEXT: [[TMP16_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 9) +; CHECK-NEXT: br label [[BB20]] +; CHECK: bb20: +; CHECK-NEXT: [[DOT05:%.*]] = phi i8 addrspace(1)* [ [[TMP16_BASE_RELOCATED]], [[BB17]] ], [ [[TMP16_BASE]], [[BB15]] ] +; CHECK-NEXT: [[DOT0:%.*]] = phi i8 addrspace(1)* [ [[TMP16_RELOCATED]], [[BB17]] ], [ [[TMP16]], [[BB15]] ] +; CHECK-NEXT: [[STATEPOINT_TOKEN2:%.*]] = call token (i64, i32, void (i8 addrspace(1)*)*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidp1i8f(i64 2882400000, i32 0, void (i8 addrspace(1)*)* @foo, i32 1, i32 0, i8 addrspace(1)* [[DOT0]], i32 0, i32 0, i8 addrspace(1)* [[DOT05]], i8 addrspace(1)* [[DOT0]]) +; CHECK-NEXT: [[TMP16_BASE_RELOCATED3:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 8) +; CHECK-NEXT: [[TMP16_RELOCATED4:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 9) +; CHECK-NEXT: ret void +; +bb6: ; preds = %bb3 + %tmp = extractelement <2 x i8 addrspace(1)*> undef, i32 1 + br i1 undef, label %bb7, label %bb9 + +bb7: ; preds = %bb6 + %tmp8 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12 + br label %bb11 + +bb9: ; preds = %bb6, %bb6 + %tmp10 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12 + br i1 undef, label %bb11, label %bb15 + +bb11: ; preds = %bb9, %bb7 + %tmp12 = phi i8 addrspace(1)* [ %tmp8, %bb7 ], [ %tmp10, %bb9 ] + call void @snork() [ "deopt"(i32 undef) ] + br label %bb15 + +bb15: ; preds = %bb11, %bb9, %bb9 + %tmp16 = phi i8 addrspace(1)* [ %tmp10, %bb9 ], [ %tmp12, %bb11 ] + br i1 undef, label %bb17, label %bb20 + +bb17: ; preds = %bb15 + call void @snork() [ "deopt"(i32 undef) ] + br label %bb20 + +bb20: ; preds = %bb17, %bb15, %bb15 + call void @foo(i8 addrspace(1)* %tmp16) + ret void +} + +declare void @snork() +declare void @foo(i8 addrspace(1)*) declare void @spam() declare <2 x i8 addrspace(1)*> @baz()