Index: lib/Target/X86/CMakeLists.txt =================================================================== --- lib/Target/X86/CMakeLists.txt +++ lib/Target/X86/CMakeLists.txt @@ -59,6 +59,7 @@ X86TargetObjectFile.cpp X86TargetTransformInfo.cpp X86VZeroUpper.cpp + X86VectorWidthInfer.cpp X86WinAllocaExpander.cpp X86WinEHState.cpp ) Index: lib/Target/X86/X86.h =================================================================== --- lib/Target/X86/X86.h +++ lib/Target/X86/X86.h @@ -119,6 +119,12 @@ void initializeEvexToVexInstPassPass(PassRegistry &); +/// This pass tries to infer a required vector width for a function if the +/// require-vector-width attribute isn't present. +FunctionPass *createX86VectorWidthInferPass(); + +void initializeX86VectorWidthInferPass(PassRegistry &); + } // End llvm namespace #endif Index: lib/Target/X86/X86TargetMachine.cpp =================================================================== --- lib/Target/X86/X86TargetMachine.cpp +++ lib/Target/X86/X86TargetMachine.cpp @@ -80,6 +80,7 @@ initializeX86CmovConverterPassPass(PR); initializeX86ExecutionDomainFixPass(PR); initializeX86DomainReassignmentPass(PR); + initializeX86VectorWidthInferPass(PR); } static std::unique_ptr createTLOF(const Triple &TT) { @@ -436,6 +437,8 @@ } bool X86PassConfig::addPreISel() { + addPass(createX86VectorWidthInferPass()); + // Only add this pass for 32-bit x86 Windows. const Triple &TT = TM->getTargetTriple(); if (TT.isOSWindows() && TT.getArch() == Triple::x86) Index: lib/Target/X86/X86VectorWidthInfer.cpp =================================================================== --- /dev/null +++ lib/Target/X86/X86VectorWidthInfer.cpp @@ -0,0 +1,126 @@ +//===- X86VectorWidthInfer.cpp - Infer require-vector-width attribute -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file This pass tries to infer the required vector with for a function +/// if the require-vector-width attribute isn't present. +// ===---------------------------------------------------------------------===// + +#include "X86TargetMachine.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Pass.h" + +using namespace llvm; + +#define DEBUG_TYPE "x86-vector-width-fix" + +namespace { + +class X86VectorWidthInfer : public FunctionPass { +public: + static char ID; // Pass ID + + X86VectorWidthInfer() : FunctionPass(ID) { + initializeX86VectorWidthInferPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + } + + bool runOnFunction(Function &F) override; +}; + +} // end anonymous namespace + +char X86VectorWidthInfer::ID = 0; + +INITIALIZE_PASS_BEGIN(X86VectorWidthInfer, DEBUG_TYPE, + "X86 Vector Width Infer", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_END(X86VectorWidthInfer, DEBUG_TYPE, + "X86 Vector Width Infer", false, false) + +FunctionPass *llvm::createX86VectorWidthInferPass() { + return new X86VectorWidthInfer(); +} + +bool X86VectorWidthInfer::runOnFunction(Function &F) { + TargetPassConfig &TPC = getAnalysis(); + const X86Subtarget *ST = + TPC.getTM().getSubtargetImpl(F); + + // If the target doesn't support 512-bit vectors or doesn't prefer them, + // then there is nothing to do. + // TODO: Support this for 256 vs 128 as well? + if (!ST->hasVLX() || ST->getPreferVectorWidth() >= 512) + return false; + + unsigned RequiredWidth = 0; + + // If we already have a function attribute and it says that 512-bit vectors + // are required, we are done. Otherwise, make sure we keep at least its + // value. + if (F.hasFnAttribute("require-vector-width")) { + StringRef Val = F.getFnAttribute("require-vector-width").getValueAsString(); + unsigned Width; + if (!Val.getAsInteger(0, Width)) { + if (Width > 256) + return false; + RequiredWidth = Width; + } + } + + // Check for a vector return type. + Type *RetTy = F.getReturnType(); + if (RetTy->isVectorTy()) + RequiredWidth = std::max(RequiredWidth, + RetTy->getPrimitiveSizeInBits()); + + // Check for any vector arguments. + for (const auto &A : F.args()) { + Type *ArgTy = A.getType(); + if (ArgTy->isVectorTy()) + RequiredWidth = std::max(RequiredWidth, + ArgTy->getPrimitiveSizeInBits()); + } + + // Otherwise scan for any calls that need wide registers to match ABI. + // Also need this for any target specific intrinsics. + for (auto &BB : F) { + for (auto &I : BB) { + if (auto *CI = dyn_cast(&I)) { + // We can handle target independent intrinsics via type legalization so + // skip those. + if (auto *II = dyn_cast(&I)) { + StringRef Name = II->getCalledFunction()->getName(); + if (!Name.startswith("llvm.x86.")) + continue; + } + // Ok we have a call. Check its types. + Type *RetTy = CI->getType(); + if (RetTy->isVectorTy()) + RequiredWidth = std::max(RequiredWidth, + RetTy->getPrimitiveSizeInBits()); + for (Value *A : CI->arg_operands()) { + Type *ArgTy = A->getType(); + if (ArgTy->isVectorTy()) + RequiredWidth = std::max(RequiredWidth, + ArgTy->getPrimitiveSizeInBits()); + } + } + } + } + + // Remove and replace function's prefer-vector-width attribute. + // TODO this should be more generic, but this will work for 512 vs 256. + F.removeFnAttr("require-vector-width"); + F.addFnAttr("require-vector-width", (RequiredWidth > 256) ? "512" : "256"); + + return false; +}