diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -262,18 +262,23 @@ return constantint_match(); } -/// This helper class is used to match scalar and fixed width vector integer -/// constants that satisfy a specified predicate. -/// For vector constants, undefined elements are ignored. +/// This helper class is used to match integer constant scalars, vector splats, +/// and fixed width vectors that satisfy a specified predicate. +/// For fixed width vector constants, undefined elements are ignored. template struct cst_pred_ty : public Predicate { template bool match(ITy *V) { if (const auto *CI = dyn_cast(V)) return this->isValue(CI->getValue()); - if (const auto *FVTy = dyn_cast(V->getType())) { + if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CI = dyn_cast_or_null(C->getSplatValue())) return this->isValue(CI->getValue()); + // Number of elements of a scalable vector unknown at compile time + auto *FVTy = dyn_cast(VTy); + if (!FVTy) + return false; + // Non-splat vector constant: check each element for a match. unsigned NumElts = FVTy->getNumElements(); assert(NumElts != 0 && "Constant vector with no elements?"); @@ -321,25 +326,25 @@ } }; -/// This helper class is used to match scalar and vector floating-point -/// constants that satisfy a specified predicate. -/// For vector constants, undefined elements are ignored. +/// This helper class is used to match float constant scalars, vector splats, +/// and fixed width vectors that satisfy a specified predicate. +/// For fixed width vector constants, undefined elements are ignored. template struct cstfp_pred_ty : public Predicate { template bool match(ITy *V) { if (const auto *CF = dyn_cast(V)) return this->isValue(CF->getValueAPF()); - if (V->getType()->isVectorTy()) { + if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CF = dyn_cast_or_null(C->getSplatValue())) return this->isValue(CF->getValueAPF()); // Number of elements of a scalable vector unknown at compile time - if (isa(V->getType())) + auto *FVTy = dyn_cast(VTy); + if (!FVTy) return false; // Non-splat vector constant: check each element for a match. - unsigned NumElts = - cast(V->getType())->getNumElements(); + unsigned NumElts = FVTy->getNumElements(); assert(NumElts != 0 && "Constant vector with no elements?"); bool HasNonUndefElements = false; for (unsigned i = 0; i != NumElts; ++i) { diff --git a/llvm/test/Transforms/InstCombine/fmul.ll b/llvm/test/Transforms/InstCombine/fmul.ll --- a/llvm/test/Transforms/InstCombine/fmul.ll +++ b/llvm/test/Transforms/InstCombine/fmul.ll @@ -1164,3 +1164,12 @@ %mul = fmul fast double %sqr, %sel ret double %mul } + +; fastmath => z * splat(0) = splat(0), even for scalable vectors +define @mul_scalable_splat_zero( %z) { +; CHECK-LABEL: @mul_scalable_splat_zero( +; CHECK-NEXT: ret zeroinitializer + %shuf = shufflevector insertelement ( undef, float 0.0, i32 0), undef, zeroinitializer + %t3 = fmul fast %shuf, %z + ret %t3 +} diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -856,3 +856,12 @@ %m = mul <4 x i32> %r, %r ret <4 x i32> %m } + +; z * splat(0) = splat(0), even for scalable vectors +define @mul_scalable_splat_zero( %z) { +; CHECK-LABEL: @mul_scalable_splat_zero( +; CHECK-NEXT: ret zeroinitializer + %shuf = shufflevector insertelement ( undef, i64 0, i32 0), undef, zeroinitializer + %t3 = mul %shuf, %z + ret %t3 +} diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -1325,6 +1325,183 @@ m_SpecificInt(10)))); } +namespace { + +struct is_unsigned_zero_pred { + bool isValue(const APInt &C) { return C.isNullValue(); } +}; + +struct is_float_zero_pred { + bool isValue(const APFloat &C) { return C.isZero(); } +}; + +template struct always_true_pred { + bool isValue(const T &) { return true; } +}; + +template struct always_false_pred { + bool isValue(const T &) { return false; } +}; + +struct is_unsigned_max_pred { + bool isValue(const APInt &C) { return C.isMaxValue(); } +}; + +struct is_float_nan_pred { + bool isValue(const APFloat &C) { return C.isNaN(); } +}; + +} // namespace + +TEST_F(PatternMatchTest, ConstantPredicateType) { + + // Scalar integer + APInt U32Max = APInt::getAllOnesValue(32); + APInt U32Zero = APInt::getNullValue(32); + APInt U32DeadBeef(32, 0xDEADBEEF); + + Type *U32Ty = Type::getInt32Ty(Ctx); + + Constant *CU32Max = Constant::getIntegerValue(U32Ty, U32Max); + Constant *CU32Zero = Constant::getIntegerValue(U32Ty, U32Zero); + Constant *CU32DeadBeef = Constant::getIntegerValue(U32Ty, U32DeadBeef); + + EXPECT_TRUE(match(CU32Max, cst_pred_ty())); + EXPECT_FALSE(match(CU32Max, cst_pred_ty())); + EXPECT_TRUE(match(CU32Max, cst_pred_ty>())); + EXPECT_FALSE(match(CU32Max, cst_pred_ty>())); + + EXPECT_FALSE(match(CU32Zero, cst_pred_ty())); + EXPECT_TRUE(match(CU32Zero, cst_pred_ty())); + EXPECT_TRUE(match(CU32Zero, cst_pred_ty>())); + EXPECT_FALSE(match(CU32Zero, cst_pred_ty>())); + + EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty())); + EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty())); + EXPECT_TRUE(match(CU32DeadBeef, cst_pred_ty>())); + EXPECT_FALSE(match(CU32DeadBeef, cst_pred_ty>())); + + // Scalar float + APFloat F32NaN = APFloat::getNaN(APFloat::IEEEsingle()); + APFloat F32Zero = APFloat::getZero(APFloat::IEEEsingle()); + APFloat F32Pi(3.14f); + + Type *F32Ty = Type::getFloatTy(Ctx); + + Constant *CF32NaN = ConstantFP::get(F32Ty, F32NaN); + Constant *CF32Zero = ConstantFP::get(F32Ty, F32Zero); + Constant *CF32Pi = ConstantFP::get(F32Ty, F32Pi); + + EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty())); + EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty())); + EXPECT_TRUE(match(CF32NaN, cstfp_pred_ty>())); + EXPECT_FALSE(match(CF32NaN, cstfp_pred_ty>())); + + EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty())); + EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty())); + EXPECT_TRUE(match(CF32Zero, cstfp_pred_ty>())); + EXPECT_FALSE(match(CF32Zero, cstfp_pred_ty>())); + + EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty())); + EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty())); + EXPECT_TRUE(match(CF32Pi, cstfp_pred_ty>())); + EXPECT_FALSE(match(CF32Pi, cstfp_pred_ty>())); + + ElementCount FixedEC(4, false); + ElementCount ScalableEC(4, true); + + // Vector splat + + for (auto EC : {FixedEC, ScalableEC}) { + // integer + + Constant *CSplatU32Max = ConstantVector::getSplat(EC, CU32Max); + Constant *CSplatU32Zero = ConstantVector::getSplat(EC, CU32Zero); + Constant *CSplatU32DeadBeef = ConstantVector::getSplat(EC, CU32DeadBeef); + + EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty())); + EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty())); + EXPECT_TRUE(match(CSplatU32Max, cst_pred_ty>())); + EXPECT_FALSE(match(CSplatU32Max, cst_pred_ty>())); + + EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty())); + EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty())); + EXPECT_TRUE(match(CSplatU32Zero, cst_pred_ty>())); + EXPECT_FALSE(match(CSplatU32Zero, cst_pred_ty>())); + + EXPECT_FALSE(match(CSplatU32DeadBeef, cst_pred_ty())); + EXPECT_FALSE( + match(CSplatU32DeadBeef, cst_pred_ty())); + EXPECT_TRUE( + match(CSplatU32DeadBeef, cst_pred_ty>())); + EXPECT_FALSE( + match(CSplatU32DeadBeef, cst_pred_ty>())); + + // float + + Constant *CSplatF32NaN = ConstantVector::getSplat(EC, CF32NaN); + Constant *CSplatF32Zero = ConstantVector::getSplat(EC, CF32Zero); + Constant *CSplatF32Pi = ConstantVector::getSplat(EC, CF32Pi); + + EXPECT_TRUE(match(CSplatF32NaN, cstfp_pred_ty())); + EXPECT_FALSE(match(CSplatF32NaN, cstfp_pred_ty())); + EXPECT_TRUE( + match(CSplatF32NaN, cstfp_pred_ty>())); + EXPECT_FALSE( + match(CSplatF32NaN, cstfp_pred_ty>())); + + EXPECT_FALSE(match(CSplatF32Zero, cstfp_pred_ty())); + EXPECT_TRUE(match(CSplatF32Zero, cstfp_pred_ty())); + EXPECT_TRUE( + match(CSplatF32Zero, cstfp_pred_ty>())); + EXPECT_FALSE( + match(CSplatF32Zero, cstfp_pred_ty>())); + + EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty())); + EXPECT_FALSE(match(CSplatF32Pi, cstfp_pred_ty())); + EXPECT_TRUE(match(CSplatF32Pi, cstfp_pred_ty>())); + EXPECT_FALSE( + match(CSplatF32Pi, cstfp_pred_ty>())); + } + + // Int arbitrary vector + + Constant *CMixedU32 = ConstantVector::get({CU32Max, CU32Zero, CU32DeadBeef}); + Constant *CU32Undef = UndefValue::get(U32Ty); + Constant *CU32MaxWithUndef = + ConstantVector::get({CU32Undef, CU32Max, CU32Undef}); + + EXPECT_FALSE(match(CMixedU32, cst_pred_ty())); + EXPECT_FALSE(match(CMixedU32, cst_pred_ty())); + EXPECT_TRUE(match(CMixedU32, cst_pred_ty>())); + EXPECT_FALSE(match(CMixedU32, cst_pred_ty>())); + + EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty())); + EXPECT_FALSE(match(CU32MaxWithUndef, cst_pred_ty())); + EXPECT_TRUE(match(CU32MaxWithUndef, cst_pred_ty>())); + EXPECT_FALSE( + match(CU32MaxWithUndef, cst_pred_ty>())); + + // Float arbitrary vector + + Constant *CMixedF32 = ConstantVector::get({CF32NaN, CF32Zero, CF32Pi}); + Constant *CF32Undef = UndefValue::get(F32Ty); + Constant *CF32NaNWithUndef = + ConstantVector::get({CF32Undef, CF32NaN, CF32Undef}); + + EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty())); + EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty())); + EXPECT_TRUE(match(CMixedF32, cstfp_pred_ty>())); + EXPECT_FALSE(match(CMixedF32, cstfp_pred_ty>())); + + EXPECT_TRUE(match(CF32NaNWithUndef, cstfp_pred_ty())); + EXPECT_FALSE(match(CF32NaNWithUndef, cstfp_pred_ty())); + EXPECT_TRUE( + match(CF32NaNWithUndef, cstfp_pred_ty>())); + EXPECT_FALSE( + match(CF32NaNWithUndef, cstfp_pred_ty>())); +} + template struct MutableConstTest : PatternMatchTest { }; typedef ::testing::Types,