Index: lib/Target/NVPTX/NVPTXTargetTransformInfo.h =================================================================== --- lib/Target/NVPTX/NVPTXTargetTransformInfo.h +++ lib/Target/NVPTX/NVPTXTargetTransformInfo.h @@ -56,6 +56,13 @@ // calls are particularly expensive in NVPTX. unsigned getInliningThresholdMultiplier() { return 5; } + unsigned getNumberOfRegisters(bool /*Vector*/) const { return 1; } + unsigned getRegisterBitWidth(bool Vector) const { return Vector ? 128 : 64; } + + int getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, + unsigned AddressSpace); + int getVectorInstrCost(unsigned Opcode, Type *Ty, unsigned Index); + int getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue, Index: lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp =================================================================== --- lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -118,6 +118,43 @@ } } +int NVPTXTTIImpl::getMemoryOpCost(unsigned Opcode, Type *Src, + unsigned Alignment, unsigned AddressSpace) { + int Cost = BaseT::getMemoryOpCost(Opcode, Src, Alignment, AddressSpace); + + // Model vector loads and stores (of vector types that ptx supports) as half + // the cost of the corresponding set of scalar loads and stores. This is a + // bit optimistic, but it encourages the SLP optimizer to use vectorized loads + // and stores, which we want. + // + // FIXME: We ignore the Alignment arg, even though PTX can only handle vector + // loads/stores that are aligned to the vector's width, because the SLP + // vectorizer queries us with an alignment of 1. + if (Src->isVectorTy()) { + int N = Src->getVectorNumElements(); + int SZ = Src->getScalarSizeInBits(); + if ((SZ <= 64 && N == 2) || (SZ <= 32 && N == 4)) { + return Cost / 2; + } + } + return Cost; +} + +int NVPTXTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Ty, + unsigned Index) { + switch (Opcode) { + case Instruction::InsertElement: + case Instruction::ExtractElement: + // Model vector insertions and extractions as free. PTX only supports + // vector loads and stores, and in those you can specify a list of + // general-purpose registers, {a, b, c, d}. So vector + // insertions/extractions get optimized away when we lower to PTX. + return 0; + default: + return BaseT::getVectorInstrCost(Opcode, Ty, Index); + } +} + void NVPTXTTIImpl::getUnrollingPreferences(Loop *L, TTI::UnrollingPreferences &UP) { BaseT::getUnrollingPreferences(L, UP); Index: lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- lib/Transforms/Vectorize/SLPVectorizer.cpp +++ lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1797,8 +1797,8 @@ } bool BoUpSLP::isFullyVectorizableTinyTree() { - DEBUG(dbgs() << "SLP: Check whether the tree with height " << - VectorizableTree.size() << " is fully vectorizable .\n"); + DEBUG(dbgs() << "SLP: Check whether the tree with height " + << VectorizableTree.size() << " is fully vectorizable.\n"); // We only handle trees of height 2. if (VectorizableTree.size() != 2) @@ -1810,9 +1810,10 @@ isSplat(VectorizableTree[1].Scalars))) return true; - // Gathering cost would be too much for tiny trees. - if (VectorizableTree[0].NeedToGather || VectorizableTree[1].NeedToGather) - return false; + // Gathering cost would be too much for tiny trees, unless gathers are free. + for (TreeEntry &TE : VectorizableTree) + if (TE.NeedToGather && getGatherCost(TE.Scalars[0]) > 0) + return false; return true; } Index: test/Transforms/SLPVectorizer/NVPTX/lit.local.cfg =================================================================== --- /dev/null +++ test/Transforms/SLPVectorizer/NVPTX/lit.local.cfg @@ -0,0 +1,3 @@ +if not 'NVPTX' in config.root.targets: + config.unsupported = True + Index: test/Transforms/SLPVectorizer/NVPTX/simple.ll =================================================================== --- /dev/null +++ test/Transforms/SLPVectorizer/NVPTX/simple.ll @@ -0,0 +1,94 @@ +; RUN: opt < %s -basicaa -slp-vectorizer -S | FileCheck %s + +; Check that we vectorize the store in the following function. + +target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64" +target triple = "nvptx64-nvidia-cuda" + +declare float @llvm.nvvm.ex2.approx.ftz.f(float) readnone norecurse nounwind +declare float @llvm.nvvm.lg2.approx.ftz.f(float) readnone norecurse nounwind +declare <4 x float > @llvm.nvvm.ldg.global.f.v4f32.p0v4f32(<4 x float >*, i32) readonly argmemonly norecurse nounwind + +; Check that we vectorize loads and stores in a trivial function. +; CHECK-LABEL: @small_fn +define void @small_fn(float* %in, float* %out) { + %p1 = getelementptr inbounds float, float* %in, i64 0 + %in1 = load float, float* %p1, align 16 + %p2 = getelementptr inbounds float, float* %in, i64 1 + %in2 = load float, float* %p2, align 4 + %p3 = getelementptr inbounds float, float* %in, i64 2 + %in3 = load float, float* %p3, align 8 + %p4 = getelementptr inbounds float, float* %in, i64 3 + %in4 = load float, float* %p4, align 4 + ; CHECK: load <4 x float>, <4 x float>* %{{[0-9]+}}, align 16 + + %t1 = fadd float %in1, 1.0 + %t2 = fadd float %in2, 2.0 + %t3 = fadd float %in3, 3.0 + %t4 = fadd float %in4, 4.0 + + %o1 = getelementptr inbounds float, float* %out, i64 0 + store float %t1, float* %o1, align 16 + %o2 = getelementptr inbounds float, float* %out, i64 1 + store float %t2, float* %o2, align 4 + %o3 = getelementptr inbounds float, float* %out, i64 2 + store float %t3, float* %o3, align 8 + %o4 = getelementptr inbounds float, float* %out, i64 3 + ; CHECK: store <4 x float> %{{[0-9]+}}, <4 x float>* %{{[0-9]+}}, align 16 + store float %t4, float* %o4, align 4 + ret void +} + +; Check that we vectorize stores in a bigger function. We don't currently +; vectorize the loads in this function because the loads are followed by a +; non-vectorizable function call. +; +; CHECK-LABEL: @big_fn +define void @big_fn(float* %in1, i64 %in1_idx, <4 x float>* %in2, + float* %out, i64 %out_idx) { + %1 = getelementptr inbounds float, float* %in1, i64 0 + %2 = load float, float* %1, align 16 + %p2 = getelementptr inbounds float, float* %in1, i64 1 + %3 = load float, float* %p2, align 4 + %p3 = getelementptr inbounds float, float* %in1, i64 2 + %4 = load float, float* %p3, align 8 + %p4 = getelementptr inbounds float, float* %in1, i64 3 + %5 = load float, float* %p4, align 4 + + %6 = fmul float %3, 0x3FF7154760000000 + %7 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %6) + %8 = fmul float %4, 0x3FF7154760000000 + %9 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %8) + %10 = fmul float %5, 0x3FF7154760000000 + %11 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %10) + %12 = fmul float %2, 0x3FF7154760000000 + %13 = tail call float @llvm.nvvm.ex2.approx.ftz.f(float %12) + + %14 = tail call <4 x float> @llvm.nvvm.ldg.global.f.v4f32.p0v4f32(<4 x float>* %in2, i32 16) + %15 = extractelement <4 x float> %14, i32 0 + %16 = extractelement <4 x float> %14, i32 1 + %17 = extractelement <4 x float> %14, i32 2 + %18 = extractelement <4 x float> %14, i32 3 + %19 = tail call float @llvm.nvvm.lg2.approx.ftz.f(float %16) + %20 = fmul float %19, 0x3FE62E4300000000 + %21 = tail call float @llvm.nvvm.lg2.approx.ftz.f(float %17) + %22 = fmul float %21, 0x3FE62E4300000000 + %23 = tail call float @llvm.nvvm.lg2.approx.ftz.f(float %18) + %24 = fmul float %23, 0x3FE62E4300000000 + %25 = tail call float @llvm.nvvm.lg2.approx.ftz.f(float %15) + %26 = fmul float %25, 0x3FE62E4300000000 + %27 = fadd float %7, %20 + %28 = fadd float %9, %22 + %29 = fadd float %11, %24 + %30 = fadd float %13, %26 + %31 = getelementptr inbounds float, float* %out, i64 %out_idx + store float %27, float* %31, align 16 + %32 = getelementptr inbounds float, float* %31, i64 1 + store float %28, float* %32, align 4 + %33 = getelementptr inbounds float, float* %31, i64 2 + store float %29, float* %33, align 8 + %34 = getelementptr inbounds float, float* %31, i64 3 + ; CHECK: store <4 x float> %{{[0-9]+}}, <4 x float>* %{{[0-9]+}}, align 16 + store float %30, float* %34, align 4 + ret void +}