Index: llvm/lib/Transforms/Utils/InlineFunction.cpp =================================================================== --- llvm/lib/Transforms/Utils/InlineFunction.cpp +++ llvm/lib/Transforms/Utils/InlineFunction.cpp @@ -1633,6 +1633,53 @@ CalledFunc->isDeclaration()) // call! return InlineResult::failure("external or indirect"); + // If the current call has a Dereferenceable parameter at ArgNo, + // walk all ActualArg's (which must be a load) uses and validate that all the + // calls we find have at least one parameter which shares this property and + // match the same definition as ActualArg. + auto HaveQualifyingDerefLoads = [CalledFunc](Value *ActualArg, unsigned ArgNo, + LoadInst *LI) { + // Must have a param that is Dereferenceable as a use of this LI. + // This is our entry criteria to keep us from iterating on the uses of + // each load. + if (!CalledFunc->hasParamAttribute(ArgNo, Attribute::Dereferenceable)) + return false; + + // If the load already has MD_dereferenceable, there nothing to do. + if (LI->getMetadata(LLVMContext::MD_dereferenceable)) + return false; + + // We want to place this attribute on the load iff all uses + // that are calls have MD_dereferenceable params. + for (User *Use : LI->users()) { + // Nothing to analyze. + if (!isa(Use)) + continue; + + CallInst *CI = cast(Use); + Function *CurFunc = CI->getCalledFunction(); + CallBase *CurCS = dyn_cast(CI); + auto CurAI = CurCS->arg_begin(); + bool FoundDereferenceableParam = false; + unsigned CurArgNo = 0; + for (Function::arg_iterator CurI = CurFunc->arg_begin(), + CurE = CurFunc->arg_end(); + CurI != CurE; ++CurI, ++CurAI, ++CurArgNo) { + Value *CurArg = *CurAI; + if (CurFunc->hasParamAttribute(CurArgNo, Attribute::Dereferenceable)) { + if (CurArg == ActualArg) { + FoundDereferenceableParam = true; + break; + } + } + } + if (!FoundDereferenceableParam) + return false; + } + + return true; + }; + // The inliner does not know how to inline through calls with operand bundles // in general ... if (CB.hasOperandBundles()) { @@ -1791,6 +1838,18 @@ CalledFunc->getParamAlignment(ArgNo)); if (ActualArg != *AI) ByValInit.push_back(std::make_pair(ActualArg, (Value*) *AI)); + } else if (auto LI = dyn_cast(ActualArg)) { + // Mark loads which have calls with dereferenceable parameters with + // their dereferenceable metadata. + if (HaveQualifyingDerefLoads(ActualArg, ArgNo, LI)) { + uint64_t DerefBytes = CalledFunc->getParamDereferenceableBytes(ArgNo); + auto &Ctx = LI->getContext(); + MDBuilder MDB(Ctx); + LI->setMetadata( + LLVMContext::MD_dereferenceable, + MDNode::get(Ctx, MDB.createConstant(ConstantInt::get( + Type::getInt64Ty(Ctx), DerefBytes)))); + } } VMap[&*I] = ActualArg; Index: llvm/test/Transforms/Inline/inline-dereferenceable.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/Inline/inline-dereferenceable.ll @@ -0,0 +1,25 @@ +; RUN: opt -S -inline < %s | FileCheck %s + +; Map dereferenceable on the load that provides the parameter to inner +; via the dereferenceable parameter definition in the inliner. + +%struct._primitive_acceleration_structure_t = type opaque +%struct.Header = type { i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i64, %struct._Box } +%struct._Box = type { [3 x float], [3 x float] } + +@buffer_pointers.0 = external constant %struct._primitive_acceleration_structure_t*, section "buffer_bindings", align 8 + +; CHECK: define i32 @inner(%struct.Header* nocapture readonly dereferenceable(96) %a) +define i32 @inner(%struct.Header* nocapture readonly dereferenceable(96) %a) { + %innerNodeOffset.i = getelementptr inbounds %struct.Header, %struct.Header* %a, i32 0, i32 9 + %x = load i32, i32* %innerNodeOffset.i, align 4 + ret i32 %x +} + +define i32 @outer() { +; CHECK: %struct = load %struct.Header*, %struct.Header** bitcast (%struct._primitive_acceleration_structure_t** @buffer_pointers.0 to %struct.Header**), align 8, !dereferenceable !0 + %struct = load %struct.Header*, %struct.Header** bitcast (%struct._primitive_acceleration_structure_t** @buffer_pointers.0 to %struct.Header**), align 8 + %r = call i32 @inner(%struct.Header* %struct) + ret i32 %r +} +