diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -158,6 +158,8 @@ Instruction *visitFenceInst(FenceInst &FI); Instruction *visitSwitchInst(SwitchInst &SI); Instruction *visitReturnInst(ReturnInst &RI); + Instruction * + foldAggregateConstructionIntoAggregateReuse(InsertValueInst &OrigIVI); Instruction *visitInsertValueInst(InsertValueInst &IV); Instruction *visitInsertElementInst(InsertElementInst &IE); Instruction *visitExtractElementInst(ExtractElementInst &EI); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" @@ -32,6 +33,7 @@ #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/IR/Verifier.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" @@ -46,6 +48,10 @@ #define DEBUG_TYPE "instcombine" +STATISTIC(NumAggregateReconstructionsSimplified, + "Number of aggregate reconstructions turned into reuse of the " + "original aggregate"); + /// Return true if the value is cheaper to scalarize than it is to leave as a /// vector operation. IsConstantExtractIndex indicates whether we are extracting /// one known element from a vector constant. @@ -694,6 +700,240 @@ return std::make_pair(V, nullptr); } +/// Look for chain of insertvalue's that fully define an aggregate, and trace +/// back the values inserted, see if they are all were extractvalue'd from +/// the same source aggregate from the exact same element indexes. +/// If they were, just reuse the source aggregate. +/// This potentially deals with PHI indirections. +Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( + InsertValueInst &OrigIVI) { + BasicBlock *UseBB = OrigIVI.getParent(); + Type *AggTy = OrigIVI.getType(); + unsigned NumAggElts; + switch (AggTy->getTypeID()) { + case Type::StructTyID: + NumAggElts = AggTy->getStructNumElements(); + break; + case Type::ArrayTyID: + NumAggElts = AggTy->getArrayNumElements(); + break; + default: + llvm_unreachable("Unhandled aggregate type?"); + } + + // Arbitrary aggregate size cut-off. Motivation for limit of 2 is to be able + // to handle clang C++ exception struct (which is hardcoded as {i8*, i32}), + // FIXME: any interesting patterns to be caught with larger limit? + assert(NumAggElts > 0 && "Aggregate should have elements."); + if (NumAggElts > 2) + return nullptr; + + // Try to find a value of each element of an aggregate. + // FIXME: deal with more complex, not one-dimensional, aggregate types + SmallVector, 2> AggElts(NumAggElts, None /*Unknown value*/); + + // Do we know values for each element of the aggregate? + auto KnowAllElts = [&AggElts]() { + return all_of(AggElts, [](Optional Elt) { return Elt; }); + }; + + int Depth = 0; + + // Arbitrary `insertvalue` visitation depth limit. Let's be okay with + // every element being overwritten twice, which should never happen. + static const int DepthLimit = 2 * NumAggElts; + + // Recurse up the chain of `insertvalue` aggregate operands until either we've + // reconstructed full initializer or can't visit any more `insertvalue`'s. + for (InsertValueInst *CurrIVI = &OrigIVI; + Depth < DepthLimit && CurrIVI && !KnowAllElts(); + CurrIVI = dyn_cast(CurrIVI->getAggregateOperand()), + ++Depth) { + Value *InsertedValue = CurrIVI->getInsertedValueOperand(); + ArrayRef Indices = CurrIVI->getIndices(); + + // Don't bother with more than single-level aggregates. + if (Indices.size() != 1) + return nullptr; // FIXME: deal with more complex aggregates? + + // Now, we may have already previously recorded the value for this element + // of an aggregate. If we did, that means the CurrIVI will later be + // overwritten with the already-recorded value. But if not, let's record it! + Optional &Elt = AggElts[Indices.front()]; + Elt = Elt.getValueOr(InsertedValue); + + // FIXME: should we handle chain-terminating undef base operand? + } + + // Was that sufficient to deduce the full initializer for the aggregate? + if (!KnowAllElts()) + return nullptr; // Give up then. + + // We now want to find the source[s] of the aggregate elements we've found. + // And with "source" we mean the original aggregate[s] from which + // the inserted elements were extracted. This may require PHI translation. + + enum class SourceAggegate { + /// When analyzing the value that was inserted into an aggregate, we did + /// not manage to find defining `extractvalue` instruction to analyze. + NotFound, + /// When analyzing the value that was inserted into an aggregate, we did + /// manage to find defining `extractvalue` instruction[s], and everything + /// matched perfectly - aggregate type, element insertion/extraction index. + Found, + /// When analyzing the value that was inserted into an aggregate, we did + /// manage to find defining `extractvalue` instruction, but there was + /// a mismatch: either the source type from which the extraction was didn't + /// match the aggregate type into which the insertion was, + /// or the extraction/insertion channels mismatched, + /// or different elements had different source aggregates. + FoundMismatch + }; + auto Describe = [](Optional SourceAggregate) { + if (SourceAggregate == None) + return SourceAggegate::NotFound; + if (*SourceAggregate != nullptr) + return SourceAggegate::Found; + return SourceAggegate::FoundMismatch; + }; + + // Given the value \p Elt that was being inserted into element \p EltIdx of an + // aggregate AggTy, see if \p Elt was originally defined by an + // appropriate extractvalue (same element index, same aggregate type). + // If found, return the source aggregate from which the extraction was. + // If \p PredBB is provided, does PHI translation of an \p Elt first. + auto FindSourceAggregate = + [&](Value *Elt, unsigned EltIdx, + Optional PredBB) -> Optional { + // For now(?), only deal with, at most, a single level of PHI indirection. + if (PredBB) + Elt = Elt->DoPHITranslation(UseBB, *PredBB); + // FIXME: deal with multiple levels of PHI indirection? + + // Did we find an extraction? + auto *EVI = dyn_cast(Elt); + if (!EVI) + return None; // SourceAggegate::NotFound + + Value *SourceAggregate = EVI->getAggregateOperand(); + + // Is the extraction from the same type into which the insertion was? + if (SourceAggregate->getType() != AggTy) + return nullptr; // SourceAggegate::FoundMismatch + // And the element index doesn't change between extraction and insertion? + if (EVI->getNumIndices() != 1 || EltIdx != EVI->getIndices().front()) + return nullptr; // SourceAggegate::FoundMismatch + + return SourceAggregate; // SourceAggegate::Found + }; + + // Given elements AggElts that were constructing an aggregate OrigIVI, + // see if we can find appropriate source aggregate for each of the elements, + // and see it's the same aggregate for each element. If so, return it. + auto FindCommonSourceAggregate = + [&](Optional PredBB) -> Optional { + Optional SourceAggregate; + + for (auto I : enumerate(AggElts)) { + assert(Describe(SourceAggregate) != SourceAggegate::FoundMismatch && + "We don't store nullptr in SourceAggregate!"); + assert((Describe(SourceAggregate) == SourceAggegate::Found) == + (I.index() != 0) && + "SourceAggregate should be valid after the the first element,"); + + // For this element, is there a plausible source aggregate? + // FIXME: we could special-case undef element, IFF we know that in the + // source aggregate said element isn't poison. + Optional SourceAggregateForElement = + FindSourceAggregate(*I.value(), I.index(), PredBB); + + // Okay, what have we found? Does that correlate with previous findings? + switch (Describe(SourceAggregateForElement)) { + case SourceAggegate::NotFound: + case SourceAggegate::FoundMismatch: + // Regardless of whether or not we have previously found source + // aggregate for previous elements (if any), if we didn't find one for + // this element, passthrough whatever we have just found. + return SourceAggregateForElement; + case SourceAggegate::Found: + // Okay, we have found source aggregate for this element. + // Let's see what we already know from previous elements, if any. + switch (Describe(SourceAggregate)) { + case SourceAggegate::NotFound: + // This is apparently the first element that we have examined. + SourceAggregate = SourceAggregateForElement; // Record the aggregate! + continue; // Great, now look at next element. + case SourceAggegate::Found: + // We have previously already successfully examined other elements. + // Is this the same source aggregate we've found for other elements? + if (*SourceAggregateForElement != *SourceAggregate) + return nullptr; // SourceAggegate::FoundMismatch + continue; // Still the same aggregate, look at next element. + case SourceAggegate::FoundMismatch: + llvm_unreachable("Can't happen. We would have early-exited then."); + }; + }; + } + + assert(Describe(SourceAggregate) == SourceAggegate::Found && + "Must be a valid Value"); + return *SourceAggregate; + }; + + Optional SourceAggregate; + + // Can we find the source aggregate without looking at predecessors? + SourceAggregate = FindCommonSourceAggregate(/*PredBB=*/None); + if (Describe(SourceAggregate) != SourceAggegate::NotFound) { + if (Describe(SourceAggregate) == SourceAggegate::FoundMismatch) + return nullptr; // Conflicting source aggregates! + ++NumAggregateReconstructionsSimplified; + return replaceInstUsesWith(OrigIVI, *SourceAggregate); + } + + // If we didn't manage to find source aggregate without looking at + // predecessors, and there are no predecessors to look at, then we're done. + if (pred_empty(UseBB)) + return nullptr; + + // Okay, apparently we need to look at predecessors. + + // Arbitrary predecessor count limit. + // Don't bother if there are more than 64 predecessors. + if (UseBB->hasNPredecessorsOrMore(64 + 1)) + return nullptr; + + // For each predecessor, what is the source aggregate, + // from which all the elements were originally extracted from? + // Note that we want for the map to have stable iteration order! + SmallMapVector SourceAggregates; + for (BasicBlock *Pred : predecessors(UseBB)) { + std::pair IV = + SourceAggregates.insert({Pred, nullptr}); + // Did we already evaluate this predecessor? + if (!IV.second) + continue; + + // Let's hope that when coming from predecessor Pred, all elements of the + // aggregate produced by OrigIVI must have been originally extracted from + // the same aggregate. Is that so? Can we find said original aggregate? + SourceAggregate = FindCommonSourceAggregate(Pred); + if (Describe(SourceAggregate) != SourceAggegate::Found) + return nullptr; // Give up. + IV.first->second = *SourceAggregate; + } + + // All good! Now we just need to thread the source aggregates here. + auto *PHI = PHINode::Create(AggTy, SourceAggregates.size(), + OrigIVI.getName() + ".merged"); + for (const std::pair &SourceAggregate : + SourceAggregates) + PHI->addIncoming(SourceAggregate.second, SourceAggregate.first); + + ++NumAggregateReconstructionsSimplified; + return PHI; +}; + /// Try to find redundant insertvalue instructions, like the following ones: /// %0 = insertvalue { i8, i32 } undef, i8 %x, 0 /// %1 = insertvalue { i8, i32 } %0, i8 %y, 0 @@ -726,6 +966,10 @@ if (IsRedundant) return replaceInstUsesWith(I, I.getOperand(0)); + + if (Instruction *NewI = foldAggregateConstructionIntoAggregateReuse(I)) + return NewI; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -3643,10 +3643,14 @@ BasicBlock *InstParent = I->getParent(); BasicBlock::iterator InsertPos = I->getIterator(); - // If we replace a PHI with something that isn't a PHI, fix up the - // insertion point. - if (!isa(Result) && isa(InsertPos)) - InsertPos = InstParent->getFirstInsertionPt(); + // Are we replace a PHI with something that isn't a PHI, or vice versa? + if (isa(Result) != isa(I)) { + // We need to fix up the insertion point. + if (isa(I)) // PHI -> Non-PHI + InsertPos = InstParent->getFirstInsertionPt(); + else // Non-PHI -> PHI + InsertPos = InstParent->getFirstNonPHI()->getIterator(); + } InstParent->getInstList().insert(InsertPos, Result); diff --git a/llvm/test/Transforms/InstCombine/aggregate-reconstruction.ll b/llvm/test/Transforms/InstCombine/aggregate-reconstruction.ll --- a/llvm/test/Transforms/InstCombine/aggregate-reconstruction.ll +++ b/llvm/test/Transforms/InstCombine/aggregate-reconstruction.ll @@ -13,11 +13,7 @@ ; We should just return the source aggregate. define { i32, i32 } @test0({ i32, i32 } %srcagg) { ; CHECK-LABEL: @test0( -; CHECK-NEXT: [[I0:%.*]] = extractvalue { i32, i32 } [[SRCAGG:%.*]], 0 -; CHECK-NEXT: [[I1:%.*]] = extractvalue { i32, i32 } [[SRCAGG]], 1 -; CHECK-NEXT: [[I2:%.*]] = insertvalue { i32, i32 } undef, i32 [[I0]], 0 -; CHECK-NEXT: [[I3:%.*]] = insertvalue { i32, i32 } [[I2]], i32 [[I1]], 1 -; CHECK-NEXT: ret { i32, i32 } [[I3]] +; CHECK-NEXT: ret { i32, i32 } [[SRCAGG:%.*]] ; %i0 = extractvalue { i32, i32 } %srcagg, 0 %i1 = extractvalue { i32, i32 } %srcagg, 1 @@ -29,11 +25,7 @@ ; Arrays are still aggregates define [2 x i32] @test1([2 x i32] %srcagg) { ; CHECK-LABEL: @test1( -; CHECK-NEXT: [[I0:%.*]] = extractvalue [2 x i32] [[SRCAGG:%.*]], 0 -; CHECK-NEXT: [[I1:%.*]] = extractvalue [2 x i32] [[SRCAGG]], 1 -; CHECK-NEXT: [[I2:%.*]] = insertvalue [2 x i32] undef, i32 [[I0]], 0 -; CHECK-NEXT: [[I3:%.*]] = insertvalue [2 x i32] [[I2]], i32 [[I1]], 1 -; CHECK-NEXT: ret [2 x i32] [[I3]] +; CHECK-NEXT: ret [2 x i32] [[SRCAGG:%.*]] ; %i0 = extractvalue [2 x i32] %srcagg, 0 %i1 = extractvalue [2 x i32] %srcagg, 1 @@ -83,11 +75,7 @@ ; This is fine, however, all elements are on the same level define { i32, { i32 } } @test4({ i32, { i32 } } %srcagg) { ; CHECK-LABEL: @test4( -; CHECK-NEXT: [[I0:%.*]] = extractvalue { i32, { i32 } } [[SRCAGG:%.*]], 0 -; CHECK-NEXT: [[I1:%.*]] = extractvalue { i32, { i32 } } [[SRCAGG]], 1 -; CHECK-NEXT: [[I2:%.*]] = insertvalue { i32, { i32 } } undef, i32 [[I0]], 0 -; CHECK-NEXT: [[I3:%.*]] = insertvalue { i32, { i32 } } [[I2]], { i32 } [[I1]], 1 -; CHECK-NEXT: ret { i32, { i32 } } [[I3]] +; CHECK-NEXT: ret { i32, { i32 } } [[SRCAGG:%.*]] ; %i0 = extractvalue { i32, { i32 } } %srcagg, 0 %i1 = extractvalue { i32, { i32 } } %srcagg, 1 @@ -216,8 +204,7 @@ ; CHECK-NEXT: call void @usei32(i32 [[I1]]) ; CHECK-NEXT: [[I2:%.*]] = insertvalue { i32, i32 } undef, i32 [[I0]], 0 ; CHECK-NEXT: call void @usei32i32agg({ i32, i32 } [[I2]]) -; CHECK-NEXT: [[I3:%.*]] = insertvalue { i32, i32 } [[I2]], i32 [[I1]], 1 -; CHECK-NEXT: ret { i32, i32 } [[I3]] +; CHECK-NEXT: ret { i32, i32 } [[SRCAGG]] ; %i0 = extractvalue { i32, i32 } %srcagg, 0 call void @usei32(i32 %i0) @@ -233,11 +220,7 @@ ; overwritten with %i0, so all is fine. define { i32, i32 } @test13({ i32, i32 } %srcagg) { ; CHECK-LABEL: @test13( -; CHECK-NEXT: [[I0:%.*]] = extractvalue { i32, i32 } [[SRCAGG:%.*]], 0 -; CHECK-NEXT: [[I1:%.*]] = extractvalue { i32, i32 } [[SRCAGG]], 1 -; CHECK-NEXT: [[I3:%.*]] = insertvalue { i32, i32 } undef, i32 [[I0]], 0 -; CHECK-NEXT: [[I4:%.*]] = insertvalue { i32, i32 } [[I3]], i32 [[I1]], 1 -; CHECK-NEXT: ret { i32, i32 } [[I4]] +; CHECK-NEXT: ret { i32, i32 } [[SRCAGG:%.*]] ; %i0 = extractvalue { i32, i32 } %srcagg, 0 %i1 = extractvalue { i32, i32 } %srcagg, 1 @@ -283,11 +266,7 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[END:%.*]] ; CHECK: end: -; CHECK-NEXT: [[I0:%.*]] = extractvalue { i32, i32 } [[SRCAGG:%.*]], 0 -; CHECK-NEXT: [[I1:%.*]] = extractvalue { i32, i32 } [[SRCAGG]], 1 -; CHECK-NEXT: [[I2:%.*]] = insertvalue { i32, i32 } undef, i32 [[I0]], 0 -; CHECK-NEXT: [[I3:%.*]] = insertvalue { i32, i32 } [[I2]], i32 [[I1]], 1 -; CHECK-NEXT: ret { i32, i32 } [[I3]] +; CHECK-NEXT: ret { i32, i32 } [[SRCAGG:%.*]] ; entry: br label %end @@ -308,11 +287,7 @@ ; CHECK-NEXT: br label [[END]] ; CHECK: end: ; CHECK-NEXT: [[SRCAGG_PHI:%.*]] = phi { i32, i32 } [ [[SRCAGG0:%.*]], [[ENTRY:%.*]] ], [ [[SRCAGG1:%.*]], [[INTERMEDIATE]] ] -; CHECK-NEXT: [[I0:%.*]] = extractvalue { i32, i32 } [[SRCAGG_PHI]], 0 -; CHECK-NEXT: [[I1:%.*]] = extractvalue { i32, i32 } [[SRCAGG_PHI]], 1 -; CHECK-NEXT: [[I2:%.*]] = insertvalue { i32, i32 } undef, i32 [[I0]], 0 -; CHECK-NEXT: [[I3:%.*]] = insertvalue { i32, i32 } [[I2]], i32 [[I1]], 1 -; CHECK-NEXT: ret { i32, i32 } [[I3]] +; CHECK-NEXT: ret { i32, i32 } [[SRCAGG_PHI]] ; entry: br i1 %c, label %intermediate, label %end diff --git a/llvm/test/Transforms/InstCombine/phi-aware-aggregate-reconstruction.ll b/llvm/test/Transforms/InstCombine/phi-aware-aggregate-reconstruction.ll --- a/llvm/test/Transforms/InstCombine/phi-aware-aggregate-reconstruction.ll +++ b/llvm/test/Transforms/InstCombine/phi-aware-aggregate-reconstruction.ll @@ -17,21 +17,14 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: br i1 [[C:%.*]], label [[LEFT:%.*]], label [[RIGHT:%.*]] ; CHECK: left: -; CHECK-NEXT: [[I0:%.*]] = extractvalue { i32, i32 } [[AGG_LEFT:%.*]], 0 -; CHECK-NEXT: [[I2:%.*]] = extractvalue { i32, i32 } [[AGG_LEFT]], 1 ; CHECK-NEXT: call void @foo() ; CHECK-NEXT: br label [[END:%.*]] ; CHECK: right: -; CHECK-NEXT: [[I3:%.*]] = extractvalue { i32, i32 } [[AGG_RIGHT:%.*]], 0 -; CHECK-NEXT: [[I4:%.*]] = extractvalue { i32, i32 } [[AGG_RIGHT]], 1 ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: br label [[END]] ; CHECK: end: -; CHECK-NEXT: [[I5:%.*]] = phi i32 [ [[I0]], [[LEFT]] ], [ [[I3]], [[RIGHT]] ] -; CHECK-NEXT: [[I6:%.*]] = phi i32 [ [[I2]], [[LEFT]] ], [ [[I4]], [[RIGHT]] ] +; CHECK-NEXT: [[I8:%.*]] = phi { i32, i32 } [ [[AGG_RIGHT:%.*]], [[RIGHT]] ], [ [[AGG_LEFT:%.*]], [[LEFT]] ] ; CHECK-NEXT: call void @baz() -; CHECK-NEXT: [[I7:%.*]] = insertvalue { i32, i32 } undef, i32 [[I5]], 0 -; CHECK-NEXT: [[I8:%.*]] = insertvalue { i32, i32 } [[I7]], i32 [[I6]], 1 ; CHECK-NEXT: ret { i32, i32 } [[I8]] ; entry: @@ -278,24 +271,17 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: br i1 [[C0:%.*]], label [[LEFT:%.*]], label [[RIGHT:%.*]] ; CHECK: left: -; CHECK-NEXT: [[I0:%.*]] = extractvalue { i32, i32 } [[AGG_LEFT:%.*]], 0 -; CHECK-NEXT: [[I2:%.*]] = extractvalue { i32, i32 } [[AGG_LEFT]], 1 ; CHECK-NEXT: call void @foo() ; CHECK-NEXT: br label [[MIDDLE:%.*]] ; CHECK: right: -; CHECK-NEXT: [[I3:%.*]] = extractvalue { i32, i32 } [[AGG_RIGHT:%.*]], 0 -; CHECK-NEXT: [[I4:%.*]] = extractvalue { i32, i32 } [[AGG_RIGHT]], 1 ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: br label [[MIDDLE]] ; CHECK: middle: -; CHECK-NEXT: [[I5:%.*]] = phi i32 [ [[I0]], [[LEFT]] ], [ [[I3]], [[RIGHT]] ], [ [[I5]], [[MIDDLE]] ] -; CHECK-NEXT: [[I6:%.*]] = phi i32 [ [[I2]], [[LEFT]] ], [ [[I4]], [[RIGHT]] ], [ [[I6]], [[MIDDLE]] ] +; CHECK-NEXT: [[I8:%.*]] = phi { i32, i32 } [ [[I8]], [[MIDDLE]] ], [ [[AGG_RIGHT:%.*]], [[RIGHT]] ], [ [[AGG_LEFT:%.*]], [[LEFT]] ] ; CHECK-NEXT: call void @baz() ; CHECK-NEXT: [[C1:%.*]] = call i1 @geni1() ; CHECK-NEXT: br i1 [[C1]], label [[END:%.*]], label [[MIDDLE]] ; CHECK: end: -; CHECK-NEXT: [[I7:%.*]] = insertvalue { i32, i32 } undef, i32 [[I5]], 0 -; CHECK-NEXT: [[I8:%.*]] = insertvalue { i32, i32 } [[I7]], i32 [[I6]], 1 ; CHECK-NEXT: ret { i32, i32 } [[I8]] ; entry: