diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1575,24 +1575,64 @@ Type *MinType = nullptr; - unsigned NumElts = cast(CVVTy)->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - auto *CFP = dyn_cast_or_null(CV->getAggregateElement(i)); + auto EC = cast(CVVTy)->getElementCount(); + + auto getMinTypeInConstant = [](Constant *constant, Type *MinType) -> Type * { + auto *CFP = dyn_cast_or_null(constant); if (!CFP) return nullptr; - Type *T = shrinkFPConstant(CFP); + auto *T = shrinkFPConstant(CFP); if (!T) return nullptr; // If we haven't found a type yet or this type has a larger mantissa than // our previous type, this is our new minimal type. if (!MinType || T->getFPMantissaWidth() > MinType->getFPMantissaWidth()) + return T; + + return MinType; + }; + + // We only can correctly find a MinType for a ScalableVector if the vector + // is a splat-vector, otherwise we can't shrink due to the runtime-defined + // vscale value + if (EC.isScalable()) { + if (CV->getSplatValue()) { + if (auto *T = getMinTypeInConstant(CV->getSplatValue(), MinType)) + return VectorType::get(T, EC); + return nullptr; + } + + // scalable splat vectors can be nested within an instruction, and as such + // we must search instruction operands for the minimum type i.e. for an + // fpext we may have the following: fpext ( ... to ) we must search the first operand for the + // minimum type. + for (unsigned i = 0; i < CV->getNumOperands(); i++) { + auto *CV2 = dyn_cast(CV->getOperand(i)); + if (!CV2 || !CV2->getSplatValue()) + return nullptr; + + if (auto *T = getMinTypeInConstant(CV2->getSplatValue(), MinType)) + MinType = T; + else + return nullptr; + } + + return VectorType::get(MinType, EC); + } + + // For fixed-width vectors we find the min-type by looking + // through the constant values of the vector + for (unsigned i = 0; i != EC.getFixedValue(); ++i) { + if (auto *T = getMinTypeInConstant(CV->getAggregateElement(i), MinType)) MinType = T; + else + return nullptr; } - // Make a vector type from the minimal type. - return FixedVectorType::get(MinType, NumElts); + return VectorType::get(MinType, EC); } /// Find the minimum FP type we can safely truncate to. diff --git a/llvm/test/Transforms/InstCombine/AArch64/instcombine-vectors.ll b/llvm/test/Transforms/InstCombine/AArch64/instcombine-vectors.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/AArch64/instcombine-vectors.ll @@ -0,0 +1,17 @@ +; RUN: opt -instcombine -mtriple=aarch64-linux-gnu -mattr=+sve -S -o - < %s 2>%t | FileCheck %s +; RUN: FileCheck --check-prefix=WARN --allow-empty %s <%t + +; If this check fails please read test/CodeGen/AArch64/README for instructions on how to resolve it. +; WARN-NOT: warning + +define @shrink_splat_scalable_extend( %a) { + ; CHECK-LABEL: @shrink_splat_scalable_extend + ; CHECK-NEXT: %1 = fadd %a, shufflevector ( insertelement ( undef, float -1.000000e+00, i32 0), undef, zeroinitializer) + ; CHECK-NEXT: ret %1 + %1 = shufflevector insertelement ( undef, float -1.000000e+00, i32 0), undef, zeroinitializer + %2 = fpext %a to + %3 = fpext %1 to + %4 = fadd %2, %3 + %5 = fptrunc %4 to + ret %5 +}