diff --git a/llvm/include/llvm/IR/Attributes.h b/llvm/include/llvm/IR/Attributes.h --- a/llvm/include/llvm/IR/Attributes.h +++ b/llvm/include/llvm/IR/Attributes.h @@ -107,6 +107,10 @@ const Optional &NumElemsArg); static Attribute getWithByValType(LLVMContext &Context, Type *Ty); + static Attribute::AttrKind getAttrKindFromName(StringRef AttrName); + + static StringRef getNameFromAttrKind(Attribute::AttrKind AttrKind); + //===--------------------------------------------------------------------===// // Attribute Accessors //===--------------------------------------------------------------------===// diff --git a/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/KnowledgeRetention.h @@ -0,0 +1,41 @@ +//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contain tools to preserve informations. They should be used before +// performing a transformation moving and deleting instruction as those +// transformation may remove or worsen information that can be derived from te +// IR. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H +#define LLVM_TRANSFORMS_UTILS_ASSUMEBUILDER_H + +#include "llvm/IR/Instruction.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +/// Build a call to llvm.assume to preserve informations that can be derived +/// from the given instruction. +/// If no information derived from \p I, this call returns null. +/// The returned instruction is not inserted anywhere. +CallInst *BuildAssumeFromInst(const Instruction *I, Module *M); +CallInst *BuildAssumeFromInst(Instruction *I) { + return BuildAssumeFromInst(I, I->getModule()); +} + +/// This pass will try to build an llvm.assume for every instruction in the +/// function. Its main purpose is testing. +struct AssumeBuilderPass : public PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +} // namespace llvm + +#endif diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp --- a/llvm/lib/IR/Attributes.cpp +++ b/llvm/lib/IR/Attributes.cpp @@ -176,6 +176,29 @@ return get(Context, AllocSize, packAllocSizeArgs(ElemSizeArg, NumElemsArg)); } +Attribute::AttrKind Attribute::getAttrKindFromName(StringRef AttrName) { + return StringSwitch(AttrName) +#define GET_ATTR_NAMES +#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ + .Case(#DISPLAY_NAME, Attribute::ENUM_NAME) +#include "llvm/IR/Attributes.inc" + .Default(Attribute::None); +} + +StringRef Attribute::getNameFromAttrKind(Attribute::AttrKind AttrKind) { + switch (AttrKind) { +#define GET_ATTR_NAMES +#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ + case AttrKind::ENUM_NAME: \ + return #DISPLAY_NAME; +#include "llvm/IR/Attributes.inc" + case AttrKind::None: + return "none"; + default: + llvm_unreachable("invalid Kind"); + } +} + //===----------------------------------------------------------------------===// // Attribute Accessor Methods //===----------------------------------------------------------------------===// diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -127,17 +127,8 @@ return LLVMGetMDKindIDInContext(LLVMGetGlobalContext(), Name, SLen); } -static Attribute::AttrKind getAttrKindFromName(StringRef AttrName) { - return StringSwitch(AttrName) -#define GET_ATTR_NAMES -#define ATTRIBUTE_ENUM(ENUM_NAME, DISPLAY_NAME) \ - .Case(#DISPLAY_NAME, Attribute::ENUM_NAME) -#include "llvm/IR/Attributes.inc" - .Default(Attribute::None); -} - unsigned LLVMGetEnumAttributeKindForName(const char *Name, size_t SLen) { - return getAttrKindFromName(StringRef(Name, SLen)); + return Attribute::getAttrKindFromName(StringRef(Name, SLen)); } unsigned LLVMGetLastEnumAttributeKind(void) { diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -173,6 +173,7 @@ #include "llvm/Transforms/Utils/CanonicalizeAliases.h" #include "llvm/Transforms/Utils/EntryExitInstrumenter.h" #include "llvm/Transforms/Utils/InjectTLIMappings.h" +#include "llvm/Transforms/Utils/KnowledgeRetention.h" #include "llvm/Transforms/Utils/LCSSA.h" #include "llvm/Transforms/Utils/LibCallsShrinkWrap.h" #include "llvm/Transforms/Utils/LoopSimplify.h" diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -160,6 +160,7 @@ FUNCTION_PASS("adce", ADCEPass()) FUNCTION_PASS("add-discriminators", AddDiscriminatorsPass()) FUNCTION_PASS("aggressive-instcombine", AggressiveInstCombinePass()) +FUNCTION_PASS("assume-builder", AssumeBuilderPass()) FUNCTION_PASS("alignment-from-assumptions", AlignmentFromAssumptionsPass()) FUNCTION_PASS("bdce", BDCEPass()) FUNCTION_PASS("bounds-checking", BoundsCheckingPass()) diff --git a/llvm/lib/Transforms/Utils/CMakeLists.txt b/llvm/lib/Transforms/Utils/CMakeLists.txt --- a/llvm/lib/Transforms/Utils/CMakeLists.txt +++ b/llvm/lib/Transforms/Utils/CMakeLists.txt @@ -28,6 +28,7 @@ InjectTLIMappings.cpp InstructionNamer.cpp IntegerDivision.cpp + KnowledgeRetention.cpp LCSSA.cpp LibCallsShrinkWrap.cpp Local.cpp diff --git a/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Utils/KnowledgeRetention.cpp @@ -0,0 +1,160 @@ +//===- KnowledgeRetention.h - utilities to preserve informations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/KnowledgeRetention.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; + +namespace { + +cl::opt ShouldPreserveAllAttributes( + "assume-preserve-all", cl::init(false), cl::Hidden, + cl::desc("enable preservation of all attrbitues. even those that are " + "unlikely to be usefull")); + +struct AssumedKnowledge { + const char *Name; + Value *Argument; + enum { + None, + Empty, + Tombstone, + }; + /// Contain the argument and a flag if needed. + llvm::PointerIntPair WasOn; +}; + +} // namespace + +template <> struct DenseMapInfo { + static AssumedKnowledge getEmptyKey() { + return {nullptr, nullptr, {nullptr, AssumedKnowledge::Empty}}; + } + static AssumedKnowledge getTombstoneKey() { + return {nullptr, nullptr, {nullptr, AssumedKnowledge::Tombstone}}; + } + static unsigned getHashValue(const AssumedKnowledge &AK) { + return hash_combine(AK.Name, AK.Argument, AK.WasOn.getPointer()); + } + static bool isEqual(const AssumedKnowledge &LHS, + const AssumedKnowledge &RHS) { + return LHS.WasOn == RHS.WasOn && LHS.Name == RHS.Name && + LHS.Argument == RHS.Argument; + } +}; + +namespace { + +/// Deterministically compare OperandBundleDef. +/// The ordering is: +/// - by the name of the attribute, (doesn't change) +/// - then by the Value of the argument, (doesn't change) +/// - lastly by the Name of the current Value it WasOn. (may change) +/// This order is deterministic and allows looking for the right kind of +/// attribute with binary search. However finding the right WasOn needs to be +/// done via linear search because values can get remplaced. +bool isLowerOpBundle(const OperandBundleDef &LHS, const OperandBundleDef &RHS) { + auto getTuple = [](const OperandBundleDef &Op) { + return std::make_tuple( + Op.getTag(), + Op.input_size() < 2 + ? 0 + : cast(*std::next(Op.input_begin()))->getZExtValue(), + Op.input_size() < 1 ? StringRef("") : (*Op.input_begin())->getName()); + }; + return getTuple(LHS) < getTuple(RHS); +} + +/// This class contain all knowledge that have been gather while building an +/// llvm.assume and the function to manipulate it. +struct AssumeBuilderState { + Module *M; + + SmallDenseSet AssumedKnowledgeSet; + + AssumeBuilderState(Module *M) : M(M) {} + + void addAttribute(Attribute Attr, Value *WasOn) { + StringRef Name; + Value *AttrArg = nullptr; + if (Attr.isStringAttribute()) + if (ShouldPreserveAllAttributes) + Name = Attr.getKindAsString(); + else + return; + else + Name = Attribute::getNameFromAttrKind(Attr.getKindAsEnum()); + if (Attr.isIntAttribute()) + AttrArg = ConstantInt::get(Type::getInt64Ty(M->getContext()), + Attr.getValueAsInt()); + AssumedKnowledgeSet.insert( + {Name.data(), AttrArg, {WasOn, AssumedKnowledge::None}}); + } + + void addCall(const CallBase *Call) { + auto addAttrList = [&](AttributeList AttrList) { + for (unsigned Idx = AttributeList::FirstArgIndex; + Idx < AttrList.getNumAttrSets(); Idx++) + for (Attribute Attr : AttrList.getAttributes(Idx)) + addAttribute(Attr, Call->getArgOperand(Idx - 1)); + if (ShouldPreserveAllAttributes) + for (Attribute Attr : AttrList.getFnAttributes()) + addAttribute(Attr, nullptr); + }; + addAttrList(Call->getAttributes()); + if (Function *Fn = Call->getCalledFunction()) + addAttrList(Fn->getAttributes()); + } + + CallInst *build() { + if (AssumedKnowledgeSet.empty()) + return nullptr; + Function *FnAssume = Intrinsic::getDeclaration(M, Intrinsic::assume); + LLVMContext &C = M->getContext(); + SmallVector OpBundle; + for (const AssumedKnowledge &Elem : AssumedKnowledgeSet) { + SmallVector Args; + if (Elem.WasOn.getPointer()) + Args.push_back(Elem.WasOn.getPointer()); + if (Elem.Argument) + Args.push_back(Elem.Argument); + OpBundle.push_back(OperandBundleDefT(Elem.Name, Args)); + } + llvm::sort(OpBundle, isLowerOpBundle); + return CallInst::Create( + FnAssume, ArrayRef({ConstantInt::getTrue(C)}), OpBundle); + } + + void addInstruction(const Instruction *I) { + if (auto *Call = dyn_cast(I)) + addCall(Call); + // TODO: Add support for the other Instructions. + // TODO: Maybe we should look around and merge with other llvm.assume. + } +}; + +} // namespace + +CallInst *llvm::BuildAssumeFromInst(const Instruction *I, Module *M) { + AssumeBuilderState Builder(M); + Builder.addInstruction(I); + return Builder.build(); +} + +PreservedAnalyses AssumeBuilderPass::run(Function &F, + FunctionAnalysisManager &AM) { + for (Instruction &I : instructions(F)) + if (Instruction *Assume = BuildAssumeFromInst(&I)) + Assume->insertBefore(&I); + return PreservedAnalyses::all(); +} diff --git a/llvm/test/Transforms/Util/assume-builder.ll b/llvm/test/Transforms/Util/assume-builder.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Util/assume-builder.ll @@ -0,0 +1,66 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -passes='assume-builder' -S %s | FileCheck %s --check-prefixes=BASIC +; RUN: opt -passes='assume-builder' --assume-preserve-all -S %s | FileCheck %s --check-prefixes=ALL + +declare void @func(i32*, i32*) +declare void @func_cold(i32*) cold +declare void @func_strbool(i32*) "no-jump-tables" +declare void @func_many(i32*) "no-jump-tables" nounwind "less-precise-fpmad" willreturn norecurse +declare void @func_argattr(i32* align 8, i32* nonnull) nounwind + +define void @test(i32* %P, i32* %P1, i32* %P2, i32* %P3) { +; BASIC-LABEL: @test( +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P:%.*]], i64 16), "nonnull"(i32* [[P]]) ] +; BASIC-NEXT: call void @func(i32* nonnull dereferenceable(16) [[P]], i32* null) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1:%.*]], i64 12), "nonnull"(i32* [[P]]) ] +; BASIC-NEXT: call void @func(i32* dereferenceable(12) [[P1]], i32* nonnull [[P]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1]], i64 12) ] +; BASIC-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) #0 +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1]], i64 12) ] +; BASIC-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) +; BASIC-NEXT: call void @func(i32* [[P1]], i32* [[P]]) +; BASIC-NEXT: call void @func_strbool(i32* [[P1]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P]], i64 8), "dereferenceable"(i32* [[P]], i64 16) ] +; BASIC-NEXT: call void @func(i32* dereferenceable(16) [[P]], i32* dereferenceable(8) [[P]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P1]], i64 8) ] +; BASIC-NEXT: call void @func_many(i32* align 8 [[P1]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P2:%.*]], i64 8), "nonnull"(i32* [[P3:%.*]]) ] +; BASIC-NEXT: call void @func_argattr(i32* [[P2]], i32* [[P3]]) +; BASIC-NEXT: call void @llvm.assume(i1 true) [ "nonnull"(i32* [[P]]), "nonnull"(i32* [[P1]]) ] +; BASIC-NEXT: call void @func(i32* nonnull [[P1]], i32* nonnull [[P]]) +; BASIC-NEXT: ret void +; +; ALL-LABEL: @test( +; ALL-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P:%.*]], i64 16), "nonnull"(i32* [[P]]) ] +; ALL-NEXT: call void @func(i32* nonnull dereferenceable(16) [[P]], i32* null) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P1:%.*]], i64 12), "nonnull"(i32* [[P]]) ] +; ALL-NEXT: call void @func(i32* dereferenceable(12) [[P1]], i32* nonnull [[P]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "cold"(), "dereferenceable"(i32* [[P1]], i64 12) ] +; ALL-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) #0 +; ALL-NEXT: call void @llvm.assume(i1 true) [ "cold"(), "dereferenceable"(i32* [[P1]], i64 12) ] +; ALL-NEXT: call void @func_cold(i32* dereferenceable(12) [[P1]]) +; ALL-NEXT: call void @func(i32* [[P1]], i32* [[P]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "no-jump-tables"() ] +; ALL-NEXT: call void @func_strbool(i32* [[P1]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "dereferenceable"(i32* [[P]], i64 8), "dereferenceable"(i32* [[P]], i64 16) ] +; ALL-NEXT: call void @func(i32* dereferenceable(16) [[P]], i32* dereferenceable(8) [[P]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P1]], i64 8), "less-precise-fpmad"(), "no-jump-tables"(), "norecurse"(), "nounwind"(), "willreturn"() ] +; ALL-NEXT: call void @func_many(i32* align 8 [[P1]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "align"(i32* [[P2:%.*]], i64 8), "nonnull"(i32* [[P3:%.*]]), "nounwind"() ] +; ALL-NEXT: call void @func_argattr(i32* [[P2]], i32* [[P3]]) +; ALL-NEXT: call void @llvm.assume(i1 true) [ "nonnull"(i32* [[P]]), "nonnull"(i32* [[P1]]) ] +; ALL-NEXT: call void @func(i32* nonnull [[P1]], i32* nonnull [[P]]) +; ALL-NEXT: ret void +; + call void @func(i32* nonnull dereferenceable(16) %P, i32* null) + call void @func(i32* dereferenceable(12) %P1, i32* nonnull %P) + call void @func_cold(i32* dereferenceable(12) %P1) cold + call void @func_cold(i32* dereferenceable(12) %P1) + call void @func(i32* %P1, i32* %P) + call void @func_strbool(i32* %P1) + call void @func(i32* dereferenceable(16) %P, i32* dereferenceable(8) %P) + call void @func_many(i32* align 8 %P1) + call void @func_argattr(i32* %P2, i32* %P3) + call void @func(i32* nonnull %P1, i32* nonnull %P) + ret void +}