diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -387,7 +387,10 @@ Result.LiveSet = LiveSet; } -static bool isKnownBaseResult(Value *V); +// Returns true is V is a knownBaseResult. If Input is passed in, we check that +// Input and V are not conflicting type such as one being scalar type and the +// other being vector type. +static bool isKnownBaseResult(Value *V, Value *Input = nullptr); namespace { @@ -635,13 +638,18 @@ /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV, /// is it known to be a base pointer? Or do we need to continue searching. -static bool isKnownBaseResult(Value *V) { +static bool isKnownBaseResult(Value *V, Value *Input) { if (!isa(V) && !isa(V) && !isa(V) && !isa(V) && !isa(V)) { // no recursion possible return true; } + // If we have the Input available, check that V and Input are both scalar or + // both vectors. + if (Input && + isa(V->getType()) != isa(Input->getType())) + return false; if (isa(V) && cast(V)->getMetadata("is_base_value")) { // This is a previously inserted base phi or select. We know @@ -762,7 +770,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { Value *Def = findBaseOrBDV(I, Cache); - if (isKnownBaseResult(Def)) + if (isKnownBaseResult(Def, I)) return Def; // Here's the rough algorithm: @@ -814,7 +822,7 @@ auto visitIncomingValue = [&](Value *InVal) { Value *Base = findBaseOrBDV(InVal, Cache); - if (isKnownBaseResult(Base)) + if (isKnownBaseResult(Base, InVal)) // Known bases won't need new instructions introduced and can be // ignored safely return; @@ -853,10 +861,10 @@ // Return a phi state for a base defining value. We'll generate a new // base state for known bases and expect to find a cached state otherwise. - auto getStateForBDV = [&](Value *baseValue) { - if (isKnownBaseResult(baseValue)) - return BDVState(baseValue); - auto I = States.find(baseValue); + auto GetStateForBDV = [&](Value *BaseValue, Value *Input) { + if (isKnownBaseResult(BaseValue, Input)) + return BDVState(BaseValue); + auto I = States.find(BaseValue); assert(I != States.end() && "lookup failed!"); return I->second; }; @@ -873,13 +881,14 @@ // much faster. for (auto Pair : States) { Value *BDV = Pair.first; - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + assert(!isKnownBaseResult(BDV, Pair.second.getBaseValue()) && + "why did it get added?"); // Given an input value for the current instruction, return a BDVState // instance which represents the BDV of that value. auto getStateForInput = [&](Value *V) mutable { Value *BDV = findBaseOrBDV(V, Cache); - return getStateForBDV(BDV); + return GetStateForBDV(BDV, V); }; BDVState NewState; @@ -926,41 +935,37 @@ } #endif - // Handle extractelement instructions and their uses. + // Handle all instructions that have a vector BDV, but the instruction itself + // is of scalar type. for (auto Pair : States) { Instruction *I = cast(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(I) && "why did it get added?"); + auto *BaseValue = State.getBaseValue(); + assert(!isKnownBaseResult(I, BaseValue) && "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); + if (!State.isBase() || !isa(BaseValue->getType())) + continue; // extractelement instructions are a bit special in that we may need to // insert an extract even when we know an exact base for the instruction. // The problem is that we need to convert from a vector base to a scalar // base for the particular indice we're interested in. - if (!State.isBase() || !isa(I) || - !isa(State.getBaseValue()->getType())) - continue; - auto *EE = cast(I); - // TODO: In many cases, the new instruction is just EE itself. We should - // exploit this, but can't do it here since it would break the invariant - // about the BDV not being known to be a base. - auto *BaseInst = ExtractElementInst::Create( - State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); - BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); - States[I] = BDVState(BDVState::Base, BaseInst); - - // We need to handle uses of the extractelement that have the same vector - // base as well but the use is a scalar type. Since we cannot reuse the - // same BaseInst above (may not satisfy property that base pointer should - // always dominate derived pointer), we conservatively set this as conflict. - // Setting the base value for these conflicts is handled in the next loop - // which traverses States. - for (User *U : I->users()) { - auto *UseI = dyn_cast(U); - if (!UseI || !States.count(UseI)) - continue; - if (!isa(UseI->getType()) && States[UseI] == State) - States[UseI] = BDVState(BDVState::Conflict); + if (isa(I)) { + auto *EE = cast(I); + // TODO: In many cases, the new instruction is just EE itself. We should + // exploit this, but can't do it here since it would break the invariant + // about the BDV not being known to be a base. + auto *BaseInst = ExtractElementInst::Create( + State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); + BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); + States[I] = BDVState(BDVState::Base, BaseInst); + } else if (!isa(I->getType())) { + // We need to handle cases that have a vector base but the instruction is + // a scalar type (these could be phis or selects or any instruction that + // are of scalar type, but the base can be a vector type). We + // conservatively set this as conflict. Setting the base value for these + // conflicts is handled in the next loop which traverses States. + States[I] = BDVState(BDVState::Conflict); } } @@ -969,7 +974,8 @@ for (auto Pair : States) { Instruction *I = cast(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(I) && "why did it get added?"); + assert(!isKnownBaseResult(I, State.getBaseValue()) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); // Since we're joining a vector and scalar base, they can never be the @@ -1030,7 +1036,7 @@ auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) { Value *BDV = findBaseOrBDV(Input, Cache); Value *Base = nullptr; - if (isKnownBaseResult(BDV)) { + if (isKnownBaseResult(BDV, Input)) { Base = BDV; } else { // Either conflict or base. @@ -1051,7 +1057,8 @@ Instruction *BDV = cast(Pair.first); BDVState State = Pair.second; - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + assert(!isKnownBaseResult(BDV, State.getBaseValue()) && + "why did it get added?"); assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); if (!State.isConflict()) continue; @@ -1141,7 +1148,7 @@ auto *BDV = Pair.first; Value *Base = Pair.second.getBaseValue(); assert(BDV && Base); - assert(!isKnownBaseResult(BDV) && "why did it get added?"); + assert(!isKnownBaseResult(BDV, Base) && "why did it get added?"); LLVM_DEBUG( dbgs() << "Updating base value cache" diff --git a/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll b/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll --- a/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll +++ b/llvm/test/Transforms/RewriteStatepointsForGC/scalar-base-vector.ll @@ -192,5 +192,75 @@ br label %header } +; Uses of extractelement that are of scalar type should not have the BDV +; incorrectly identified as a vector type. +define void @widget() gc "statepoint-example" { +; CHECK-LABEL: @widget( +; CHECK-NEXT: bb6: +; CHECK-NEXT: [[BASE_EE:%.*]] = extractelement <2 x i8 addrspace(1)*> zeroinitializer, i32 1, !is_base_value !0 +; CHECK-NEXT: [[TMP:%.*]] = extractelement <2 x i8 addrspace(1)*> undef, i32 1 +; CHECK-NEXT: br i1 undef, label [[BB7:%.*]], label [[BB9:%.*]] +; CHECK: bb7: +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12 +; CHECK-NEXT: br label [[BB11:%.*]] +; CHECK: bb9: +; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i8, i8 addrspace(1)* [[TMP]], i64 12 +; CHECK-NEXT: br i1 undef, label [[BB11]], label [[BB15:%.*]] +; CHECK: bb11: +; CHECK-NEXT: [[TMP12_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB7]] ], [ [[BASE_EE]], [[BB9]] ], !is_base_value !0 +; CHECK-NEXT: [[TMP12:%.*]] = phi i8 addrspace(1)* [ [[TMP8]], [[BB7]] ], [ [[TMP10]], [[BB9]] ] +; CHECK-NEXT: [[STATEPOINT_TOKEN:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP12_BASE]], i8 addrspace(1)* [[TMP12]]) +; CHECK-NEXT: [[TMP12_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 8) +; CHECK-NEXT: [[TMP12_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN]], i32 8, i32 9) +; CHECK-NEXT: br label [[BB15]] +; CHECK: bb15: +; CHECK-NEXT: [[TMP16_BASE:%.*]] = phi i8 addrspace(1)* [ [[BASE_EE]], [[BB9]] ], [ [[TMP12_BASE_RELOCATED]], [[BB11]] ], !is_base_value !0 +; CHECK-NEXT: [[TMP16:%.*]] = phi i8 addrspace(1)* [ [[TMP10]], [[BB9]] ], [ [[TMP12_RELOCATED]], [[BB11]] ] +; CHECK-NEXT: br i1 undef, label [[BB17:%.*]], label [[BB20:%.*]] +; CHECK: bb17: +; CHECK-NEXT: [[STATEPOINT_TOKEN1:%.*]] = call token (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 2882400000, i32 0, void ()* @snork, i32 0, i32 0, i32 0, i32 1, i32 undef, i8 addrspace(1)* [[TMP16_BASE]], i8 addrspace(1)* [[TMP16]]) +; CHECK-NEXT: [[TMP16_BASE_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 8) +; CHECK-NEXT: [[TMP16_RELOCATED:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN1]], i32 8, i32 9) +; CHECK-NEXT: br label [[BB20]] +; CHECK: bb20: +; CHECK-NEXT: [[DOT05:%.*]] = phi i8 addrspace(1)* [ [[TMP16_BASE_RELOCATED]], [[BB17]] ], [ [[TMP16_BASE]], [[BB15]] ] +; CHECK-NEXT: [[DOT0:%.*]] = phi i8 addrspace(1)* [ [[TMP16_RELOCATED]], [[BB17]] ], [ [[TMP16]], [[BB15]] ] +; CHECK-NEXT: [[STATEPOINT_TOKEN2:%.*]] = call token (i64, i32, void (i8 addrspace(1)*)*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidp1i8f(i64 2882400000, i32 0, void (i8 addrspace(1)*)* @foo, i32 1, i32 0, i8 addrspace(1)* [[DOT0]], i32 0, i32 0, i8 addrspace(1)* [[DOT05]], i8 addrspace(1)* [[DOT0]]) +; CHECK-NEXT: [[TMP16_BASE_RELOCATED3:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 8) +; CHECK-NEXT: [[TMP16_RELOCATED4:%.*]] = call coldcc i8 addrspace(1)* @llvm.experimental.gc.relocate.p1i8(token [[STATEPOINT_TOKEN2]], i32 8, i32 9) +; CHECK-NEXT: ret void +; +bb6: ; preds = %bb3 + %tmp = extractelement <2 x i8 addrspace(1)*> undef, i32 1 + br i1 undef, label %bb7, label %bb9 + +bb7: ; preds = %bb6 + %tmp8 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12 + br label %bb11 + +bb9: ; preds = %bb6, %bb6 + %tmp10 = getelementptr inbounds i8, i8 addrspace(1)* %tmp, i64 12 + br i1 undef, label %bb11, label %bb15 + +bb11: ; preds = %bb9, %bb7 + %tmp12 = phi i8 addrspace(1)* [ %tmp8, %bb7 ], [ %tmp10, %bb9 ] + call void @snork() [ "deopt"(i32 undef) ] + br label %bb15 + +bb15: ; preds = %bb11, %bb9, %bb9 + %tmp16 = phi i8 addrspace(1)* [ %tmp10, %bb9 ], [ %tmp12, %bb11 ] + br i1 undef, label %bb17, label %bb20 + +bb17: ; preds = %bb15 + call void @snork() [ "deopt"(i32 undef) ] + br label %bb20 + +bb20: ; preds = %bb17, %bb15, %bb15 + call void @foo(i8 addrspace(1)* %tmp16) + ret void +} + +declare void @snork() +declare void @foo(i8 addrspace(1)*) declare void @spam() declare <2 x i8 addrspace(1)*> @baz()