diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -262,18 +262,23 @@ return constantint_match(); } -/// This helper class is used to match scalar and fixed width vector integer -/// constants that satisfy a specified predicate. -/// For vector constants, undefined elements are ignored. +/// This helper class is used to match integer constant scalars, vector splats, +/// and fixed width vectors that satisfy a specified predicate. +/// For fixed width vector constants, undefined elements are ignored. template struct cst_pred_ty : public Predicate { template bool match(ITy *V) { if (const auto *CI = dyn_cast(V)) return this->isValue(CI->getValue()); - if (const auto *FVTy = dyn_cast(V->getType())) { + if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CI = dyn_cast_or_null(C->getSplatValue())) return this->isValue(CI->getValue()); + // Number of elements of a scalable vector unknown at compile time + auto *FVTy = dyn_cast(VTy); + if (!FVTy) + return false; + // Non-splat vector constant: check each element for a match. unsigned NumElts = FVTy->getNumElements(); assert(NumElts != 0 && "Constant vector with no elements?"); @@ -321,25 +326,25 @@ } }; -/// This helper class is used to match scalar and vector floating-point -/// constants that satisfy a specified predicate. -/// For vector constants, undefined elements are ignored. +/// This helper class is used to match float constant scalars, vector splats, +/// and fixed width vectors that satisfy a specified predicate. +/// For fixed width vector constants, undefined elements are ignored. template struct cstfp_pred_ty : public Predicate { template bool match(ITy *V) { if (const auto *CF = dyn_cast(V)) return this->isValue(CF->getValueAPF()); - if (V->getType()->isVectorTy()) { + if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CF = dyn_cast_or_null(C->getSplatValue())) return this->isValue(CF->getValueAPF()); // Number of elements of a scalable vector unknown at compile time - if (isa(V->getType())) + auto *FVTy = dyn_cast(VTy); + if (!FVTy) return false; // Non-splat vector constant: check each element for a match. - unsigned NumElts = - cast(V->getType())->getNumElements(); + unsigned NumElts = FVTy->getNumElements(); assert(NumElts != 0 && "Constant vector with no elements?"); bool HasNonUndefElements = false; for (unsigned i = 0; i != NumElts; ++i) { diff --git a/llvm/test/Transforms/InstCombine/fmul.ll b/llvm/test/Transforms/InstCombine/fmul.ll --- a/llvm/test/Transforms/InstCombine/fmul.ll +++ b/llvm/test/Transforms/InstCombine/fmul.ll @@ -1164,3 +1164,12 @@ %mul = fmul fast double %sqr, %sel ret double %mul } + +; fastmath => z * splat(0) = splat(0), even for scalable vectors +define @mul_scalable_splat_zero( %z) { +; CHECK-LABEL: @mul_scalable_splat_zero( +; CHECK-NEXT: ret zeroinitializer + %shuf = shufflevector insertelement ( undef, float 0.0, i32 0), undef, zeroinitializer + %t3 = fmul fast %shuf, %z + ret %t3 +} diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -680,3 +680,12 @@ %m = mul <4 x i32> %r, %r ret <4 x i32> %m } + +; z * splat(0) = splat(0), even for scalable vectors +define @mul_scalable_splat_zero( %z) { +; CHECK-LABEL: @mul_scalable_splat_zero( +; CHECK-NEXT: ret zeroinitializer + %shuf = shufflevector insertelement ( undef, i64 0, i32 0), undef, zeroinitializer + %t3 = mul %shuf, %z + ret %t3 +}