Index: llvm/include/llvm/IR/Attributes.h =================================================================== --- llvm/include/llvm/IR/Attributes.h +++ llvm/include/llvm/IR/Attributes.h @@ -107,6 +107,10 @@ const Optional &NumElemsArg); static Attribute getWithByValType(LLVMContext &Context, Type *Ty); + static Attribute::AttrKind getAttrKindFromName(StringRef AttrName); + + static StringRef getNameFromAttrKind(Attribute::AttrKind AttrKind); + //===--------------------------------------------------------------------===// // Attribute Accessors //===--------------------------------------------------------------------===// Index: llvm/include/llvm/Transforms/Utils/AssumeBuilder.h =================================================================== --- /dev/null +++ llvm/include/llvm/Transforms/Utils/AssumeBuilder.h @@ -0,0 +1,35 @@ +//===- AssumeBuilder.h - Build assume to preserve informations---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H +#define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H + +#include "llvm/ADT/SmallSet.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +/// Build an instruction to preserve informations that can be derived from +/// this it. if no information was found, returns null. +/// The returned instruction is not inserted anywhere. +Instruction *BuildAssumeFromInst(const Instruction *I, Module *M); +Instruction *BuildAssumeFromInst(Instruction *I) { + return BuildAssumeFromInst(I, I->getModule()); +} + +/// This pass will try to build an llvm.assume for every instruction in the +/// function. Its main purpose is testing. +struct AssumeBuilderPass : public PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +} // namespace llvm + +#endif Index: llvm/lib/IR/Attributes.cpp =================================================================== --- llvm/lib/IR/Attributes.cpp +++ llvm/lib/IR/Attributes.cpp @@ -176,6 +176,29 @@ return get(Context, AllocSize, packAllocSizeArgs(ElemSizeArg, NumElemsArg)); } +Attribute::AttrKind Attribute::getAttrKindFromName(StringRef AttrName) { + return StringSwitch(AttrName) +#define GET_ATTR_NAMES +#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ + .Case(#DISPLAY_NAME, Attribute::ENUM_NAME) +#include "llvm/IR/Attributes.inc" + .Default(Attribute::None); +} + +StringRef Attribute::getNameFromAttrKind(Attribute::AttrKind AttrKind) { + switch (AttrKind) { +#define GET_ATTR_NAMES +#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ + case AttrKind::ENUM_NAME: \ + return #DISPLAY_NAME; +#include "llvm/IR/Attributes.inc" + case AttrKind::None: + return "none"; + default: + llvm_unreachable("invalid Kind"); + } +} + //===----------------------------------------------------------------------===// // Attribute Accessor Methods //===----------------------------------------------------------------------===// Index: llvm/lib/IR/Core.cpp =================================================================== --- llvm/lib/IR/Core.cpp +++ llvm/lib/IR/Core.cpp @@ -127,17 +127,8 @@ return LLVMGetMDKindIDInContext(LLVMGetGlobalContext(), Name, SLen); } -static Attribute::AttrKind getAttrKindFromName(StringRef AttrName) { - return StringSwitch(AttrName) -#define GET_ATTR_NAMES -#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ - .Case(#DISPLAY_NAME, Attribute::ENUM_NAME) -#include "llvm/IR/Attributes.inc" - .Default(Attribute::None); -} - unsigned LLVMGetEnumAttributeKindForName(const char *Name, size_t SLen) { - return getAttrKindFromName(StringRef(Name, SLen)); + return Attribute::getAttrKindFromName(StringRef(Name, SLen)); } unsigned LLVMGetLastEnumAttributeKind(void) { Index: llvm/lib/Passes/PassBuilder.cpp =================================================================== --- llvm/lib/Passes/PassBuilder.cpp +++ llvm/lib/Passes/PassBuilder.cpp @@ -168,6 +168,7 @@ #include "llvm/Transforms/Scalar/TailRecursionElimination.h" #include "llvm/Transforms/Scalar/WarnMissedTransforms.h" #include "llvm/Transforms/Utils/AddDiscriminators.h" +#include "llvm/Transforms/Utils/AssumeBuilder.h" #include "llvm/Transforms/Utils/BreakCriticalEdges.h" #include "llvm/Transforms/Utils/CanonicalizeAliases.h" #include "llvm/Transforms/Utils/EntryExitInstrumenter.h" Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -159,6 +159,7 @@ FUNCTION_PASS("adce", ADCEPass()) FUNCTION_PASS("add-discriminators", AddDiscriminatorsPass()) FUNCTION_PASS("aggressive-instcombine", AggressiveInstCombinePass()) +FUNCTION_PASS("assume-builder", AssumeBuilderPass()) FUNCTION_PASS("alignment-from-assumptions", AlignmentFromAssumptionsPass()) FUNCTION_PASS("bdce", BDCEPass()) FUNCTION_PASS("bounds-checking", BoundsCheckingPass()) Index: llvm/lib/Transforms/Utils/AssumeBuilder.cpp =================================================================== --- /dev/null +++ llvm/lib/Transforms/Utils/AssumeBuilder.cpp @@ -0,0 +1,124 @@ +//===- AssumeBuilder.h - Build assume to preserve informations --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/AssumeBuilder.h" +#include "llvm/IR/AssemblyAnnotationWriter.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormattedStream.h" + +using namespace llvm; + +namespace { + +cl::opt ShouldPreserveAllAttributes( + "assume-preserve-all", cl::init(false), cl::Hidden, + cl::desc("enable preservation of all attrbitues. even those that are " + "unlikely to be usefull")); + +/// This class contain all knowledge that have been gather while building an +/// llvm.assume and the function to manipulate it. +struct AssumeBuilderState { + Module *M; + + struct Knowledge { + using Tuple = std::tuple; + const char *Name; + Value *Argument; + Value *WasOn; + + explicit operator Tuple() const { + return std::make_tuple(Name, Argument, WasOn); + } + bool operator==(Knowledge Other) const { + return static_cast(*this) == static_cast(Other); + } + bool operator<(Knowledge Other) const { + return static_cast(*this) < static_cast(Other); + } + }; + + SmallSet KnowledgeSet; + + AssumeBuilderState(Module *M) : M(M) {} + + void addAttribute(Attribute Attr, Value *WasOn) { + StringRef Name; + Value *AttrArg = nullptr; + if (Attr.isStringAttribute()) + if (ShouldPreserveAllAttributes) + Name = Attr.getKindAsString(); + else + return; + else + Name = Attribute::getNameFromAttrKind(Attr.getKindAsEnum()); + if (Attr.isIntAttribute()) + AttrArg = ConstantInt::get(Type::getInt64Ty(M->getContext()), + Attr.getValueAsInt()); + KnowledgeSet.insert({Name.data(), AttrArg, WasOn}); + } + + void addCall(const CallBase *Call) { + auto addAttrList = [&](AttributeList AttrList) { + for (unsigned Idx = AttributeList::FirstArgIndex; + Idx < AttrList.getNumAttrSets(); Idx++) + for (Attribute Attr : AttrList.getAttributes(Idx)) + addAttribute(Attr, Call->getArgOperand(Idx - 1)); + if (ShouldPreserveAllAttributes) + for (Attribute Attr : AttrList.getFnAttributes()) + addAttribute(Attr, Call->getCalledFunction()); + }; + addAttrList(Call->getAttributes()); + if (Function *Fn = Call->getCalledFunction()) + addAttrList(Fn->getAttributes()); + } + + Instruction *build() { + if (KnowledgeSet.empty()) + return nullptr; + Function *FnAssume = Intrinsic::getDeclaration(M, Intrinsic::assume); + LLVMContext &C = M->getContext(); + SmallVector OpBundle; + for (const Knowledge &Elem : KnowledgeSet) { + SmallVector Args; + Args.push_back(Elem.WasOn); + if (Elem.Argument) + Args.push_back(Elem.Argument); + OpBundle.push_back(OperandBundleDefT(Elem.Name, Args)); + } + return CallInst::Create( + FnAssume, ArrayRef({ConstantInt::getTrue(C)}), OpBundle); + } + + void addInstruction(const Instruction *I) { + if (auto *Call = dyn_cast(I)) + addCall(Call); + // TODO: Add support for the other Instruction. + // TODO: Maybe we should look around and merge with other llvm.assume. + } +}; + +} // namespace + +Instruction *llvm::BuildAssumeFromInst(const Instruction *I, Module *M) { + AssumeBuilderState Builder(M); + Builder.addInstruction(I); + return Builder.build(); +} + +PreservedAnalyses AssumeBuilderPass::run(Function &F, + FunctionAnalysisManager &AM) { + for (Instruction &I : instructions(F)) + if (Instruction *Assume = BuildAssumeFromInst(&I)) + Assume->insertBefore(&I); + return PreservedAnalyses::all(); +} Index: llvm/lib/Transforms/Utils/CMakeLists.txt =================================================================== --- llvm/lib/Transforms/Utils/CMakeLists.txt +++ llvm/lib/Transforms/Utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_component_library(LLVMTransformUtils ASanStackFrameLayout.cpp AddDiscriminators.cpp + AssumeBuilder.cpp BasicBlockUtils.cpp BreakCriticalEdges.cpp BuildLibCalls.cpp Index: llvm/test/Transforms/Util/assume-builder.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/Util/assume-builder.ll @@ -0,0 +1,66 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='assume-builder' -S %s | FileCheck %s --check-prefixes=BASIC +; RUN: opt -passes='assume-builder' --assume-preserve-all -S %s | FileCheck %s --check-prefixes=ALL + +declare void @func(i32*, i32*) +declare void @func_cold(i32*) cold +declare void @func_strbool(i32*) "no-jump-tables" +declare void @func_many(i32*) "no-jump-tables" nounwind "less-precise-fpmad" willreturn norecurse +declare void @func_argattr(i32* align 8, i32* nonnull) nounwind + +define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) { +; BASIC-LABEL: @test( +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "nonnull"(i32* [[P:%.*]]), "dereferenceable"(i32* [[P]], i64 16) ] +; BASIC-NEXT: call void @func(i32* nonnull dereferenceable(16) [[P]], i32* null) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1:%.*]], i64 12), "nonnull"(i32* [[P]]) ] +; BASIC-NEXT: call void @func(i32* dereferenceable(12) [[P1]], i32* nonnull [[P]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1]], i64 12) ] +; BASIC-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) #0 +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1]], i64 12) ] +; BASIC-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) +; BASIC-NEXT: call void @func(i32* [[P1]], i32* [[P]]) +; BASIC-NEXT: call void @func_strbool(i32* [[P1]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P]], i64 16), "dereferenceable"(i32* [[P]], i64 8) ] +; BASIC-NEXT: call void @func(i32* dereferenceable(16) [[P]], i32* dereferenceable(8) [[P]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P1]], i64 8) ] +; BASIC-NEXT: call void @func_many(i32* align 8 [[P1]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P2:%.*]], i64 8), "nonnull"(i32* [[P3:%.*]]) ] +; BASIC-NEXT: call void @func_argattr(i32* [[P2]], i32* [[P3]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "nonnull"(i32* [[P1]]), "nonnull"(i32* [[P]]) ] +; BASIC-NEXT: call void @func(i32* nonnull [[P1]], i32* nonnull [[P]]) +; BASIC-NEXT: ret void +; +; ALL-LABEL: @test( +; ALL-NEXT: call void @llvm.assume(i1 true) [ "nonnull"(i32* [[P:%.*]]), "dereferenceable"(i32* [[P]], i64 16) ] +; ALL-NEXT: call void @func(i32* nonnull dereferenceable(16) [[P]], i32* null) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1:%.*]], i64 12), "nonnull"(i32* [[P]]) ] +; ALL-NEXT: call void @func(i32* dereferenceable(12) [[P1]], i32* nonnull [[P]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1]], i64 12), "cold"(void (i32*)* @func_cold) ] +; ALL-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) #0 +; ALL-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1]], i64 12), "cold"(void (i32*)* @func_cold) ] +; ALL-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) +; ALL-NEXT: call void @func(i32* [[P1]], i32* [[P]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "no-jump-tables"(void (i32*)* @func_strbool) ] +; ALL-NEXT: call void @func_strbool(i32* [[P1]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P]], i64 16), "dereferenceable"(i32* [[P]], i64 8) ] +; ALL-NEXT: call void @func(i32* dereferenceable(16) [[P]], i32* dereferenceable(8) [[P]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P1]], i64 8), "norecurse"(void (i32*)* @func_many), "nounwind"(void (i32*)* @func_many), "willreturn"(void (i32*)* @func_many), "less-precise-fpmad"(void (i32*)* @func_many), "no-jump-tables"(void (i32*)* @func_many) ] +; ALL-NEXT: call void @func_many(i32* align 8 [[P1]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P2:%.*]], i64 8), "nonnull"(i32* [[P3:%.*]]), "nounwind"(void (i32*, i32*)* @func_argattr) ] +; ALL-NEXT: call void @func_argattr(i32* [[P2]], i32* [[P3]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "nonnull"(i32* [[P1]]), "nonnull"(i32* [[P]]) ] +; ALL-NEXT: call void @func(i32* nonnull [[P1]], i32* nonnull [[P]]) +; ALL-NEXT: ret void +; + call void @func(i32* nonnull dereferenceable(16) %P, i32* null) + call void @func(i32* dereferenceable(12) %P1, i32* nonnull %P) + call void @func_cold(i32* dereferenceable(12) %P1) cold + call void @func_cold(i32* dereferenceable(12) %P1) + call void @func(i32* %P1, i32* %P) + call void @func_strbool(i32* %P1) + call void @func(i32* dereferenceable(16) %P, i32* dereferenceable(8) %P) + call void @func_many(i32* align 8 %P1) + call void @func_argattr(i32* %P2, i32* %P3) + call void @func(i32* nonnull %P1, i32* nonnull %P) + ret void +}