diff --git a/llvm/include/llvm/Transforms/IPO/Attributor.h b/llvm/include/llvm/Transforms/IPO/Attributor.h --- a/llvm/include/llvm/Transforms/IPO/Attributor.h +++ b/llvm/include/llvm/Transforms/IPO/Attributor.h @@ -591,6 +591,13 @@ return AG.getAnalysis(F); } + /// Return true if \p Arg is involved in a must-tail call, thus the argument + /// of the caller or callee. + bool isInvolvedInMustTailCall(const Argument &Arg) const { + return FunctionsCalledViaMustTail.count(Arg.getParent()) || + FunctionsWithMustTailCall.count(Arg.getParent()); + } + /// Return the analysis result from a pass \p AP for function \p F. template typename AP::Result *getAnalysisResultForFunction(const Function &F) { @@ -621,6 +628,12 @@ /// A map from functions to their instructions that may read or write memory. FuncRWInstsMapTy FuncRWInstsMap; + /// Functions called by a `musttail` call. + SmallPtrSet FunctionsCalledViaMustTail; + + /// Functions containing a `musttail` call. + SmallPtrSet FunctionsWithMustTailCall; + /// The datalayout used in the module. const DataLayout &DL; diff --git a/llvm/lib/Transforms/IPO/Attributor.cpp b/llvm/lib/Transforms/IPO/Attributor.cpp --- a/llvm/lib/Transforms/IPO/Attributor.cpp +++ b/llvm/lib/Transforms/IPO/Attributor.cpp @@ -4062,10 +4062,20 @@ struct AAAlignArgument final : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext { - AAAlignArgument(const IRPosition &IRP) - : AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext( - IRP) {} + using Base = + AAArgumentFromCallSiteArgumentsAndMustBeExecutedContext; + AAAlignArgument(const IRPosition &IRP) : Base(IRP) {} + + /// See AbstractAttribute::initialize(...). + void initialize(Attributor &A) override { + Base::initialize(A); + // If the associated argument is involved in a must-tail call we give up + // because we would need to keep the argument alignments of caller and + // callee in-sync. Just does not seem worth the trouble right now. + if (A.getInfoCache().isInvolvedInMustTailCall(*getAssociatedArgument())) + indicatePessimisticFixpoint(); + } /// See AbstractAttribute::trackStatistics() void trackStatistics() const override { STATS_DECLTRACK_ARG_ATTR(aligned) } @@ -8308,11 +8318,12 @@ "New call site/base instruction type needs to be known in the " "Attributor."); break; - case Instruction::Load: - // The alignment of a pointer is interesting for loads. - case Instruction::Store: - // The alignment of a pointer is interesting for stores. case Instruction::Call: + if (cast(I).isMustTailCall()) { + InfoCache.FunctionsWithMustTailCall.insert(&F); + InfoCache.FunctionsCalledViaMustTail.insert( + cast(I).getCalledFunction()); + } case Instruction::CallBr: case Instruction::Invoke: case Instruction::CleanupRet: @@ -8322,6 +8333,10 @@ case Instruction::Br: case Instruction::Resume: case Instruction::Ret: + case Instruction::Load: + // The alignment of a pointer is interesting for loads. + case Instruction::Store: + // The alignment of a pointer is interesting for stores. IsInterestingOpcode = true; } if (IsInterestingOpcode) @@ -8356,6 +8371,15 @@ if (F.isDeclaration()) return; + // In non-module runs we need to look at the call sites of a function to + // determine if it is part of a must-tail call edge. This will influence what + // attributes we can derive. + if (!isModulePass() && !InfoCache.FunctionsCalledViaMustTail.count(&F)) + for (const Use &U : F.uses()) + if (ImmutableCallSite ICS = ImmutableCallSite(U.getUser())) + if (ICS.isCallee(&U) && ICS.isMustTailCall()) + InfoCache.FunctionsCalledViaMustTail.insert(&F); + IRPosition FPos = IRPosition::function(F); // Check for dead BasicBlocks in every function. diff --git a/llvm/test/Transforms/Attributor/ArgumentPromotion/musttail.ll b/llvm/test/Transforms/Attributor/ArgumentPromotion/musttail.ll --- a/llvm/test/Transforms/Attributor/ArgumentPromotion/musttail.ll +++ b/llvm/test/Transforms/Attributor/ArgumentPromotion/musttail.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --scrub-attributes -; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=1 < %s | FileCheck %s +; RUN: opt -S -passes=attributor -aa-pipeline='basic-aa' -attributor-disable=false -attributor-max-iterations-verify -attributor-max-iterations=2 < %s | FileCheck %s ; PR36543 ; Don't promote arguments of musttail callee diff --git a/llvm/test/Transforms/Attributor/align.ll b/llvm/test/Transforms/Attributor/align.ll --- a/llvm/test/Transforms/Attributor/align.ll +++ b/llvm/test/Transforms/Attributor/align.ll @@ -552,6 +552,23 @@ store i8 0, i8* %bc ret void } + +; Make sure we do not annotate the callee of a must-tail call with an alignment +; we cannot also put on the caller. +@cnd = external global i1 +define i32 @musttail_callee_1(i32* %p) { + %v = load i32, i32* %p, align 32 + ret i32 %v +} +define i32 @musttail_caller_1(i32* %p) { + %c = load i1, i1* @cnd + br i1 %c, label %mt, label %exit +mt: + %v = musttail call i32 @musttail_callee_1(i32* %p) + ret i32 %v +exit: + ret i32 0 +} ; UTC_ARGS: --disable attributes #0 = { nounwind uwtable noinline }