Index: llvm/include/llvm/IR/Attributes.h =================================================================== --- llvm/include/llvm/IR/Attributes.h +++ llvm/include/llvm/IR/Attributes.h @@ -107,6 +107,10 @@ const Optional &NumElemsArg); static Attribute getWithByValType(LLVMContext &Context, Type *Ty); + static Attribute::AttrKind getAttrKindFromName(StringRef AttrName); + + static StringRef getNameFromAttrKind(Attribute::AttrKind AttrKind); + //===--------------------------------------------------------------------===// // Attribute Accessors //===--------------------------------------------------------------------===// Index: llvm/include/llvm/Transforms/Utils/AssumeBuilder.h =================================================================== --- /dev/null +++ llvm/include/llvm/Transforms/Utils/AssumeBuilder.h @@ -0,0 +1,38 @@ +//===- AssumeBuilder.h - Build assume to preserve informations---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H +#define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H + +#include "llvm/ADT/SmallSet.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +/// Build an instruction to preserve informations that can be derived from +/// this call. if no information was found, returns null. +Instruction *BuildAssumeFromInst(const Instruction *I, Module *M); +Instruction *BuildAssumeFromInst(Instruction *I) { + return BuildAssumeFromInst(I, I->getModule()); +} + +/// This pass will try to build an llvm.assume for every instruction in the +/// function. +class AssumeBuilderPass : public PassInfoMixin { + +public: + explicit AssumeBuilderPass() {} + + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +} // namespace llvm + +#endif Index: llvm/lib/IR/Attributes.cpp =================================================================== --- llvm/lib/IR/Attributes.cpp +++ llvm/lib/IR/Attributes.cpp @@ -176,6 +176,29 @@ return get(Context, AllocSize, packAllocSizeArgs(ElemSizeArg, NumElemsArg)); } +Attribute::AttrKind Attribute::getAttrKindFromName(StringRef AttrName) { + return StringSwitch(AttrName) +#define GET_ATTR_NAMES +#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ + .Case(#DISPLAY_NAME, Attribute::ENUM_NAME) +#include "llvm/IR/Attributes.inc" + .Default(Attribute::None); +} + +StringRef Attribute::getNameFromAttrKind(Attribute::AttrKind AttrKind) { + switch (AttrKind) { +#define GET_ATTR_NAMES +#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ + case AttrKind::ENUM_NAME: \ + return #DISPLAY_NAME; +#include "llvm/IR/Attributes.inc" + case AttrKind::None: + return "none"; + default: + llvm_unreachable("invalid Kind"); + } +} + //===----------------------------------------------------------------------===// // Attribute Accessor Methods //===----------------------------------------------------------------------===// Index: llvm/lib/IR/Core.cpp =================================================================== --- llvm/lib/IR/Core.cpp +++ llvm/lib/IR/Core.cpp @@ -127,17 +127,8 @@ return LLVMGetMDKindIDInContext(LLVMGetGlobalContext(), Name, SLen); } -static Attribute::AttrKind getAttrKindFromName(StringRef AttrName) { - return StringSwitch(AttrName) -#define GET_ATTR_NAMES -#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ - .Case(#DISPLAY_NAME, Attribute::ENUM_NAME) -#include "llvm/IR/Attributes.inc" - .Default(Attribute::None); -} - unsigned LLVMGetEnumAttributeKindForName(const char *Name, size_t SLen) { - return getAttrKindFromName(StringRef(Name, SLen)); + return Attribute::getAttrKindFromName(StringRef(Name, SLen)); } unsigned LLVMGetLastEnumAttributeKind(void) { Index: llvm/lib/Passes/PassBuilder.cpp =================================================================== --- llvm/lib/Passes/PassBuilder.cpp +++ llvm/lib/Passes/PassBuilder.cpp @@ -168,6 +168,7 @@ #include "llvm/Transforms/Scalar/TailRecursionElimination.h" #include "llvm/Transforms/Scalar/WarnMissedTransforms.h" #include "llvm/Transforms/Utils/AddDiscriminators.h" +#include "llvm/Transforms/Utils/AssumeBuilder.h" #include "llvm/Transforms/Utils/BreakCriticalEdges.h" #include "llvm/Transforms/Utils/CanonicalizeAliases.h" #include "llvm/Transforms/Utils/EntryExitInstrumenter.h" Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -159,6 +159,7 @@ FUNCTION_PASS("adce", ADCEPass()) FUNCTION_PASS("add-discriminators", AddDiscriminatorsPass()) FUNCTION_PASS("aggressive-instcombine", AggressiveInstCombinePass()) +FUNCTION_PASS("assume-builder", AssumeBuilderPass()) FUNCTION_PASS("alignment-from-assumptions", AlignmentFromAssumptionsPass()) FUNCTION_PASS("bdce", BDCEPass()) FUNCTION_PASS("bounds-checking", BoundsCheckingPass()) Index: llvm/lib/Transforms/Utils/AssumeBuilder.cpp =================================================================== --- /dev/null +++ llvm/lib/Transforms/Utils/AssumeBuilder.cpp @@ -0,0 +1,94 @@ +//===- AssumeBuilder.h - Build assume to preserve informations --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/AssumeBuilder.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormattedStream.h" + +using namespace llvm; + +namespace { + +struct AssumeBuilderState { + Module *M; + + SmallSet, 8> BundleSet; + + AssumeBuilderState(Module *M) : M(M) {} + + void AddAttribute(AssumeBuilderState &Builder, Attribute Attr, Value *WasOn) { + StringRef Name; + Value *AttrArg = nullptr; + if (Attr.isStringAttribute()) + Name = Attr.getKindAsString(); + else + Name = Attribute::getNameFromAttrKind(Attr.getKindAsEnum()); + if (Attr.isIntAttribute()) + AttrArg = ConstantInt::get(Type::getInt64Ty(Builder.M->getContext()), + Attr.getValueAsInt()); + Builder.BundleSet.insert(std::make_tuple(Name.data(), AttrArg, WasOn)); + } + + void AddCall(AssumeBuilderState &Builder, const CallBase *Call) { + for (AttributeList AttrList : + {Call->getAttributes(), Call->getCalledFunction()->getAttributes()}) { + /// the start Index is at 1 because 0 is the return value. + for (unsigned Idx = 1; Idx < AttrList.getNumAttrSets(); Idx++) + for (Attribute Attr : AttrList.getAttributes(Idx)) + AddAttribute(Builder, Attr, Call->getArgOperand(Idx - 1)); + for (Attribute Attr : AttrList.getFnAttributes()) + AddAttribute(Builder, Attr, Call->getCalledFunction()); + } + } + + Instruction *Build() { + if (!BundleSet.empty()) { + Function *FnAssume = Intrinsic::getDeclaration(M, Intrinsic::assume); + LLVMContext &C = M->getContext(); + SmallVector OpBundle; + for (auto Elem : BundleSet) { + SmallVector Args; + if (std::get<1>(Elem)) + Args.push_back(std::get<1>(Elem)); + Args.push_back(std::get<2>(Elem)); + OpBundle.push_back(OperandBundleDefT(std::get<0>(Elem), Args)); + } + return CallInst::Create( + FnAssume, ArrayRef({ConstantInt::getTrue(C)}), OpBundle); + } + return nullptr; + } + + void AddInstruction(const Instruction *I) { + if (auto *Call = dyn_cast(I)) + AddCall(*this, Call); + // TODO: Add support for the other Instruction. + } +}; + +} // namespace + +Instruction *llvm::BuildAssumeFromInst(const Instruction *I, Module *M) { + AssumeBuilderState Builder(M); + Builder.AddInstruction(I); + return Builder.Build(); +} + +PreservedAnalyses AssumeBuilderPass::run(Function &F, + FunctionAnalysisManager &AM) { + for (BasicBlock &BB : F) + for (Instruction &I : BB) + if (Instruction *Assume = BuildAssumeFromInst(&I)) + Assume->insertBefore(&I); + return PreservedAnalyses::all(); +} Index: llvm/lib/Transforms/Utils/CMakeLists.txt =================================================================== --- llvm/lib/Transforms/Utils/CMakeLists.txt +++ llvm/lib/Transforms/Utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_component_library(LLVMTransformUtils ASanStackFrameLayout.cpp AddDiscriminators.cpp + AssumeBuilder.cpp BasicBlockUtils.cpp BreakCriticalEdges.cpp BuildLibCalls.cpp Index: llvm/test/Transforms/Util/assume-builder.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/Util/assume-builder.ll @@ -0,0 +1,41 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='assume-builder' -S %s | FileCheck %s + +declare void @func(i32*, i32*) +declare void @func_cold(i32*) cold +declare void @func_strbool(i32*) "no-jump-tables" +declare void @func_many(i32*) "no-jump-tables" nounwind "less-precise-fpmad" willreturn norecurse +declare void @func_argattr(i32* align 8, i32* nonnull) nounwind + +define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) { +; CHECK-LABEL: @test( +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "nonnull"(i32* [[P:%.*]]), "dereferenceable"(i64 16, i32* [[P]]) ] +; CHECK-NEXT: call void @func(i32* nonnull dereferenceable(16) [[P]], i32* null) +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i64 12, i32* [[P1:%.*]]), "nonnull"(i32* [[P]]) ] +; CHECK-NEXT: call void @func(i32* dereferenceable(12) [[P1]], i32* nonnull [[P]]) +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i64 12, i32* [[P1]]), "cold"(void (i32*)* @func_cold) ] +; CHECK-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) #0 +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i64 12, i32* [[P1]]), "cold"(void (i32*)* @func_cold) ] +; CHECK-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) +; CHECK-NEXT: call void @func(i32* [[P1]], i32* [[P]]) +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "no-jump-tables"(void (i32*)* @func_strbool) ] +; CHECK-NEXT: call void @func_strbool(i32* [[P1]]) +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i64 16, i32* [[P]]), "dereferenceable"(i64 8, i32* [[P]]) ] +; CHECK-NEXT: call void @func(i32* dereferenceable(16) [[P]], i32* dereferenceable(8) [[P]]) +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(i64 8, i32* [[P1]]), "norecurse"(void (i32*)* @func_many), "nounwind"(void (i32*)* @func_many), "willreturn"(void (i32*)* @func_many), "less-precise-fpmad"(void (i32*)* @func_many), "no-jump-tables"(void (i32*)* @func_many) ] +; CHECK-NEXT: call void @func_many(i32* align 8 [[P1]]) +; CHECK-NEXT: call void @llvm.assume(i1 true) [ "align"(i64 8, i32* [[P2:%.*]]), "nonnull"(i32* [[P3:%.*]]), "nounwind"(void (i32*, i32*)* @func_argattr) ] +; CHECK-NEXT: call void @func_argattr(i32* [[P2]], i32* [[P3]]) +; CHECK-NEXT: ret void +; + call void @func(i32* nonnull dereferenceable(16) %P, i32* null) + call void @func(i32* dereferenceable(12) %P1, i32* nonnull %P) + call void @func_cold(i32* dereferenceable(12) %P1) cold + call void @func_cold(i32* dereferenceable(12) %P1) + call void @func(i32* %P1, i32* %P) + call void @func_strbool(i32* %P1) + call void @func(i32* dereferenceable(16) %P, i32* dereferenceable(8) %P) + call void @func_many(i32* align 8 %P1) + call void @func_argattr(i32* %P2, i32* %P3) + ret void +}