diff --git a/llvm/include/llvm/Transforms/IPO/ImplementsAttrResolver.h b/llvm/include/llvm/Transforms/IPO/ImplementsAttrResolver.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/IPO/ImplementsAttrResolver.h @@ -0,0 +1,27 @@ +//===- ImplementsAttrResolver.h - Repl. specification with impls. - C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TODO +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_IPO_IMPLEMENTSATTRRESOLVER_H +#define LLVM_TRANSFORMS_IPO_IMPLEMENTSATTRRESOLVER_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +struct ImplementsAttrResolverPass + : public PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_IPO_DEADARGUMENTELIMINATION_H diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -106,6 +106,7 @@ #include "llvm/Transforms/IPO/InferFunctionAttrs.h" #include "llvm/Transforms/IPO/Inliner.h" #include "llvm/Transforms/IPO/Internalize.h" +#include "llvm/Transforms/IPO/ImplementsAttrResolver.h" #include "llvm/Transforms/IPO/LoopExtractor.h" #include "llvm/Transforms/IPO/LowerTypeTests.h" #include "llvm/Transforms/IPO/MergeFunctions.h" @@ -278,6 +279,10 @@ "enable-npm-O3-nontrivial-unswitch", cl::init(true), cl::Hidden, cl::ZeroOrMore, cl::desc("Enable non-trivial loop unswitching for -O3")); +static cl::opt EnableImplementsAttrResolver( + "enable-implements-attr-resolver", cl::init(false), cl::Hidden, + cl::ZeroOrMore, cl::desc("Enable the implements attribute resolver")); + PipelineTuningOptions::PipelineTuningOptions() { LoopInterleaving = true; LoopVectorization = true; @@ -1131,6 +1136,12 @@ // globals. MPM.addPass(DeadArgumentEliminationPass()); + // Replace calls to functions specifications with their "implementation". + // See the `implements` and `specification` clang attributes and the + // `implements` LLVM-IR attribute. + if (EnableImplementsAttrResolver) + MPM.addPass(ImplementsAttrResolverPass()); + // Create a small function pass pipeline to cleanup after all the global // optimizations. FunctionPassManager GlobalCleanupPM(DebugLogging); @@ -1659,6 +1670,12 @@ // is fixed. MPM.addPass(WholeProgramDevirtPass(ExportSummary, nullptr)); + // Replace calls to functions specifications with their "implementation". + // See the `implements` and `specification` clang attributes and the + // `implements` LLVM-IR attribute. + if (EnableImplementsAttrResolver) + MPM.addPass(ImplementsAttrResolverPass()); + // Stop here at -O1. if (Level == OptimizationLevel::O1) { // The LowerTypeTestsPass needs to run to lower type metadata and the @@ -1892,6 +1909,12 @@ MPM.addPass(AlwaysInlinerPass( /*InsertLifetimeIntrinsics=*/PTO.Coroutines)); + // Replace calls to functions specifications with their "implementation". + // See the `implements` and `specification` clang attributes and the + // `implements` LLVM-IR attribute. + if (EnableImplementsAttrResolver) + MPM.addPass(ImplementsAttrResolverPass()); + if (PTO.MergeFunctions) MPM.addPass(MergeFunctionsPass()); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -71,6 +71,7 @@ MODULE_PASS("instrprof", InstrProfiling()) MODULE_PASS("internalize", InternalizePass()) MODULE_PASS("invalidate", InvalidateAllAnalysesPass()) +MODULE_PASS("implements-attr-resolver", ImplementsAttrResolverPass()) MODULE_PASS("ipsccp", IPSCCPPass()) MODULE_PASS("iroutliner", IROutlinerPass()) MODULE_PASS("print-ir-similarity", IRSimilarityAnalysisPrinterPass(dbgs())) diff --git a/llvm/lib/Transforms/IPO/CMakeLists.txt b/llvm/lib/Transforms/IPO/CMakeLists.txt --- a/llvm/lib/Transforms/IPO/CMakeLists.txt +++ b/llvm/lib/Transforms/IPO/CMakeLists.txt @@ -25,6 +25,7 @@ InlineSimple.cpp Inliner.cpp Internalize.cpp + ImplementsAttrResolver.cpp LoopExtractor.cpp LowerTypeTests.cpp MergeFunctions.cpp diff --git a/llvm/lib/Transforms/IPO/ImplementsAttrResolver.cpp b/llvm/lib/Transforms/IPO/ImplementsAttrResolver.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/IPO/ImplementsAttrResolver.cpp @@ -0,0 +1,44 @@ +//===- ImplementsAttrResolver.cpp - Repl. specifications w implemenations -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TODO +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/ImplementsAttrResolver.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/PassInstrumentation.h" + +#define DEBUG_TYPE "implements-attr-resolver" + +using namespace llvm; + +PreservedAnalyses ImplementsAttrResolverPass::run(Module &M, + ModuleAnalysisManager &) { + for (Function &Impl : M) { + const Attribute &A = Impl.getFnAttribute("implements"); + if (!A.isValid()) + continue; + + const StringRef SpecificationName = A.getValueAsString(); + Function *Specification = M.getFunction(SpecificationName); + if (!Specification) { + LLVM_DEBUG(dbgs() << "Found implementation '" << Impl.getName() + << "' but no matching specification with name '" + << SpecificationName + << "', potentially inlined and/or eliminated.\n"); + continue; + } + LLVM_DEBUG(dbgs() << "Replace specification '" << Specification->getName() + << "' with implementation '" << Impl.getName() << "'\n"); + Specification->replaceAllUsesWith( + ConstantExpr::getBitCast(&Impl, Specification->getType())); + } + return PreservedAnalyses::all(); +} diff --git a/llvm/test/Transforms/ImplementsAttrResolver/intrinsic_implementation.ll b/llvm/test/Transforms/ImplementsAttrResolver/intrinsic_implementation.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/ImplementsAttrResolver/intrinsic_implementation.ll @@ -0,0 +1,38 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=implements-attr-resolver -S | FileCheck %s + +@llvm.compiler.used = appending global [2 x i8*] [i8* bitcast (float (float)* @my_sin to i8*), i8* bitcast (float (float)* @my_cos_type_mismatch to i8*)], section "llvm.metadata" + +define internal float @my_sin(float %d) "implements"="llvm.sin.f32" { +; CHECK-LABEL: @my_sin( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float [[D:%.*]] +; +entry: + ret float %d +} + +define internal float @my_cos_type_mismatch(float %d) "implements"="llvm.cos.f64" { +; CHECK-LABEL: @my_cos_type_mismatch( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret float [[D:%.*]] +; +entry: + ret float %d +} + +define double @foo(double %d, float %f) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call fast float @my_sin(float [[F:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = tail call fast double bitcast (float (float)* @my_cos_type_mismatch to double (double)*)(double [[D:%.*]]) +; CHECK-NEXT: ret double [[TMP1]] +; +entry: + %0 = tail call fast float @llvm.sin.f32(float %f) + %1 = tail call fast double @llvm.cos.f64(double %d) + ret double %1 +} + +declare float @llvm.sin.f32(float) +declare double @llvm.cos.f64(double)