diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -347,9 +347,9 @@ bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID, unsigned ScalarOpdIdx); -/// Identifies if the vector form of the intrinsic has a operand that has -/// an overloaded type. -bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, unsigned OpdIdx); +/// Identifies if the vector form of the intrinsic is overloaded on the type of +/// the operand at index \p OpdIdx, or on the return type if \p OpdIdx is -1. +bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, int OpdIdx); /// Returns intrinsic ID for call. /// For the input call instruction it finds mapping intrinsic and returns diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -86,6 +86,7 @@ case Intrinsic::pow: case Intrinsic::fma: case Intrinsic::fmuladd: + case Intrinsic::is_fpclass: case Intrinsic::powi: case Intrinsic::canonicalize: case Intrinsic::fptosi_sat: @@ -103,6 +104,7 @@ case Intrinsic::abs: case Intrinsic::ctlz: case Intrinsic::cttz: + case Intrinsic::is_fpclass: case Intrinsic::powi: return (ScalarOpdIdx == 1); case Intrinsic::smul_fix: @@ -116,15 +118,17 @@ } bool llvm::isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, - unsigned OpdIdx) { + int OpdIdx) { switch (ID) { case Intrinsic::fptosi_sat: case Intrinsic::fptoui_sat: + return OpdIdx == -1 || OpdIdx == 0; + case Intrinsic::is_fpclass: return OpdIdx == 0; case Intrinsic::powi: - return OpdIdx == 1; + return OpdIdx == -1 || OpdIdx == 1; default: - return false; + return OpdIdx == -1; } } diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -583,7 +583,9 @@ Scattered.resize(NumArgs); SmallVector Tys; - Tys.push_back(VT->getScalarType()); + // Add return type if intrinsic is overloaded on it. + if (isVectorIntrinsicWithOverloadTypeAtArg(ID, -1)) + Tys.push_back(VT->getScalarType()); // Assumes that any vector type has the same number of elements as the return // vector type, which is true for all current intrinsics. diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -10335,8 +10335,11 @@ Value *ScalarArg = nullptr; std::vector OpVecs; - SmallVector TysForDecl = - {FixedVectorType::get(CI->getType(), E->Scalars.size())}; + SmallVector TysForDecl; + // Add return type if intrinsic is overloaded on it. + if (isVectorIntrinsicWithOverloadTypeAtArg(IID, -1)) + TysForDecl.push_back( + FixedVectorType::get(CI->getType(), E->Scalars.size())); for (int j = 0, e = CI->arg_size(); j < e; ++j) { ValueList OpVL; // Some intrinsics have scalar arguments. This argument should not be diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -481,7 +481,14 @@ State.setDebugLocFromInst(&CI); for (unsigned Part = 0; Part < State.UF; ++Part) { - SmallVector TysForDecl = {CI.getType()}; + SmallVector TysForDecl; + // Add return type if intrinsic is overloaded on it. + if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, -1)) { + TysForDecl.push_back( + State.VF.isVector() + ? VectorType::get(CI.getType()->getScalarType(), State.VF) + : CI.getType()); + } SmallVector Args; for (const auto &I : enumerate(operands())) { // Some intrinsics have a scalar argument - don't replace it with a @@ -500,9 +507,6 @@ Function *VectorF; if (VectorIntrinsicID != Intrinsic::not_intrinsic) { // Use vector version of the intrinsic. - if (State.VF.isVector()) - TysForDecl[0] = - VectorType::get(CI.getType()->getScalarType(), State.VF); Module *M = State.Builder.GetInsertBlock()->getModule(); VectorF = Intrinsic::getDeclaration(M, VectorIntrinsicID, TysForDecl); assert(VectorF && "Can't retrieve vector intrinsic."); diff --git a/llvm/test/Transforms/InstSimplify/is_fpclass.ll b/llvm/test/Transforms/InstSimplify/is_fpclass.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/is_fpclass.ll @@ -0,0 +1,12 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -S -passes=instsimplify | FileCheck %s + +define <2 x i1> @f() { +; CHECK-LABEL: define <2 x i1> @f() { +; CHECK-NEXT: ret <2 x i1> zeroinitializer +; + %i = call <2 x i1> @llvm.is.fpclass.v2f16(<2 x half> , i32 3) + ret <2 x i1> %i +} + +declare <2 x i1> @llvm.is.fpclass.v2f16(<2 x half>, i32 immarg) diff --git a/llvm/test/Transforms/LoopVectorize/is_fpclass.ll b/llvm/test/Transforms/LoopVectorize/is_fpclass.ll --- a/llvm/test/Transforms/LoopVectorize/is_fpclass.ll +++ b/llvm/test/Transforms/LoopVectorize/is_fpclass.ll @@ -4,9 +4,28 @@ define void @d() { ; CHECK-LABEL: define void @d() { ; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 true, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]] +; CHECK: vector.ph: +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP0:%.*]] = add i64 [[INDEX]], 0 +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr float, ptr @d, i64 [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i1> @llvm.is.fpclass.v2f32(<2 x float> zeroinitializer, i32 0) +; CHECK-NEXT: [[TMP3:%.*]] = select <2 x i1> [[TMP2]], <2 x float> zeroinitializer, <2 x float> zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr float, ptr [[TMP1]], i32 0 +; CHECK-NEXT: store <2 x float> [[TMP3]], ptr [[TMP4]], align 4 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 2 +; CHECK-NEXT: [[TMP5:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0 +; CHECK-NEXT: br i1 [[TMP5]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]] +; CHECK: middle.block: +; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 0, 0 +; CHECK-NEXT: br i1 [[CMP_N]], label [[EXIT:%.*]], label [[SCALAR_PH]] +; CHECK: scalar.ph: +; CHECK-NEXT: [[BC_RESUME_VAL:%.*]] = phi i64 [ 0, [[MIDDLE_BLOCK]] ], [ 0, [[ENTRY:%.*]] ] ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: -; CHECK-NEXT: [[I:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[I7:%.*]], [[LOOP]] ] +; CHECK-NEXT: [[I:%.*]] = phi i64 [ [[BC_RESUME_VAL]], [[SCALAR_PH]] ], [ [[I7:%.*]], [[LOOP]] ] ; CHECK-NEXT: [[I3:%.*]] = load float, ptr null, align 4 ; CHECK-NEXT: [[I4:%.*]] = getelementptr float, ptr @d, i64 [[I]] ; CHECK-NEXT: [[I5:%.*]] = tail call i1 @llvm.is.fpclass.f32(float 0.000000e+00, i32 0) @@ -14,7 +33,7 @@ ; CHECK-NEXT: store float [[I6]], ptr [[I4]], align 4 ; CHECK-NEXT: [[I7]] = add i64 [[I]], 1 ; CHECK-NEXT: [[I8:%.*]] = icmp eq i64 [[I7]], 0 -; CHECK-NEXT: br i1 [[I8]], label [[EXIT:%.*]], label [[LOOP]] +; CHECK-NEXT: br i1 [[I8]], label [[EXIT]], label [[LOOP]], !llvm.loop [[LOOP3:![0-9]+]] ; CHECK: exit: ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll b/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll --- a/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll +++ b/llvm/test/Transforms/SLPVectorizer/is_fpclass.ll @@ -4,13 +4,8 @@ define <2 x i1> @scalarize_is_fpclass(<2 x float> %x) { ; CHECK-LABEL: define <2 x i1> @scalarize_is_fpclass ; CHECK-SAME: (<2 x float> [[X:%.*]]) { -; CHECK-NEXT: [[X_I0:%.*]] = extractelement <2 x float> [[X]], i32 0 -; CHECK-NEXT: [[ISFPCLASS_I0:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I0]], i32 123) -; CHECK-NEXT: [[X_I1:%.*]] = extractelement <2 x float> [[X]], i32 1 -; CHECK-NEXT: [[ISFPCLASS_I1:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I1]], i32 123) -; CHECK-NEXT: [[ISFPCLASS_UPTO0:%.*]] = insertelement <2 x i1> poison, i1 [[ISFPCLASS_I0]], i32 0 -; CHECK-NEXT: [[ISFPCLASS:%.*]] = insertelement <2 x i1> [[ISFPCLASS_UPTO0]], i1 [[ISFPCLASS_I1]], i32 1 -; CHECK-NEXT: ret <2 x i1> [[ISFPCLASS]] +; CHECK-NEXT: [[TMP1:%.*]] = call <2 x i1> @llvm.is.fpclass.v2f32(<2 x float> [[X]], i32 123) +; CHECK-NEXT: ret <2 x i1> [[TMP1]] ; %x.i0 = extractelement <2 x float> %x, i32 0 %isfpclass.i0 = call i1 @llvm.is.fpclass.f32(float %x.i0, i32 123) diff --git a/llvm/test/Transforms/Scalarizer/intrinsics.ll b/llvm/test/Transforms/Scalarizer/intrinsics.ll --- a/llvm/test/Transforms/Scalarizer/intrinsics.ll +++ b/llvm/test/Transforms/Scalarizer/intrinsics.ll @@ -212,7 +212,12 @@ define <2 x i1> @scalarize_is_fpclass(<2 x float> %x) #0 { ; CHECK-LABEL: @scalarize_is_fpclass( -; CHECK-NEXT: [[ISFPCLASS:%.*]] = call <2 x i1> @llvm.is.fpclass.v2f32(<2 x float> [[X:%.*]], i32 123) +; CHECK-NEXT: [[X_I0:%.*]] = extractelement <2 x float> [[X:%.*]], i32 0 +; CHECK-NEXT: [[ISFPCLASS_I0:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I0]], i32 123) +; CHECK-NEXT: [[X_I1:%.*]] = extractelement <2 x float> [[X]], i32 1 +; CHECK-NEXT: [[ISFPCLASS_I1:%.*]] = call i1 @llvm.is.fpclass.f32(float [[X_I1]], i32 123) +; CHECK-NEXT: [[ISFPCLASS_UPTO0:%.*]] = insertelement <2 x i1> poison, i1 [[ISFPCLASS_I0]], i32 0 +; CHECK-NEXT: [[ISFPCLASS:%.*]] = insertelement <2 x i1> [[ISFPCLASS_UPTO0]], i1 [[ISFPCLASS_I1]], i32 1 ; CHECK-NEXT: ret <2 x i1> [[ISFPCLASS]] ; %isfpclass = call <2 x i1> @llvm.is.fpclass(<2 x float> %x, i32 123)