Index: lib/Transforms/Scalar/Scalarizer.cpp =================================================================== --- lib/Transforms/Scalar/Scalarizer.cpp +++ lib/Transforms/Scalar/Scalarizer.cpp @@ -16,6 +16,7 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/Pass.h" @@ -148,6 +149,7 @@ bool visitPHINode(PHINode &); bool visitLoadInst(LoadInst &); bool visitStoreInst(StoreInst &); + bool visitCallInst(CallInst &I); static void registerOptions() { // This is disabled by default because having separate loads and stores @@ -169,6 +171,8 @@ template bool splitBinary(Instruction &, const T &); + bool splitCall(CallInst &CI); + ScatterMap Scattered; GatherList Gathered; unsigned ParallelLoopAccessMDKind; @@ -394,6 +398,75 @@ return true; } +static bool isTriviallyScalariable(Intrinsic::ID ID) { + return isTriviallyVectorizable(ID); +} + +// All of the current scalarizable intrinsics only have one mangled type. +static Function *getScalarIntrinsicDeclaration(Module *M, + Intrinsic::ID ID, + VectorType *Ty) { + return Intrinsic::getDeclaration(M, ID, { Ty->getScalarType() }); +} + +/// If a call to a vector typed intrinsic function, split into a scalar call per +/// element if possible for the intrinsic. +bool Scalarizer::splitCall(CallInst &CI) { + VectorType *VT = dyn_cast(CI.getType()); + if (!VT) + return false; + + Function *F = CI.getCalledFunction(); + if (!F) + return false; + + Intrinsic::ID ID = F->getIntrinsicID(); + if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID)) + return false; + + unsigned NumElems = VT->getNumElements(); + unsigned NumArgs = CI.getNumArgOperands(); + + ValueVector ScalarOperands(NumArgs); + SmallVector Scattered(NumArgs); + + Scattered.resize(NumArgs); + + // Assumes that any vector type has the same number of elements as the return vector type, which is true for + for (unsigned I = 0; I != NumArgs; ++I) { + if (hasVectorInstrinsicScalarOpd(ID, I)) + ScalarOperands[I] = CI.getOperand(I); + else { + Scattered[I] = scatter(&CI, CI.getOperand(I)); + assert(Scattered[I].size() == NumElems && "mismatched call operands"); + } + } + + ValueVector Res(NumElems); + ValueVector ScalarCallOps(NumArgs); + + Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, VT); + IRBuilder<> Builder(&CI); + + // Perform actual scalarization, taking care to preserve any scalar operands. + for (unsigned Elem = 0; Elem < NumElems; ++Elem) { + ScalarCallOps.clear(); + + for (unsigned J = 0; J != NumArgs; ++J) { + if (hasVectorInstrinsicScalarOpd(ID, J)) + ScalarCallOps.push_back(ScalarOperands[J]); + else + ScalarCallOps.push_back(Scattered[J][Elem]); + } + + Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps, + CI.getName() + ".i" + Twine(Elem)); + } + + gather(&CI, Res); + return true; +} + bool Scalarizer::visitSelectInst(SelectInst &SI) { VectorType *VT = dyn_cast(SI.getType()); if (!VT) @@ -642,6 +715,10 @@ return true; } +bool Scalarizer::visitCallInst(CallInst &CI) { + return splitCall(CI); +} + // Delete the instructions that we scalarized. If a full vector result // is still needed, recreate it using InsertElements. bool Scalarizer::finish() { Index: test/Transforms/Scalarizer/intrinsics.ll =================================================================== --- /dev/null +++ test/Transforms/Scalarizer/intrinsics.ll @@ -0,0 +1,85 @@ +; RUN: opt -S -scalarizer %s | FileCheck %s + +; Unary fp +declare <2 x float> @llvm.sqrt.v2f32(<2 x float>) + +; Binary fp +declare <2 x float> @llvm.minnum.v2f32(<2 x float>, <2 x float>) + +; Ternary fp +declare <2 x float> @llvm.fma.v2f32(<2 x float>, <2 x float>, <2 x float>) + +; Binary int +declare <2 x i32> @llvm.bswap.v2i32(<2 x i32>) + +; Unary int plus constant scalar operand +declare <2 x i32> @llvm.ctlz.v2i32(<2 x i32>, i1) + +; Unary fp plus any scalar operand +declare <2 x float> @llvm.powi.v2f32(<2 x float>, i32) + +; CHECK-LABEL: @scalarize_sqrt_v2f32( +; CHECK: %sqrt.i0 = call float @llvm.sqrt.f32(float %x.i0) +; CHECK: %sqrt.i1 = call float @llvm.sqrt.f32(float %x.i1) +; CHECK: %sqrt.upto0 = insertelement <2 x float> undef, float %sqrt.i0, i32 0 +; CHECK: %sqrt = insertelement <2 x float> %sqrt.upto0, float %sqrt.i1, i32 1 +; CHECK: ret <2 x float> %sqrt +define <2 x float> @scalarize_sqrt_v2f32(<2 x float> %x) #0 { + %sqrt = call <2 x float> @llvm.sqrt.v2f32(<2 x float> %x) + ret <2 x float> %sqrt +} + +; CHECK-LABEL: @scalarize_minnum_v2f32( +; CHECK: %minnum.i0 = call float @llvm.minnum.f32(float %x.i0, float %y.i0) +; CHECK: %minnum.i1 = call float @llvm.minnum.f32(float %x.i1, float %y.i1) +; CHECK: %minnum.upto0 = insertelement <2 x float> undef, float %minnum.i0, i32 0 +; CHECK: %minnum = insertelement <2 x float> %minnum.upto0, float %minnum.i1, i32 1 +; CHECK: ret <2 x float> %minnum +define <2 x float> @scalarize_minnum_v2f32(<2 x float> %x, <2 x float> %y) #0 { + %minnum = call <2 x float> @llvm.minnum.v2f32(<2 x float> %x, <2 x float> %y) + ret <2 x float> %minnum +} + +; CHECK-LABEL: @scalarize_fma_v2f32( +; CHECK: %fma.i0 = call float @llvm.fma.f32(float %x.i0, float %y.i0, float %z.i0) +; CHECK: %fma.i1 = call float @llvm.fma.f32(float %x.i1, float %y.i1, float %z.i1) +; CHECK: %fma.upto0 = insertelement <2 x float> undef, float %fma.i0, i32 0 +; CHECK: %fma = insertelement <2 x float> %fma.upto0, float %fma.i1, i32 1 +; CHECK: ret <2 x float> %fma +define <2 x float> @scalarize_fma_v2f32(<2 x float> %x, <2 x float> %y, <2 x float> %z) #0 { + %fma = call <2 x float> @llvm.fma.v2f32(<2 x float> %x, <2 x float> %y, <2 x float> %z) + ret <2 x float> %fma +} + +; CHECK-LABEL: @scalarize_bswap_v2i32( +; CHECK: %bswap.i0 = call i32 @llvm.bswap.i32(i32 %x.i0) +; CHECK: %bswap.i1 = call i32 @llvm.bswap.i32(i32 %x.i1) +; CHECK: %bswap.upto0 = insertelement <2 x i32> undef, i32 %bswap.i0, i32 0 +; CHECK: %bswap = insertelement <2 x i32> %bswap.upto0, i32 %bswap.i1, i32 1 +; CHECK: ret <2 x i32> %bswap +define <2 x i32> @scalarize_bswap_v2i32(<2 x i32> %x) #0 { + %bswap = call <2 x i32> @llvm.bswap.v2i32(<2 x i32> %x) + ret <2 x i32> %bswap +} + +; CHECK-LABEL: @scalarize_ctlz_v2i32( +; CHECK: %ctlz.i0 = call i32 @llvm.ctlz.i32(i32 %x.i0, i1 true) +; CHECK: %ctlz.i1 = call i32 @llvm.ctlz.i32(i32 %x.i1, i1 true) +; CHECK: %ctlz.upto0 = insertelement <2 x i32> undef, i32 %ctlz.i0, i32 0 +; CHECK: %ctlz = insertelement <2 x i32> %ctlz.upto0, i32 %ctlz.i1, i32 1 +; CHECK: ret <2 x i32> %ctlz +define <2 x i32> @scalarize_ctlz_v2i32(<2 x i32> %x) #0 { + %ctlz = call <2 x i32> @llvm.ctlz.v2i32(<2 x i32> %x, i1 true) + ret <2 x i32> %ctlz +} + +; CHECK-LABEL: @scalarize_powi_v2f32( +; CHECK: %powi.i0 = call float @llvm.powi.f32(float %x.i0, i32 %y) +; CHECK: %powi.i1 = call float @llvm.powi.f32(float %x.i1, i32 %y) +; CHECK: %powi.upto0 = insertelement <2 x float> undef, float %powi.i0, i32 0 +; CHECK: %powi = insertelement <2 x float> %powi.upto0, float %powi.i1, i32 1 +; CHECK: ret <2 x float> %powi +define <2 x float> @scalarize_powi_v2f32(<2 x float> %x, i32 %y) #0 { + %powi = call <2 x float> @llvm.powi.v2f32(<2 x float> %x, i32 %y) + ret <2 x float> %powi +}