diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -398,6 +398,16 @@ } } + // Handle scalable vector splats separately, because we know the + // bitcast can always be pushed to the splatted source element. + if (auto *SrcSVTy = dyn_cast(V->getType())) + if (auto *DestSVTy = dyn_cast(DestTy)) + if (DestSVTy->getElementCount() == SrcSVTy->getElementCount()) + if (Constant *Splat = V->getSplatValue()) + return ConstantVector::getSplat( + DestSVTy->getElementCount(), + ConstantExpr::getCast(opc, Splat, DestSVTy->getElementType())); + // If the cast operand is a constant vector, perform the cast by // operating on each element. In the cast of bitcasts, the element // count may be mismatched; don't attempt to handle that here. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -1655,13 +1655,22 @@ // TODO: Make these support undef elements. static Type *shrinkFPConstantVector(Value *V) { auto *CV = dyn_cast(V); - auto *CVVTy = dyn_cast(V->getType()); - if (!CV || !CVVTy) + auto *CVT = dyn_cast(V->getType()); + if (!CV || !CVT) return nullptr; - Type *MinType = nullptr; + // Handle the splat case separately, because that works for scalable + // vectors, whereas the rest of this function does not. + if (Constant *SplatVal = CV->getSplatValue()) + if (auto *CFP = dyn_cast(SplatVal)) + if (Type *T = shrinkFPConstant(CFP)) + return VectorType::get(T, CVT->getElementCount()); - unsigned NumElts = CVVTy->getNumElements(); + if (!isa(CVT)) + return nullptr; + + Type *MinType = nullptr; + unsigned NumElts = cast(CVT)->getNumElements(); // For fixed-width vectors we find the minimal type by looking // through the constant values of the vector. @@ -1696,15 +1705,6 @@ if (Type *T = shrinkFPConstant(CFP)) return T; - // We can only correctly find a minimum type for a scalable vector when it is - // a splat. For splats of constant values the fpext is wrapped up as a - // ConstantExpr. - if (auto *FPCExt = dyn_cast(V)) - if (FPCExt->getOpcode() == Instruction::FPExt) - return FPCExt->getOperand(0)->getType(); - - // Try to shrink a vector of FP constants. This returns nullptr on scalable - // vectors if (Type *T = shrinkFPConstantVector(V)) return T; diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/sink-free-instructions.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/sink-free-instructions.ll --- a/llvm/test/Transforms/CodeGenPrepare/AArch64/sink-free-instructions.ll +++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/sink-free-instructions.ll @@ -494,3 +494,24 @@ %vmull1 = tail call <8 x i16> @llvm.aarch64.neon.pmull.v8i16(<8 x i8> %s3, <8 x i8> %s4) ret <8 x i16> %vmull1 } + +; Ensure sinking add/sub operands doesn't crash when one of them is a Const or ConstExpr. +define @zext_scalable_constexpr( %a, i1 %p) { +; CHECK-LABEL: @zext_scalable_constexpr( +; CHECK-NEXT: [[TMP1:%.*]] = zext [[A:%.*]] to +; CHECK-NEXT: br i1 [[P:%.*]], label [[ZEXT_THEN:%.*]], label [[ZEXT_EXIT:%.*]] +; CHECK: zext.then: +; CHECK-NEXT: [[TMP2:%.*]] = add [[TMP1]], shufflevector ( insertelement ( poison, i16 2, i32 0), poison, zeroinitializer) +; CHECK-NEXT: ret [[TMP2]] +; CHECK: zext.exit: +; CHECK-NEXT: ret [[TMP1]] +; + %1 = zext %a to + br i1 %p, label %zext.then, label %zext.exit +zext.then: + %2 = add %1, zext ( shufflevector ( insertelement ( undef, i8 2, i32 0), undef, zeroinitializer) to ) + ret %2 + +zext.exit: + ret %1 +} diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll --- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll +++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll @@ -2,7 +2,7 @@ define @shrink_splat_scalable_extend( %a) { ; CHECK-LABEL: @shrink_splat_scalable_extend - ; CHECK-NEXT: %[[FADD:.*]] = fadd %a, shufflevector ( insertelement ( undef, float -1.000000e+00, i32 0), undef, zeroinitializer) + ; CHECK-NEXT: %[[FADD:.*]] = fadd %a, shufflevector ( insertelement ( poison, float -1.000000e+00, i32 0), poison, zeroinitializer) ; CHECK-NEXT: ret %[[FADD]] %1 = shufflevector insertelement ( undef, float -1.000000e+00, i32 0), undef, zeroinitializer %2 = fpext %a to diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll --- a/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/vscale-inseltpoison.ll @@ -221,7 +221,7 @@ define @bitcast() { ; CHECK-LABEL: @bitcast( -; CHECK-NEXT: ret bitcast ( shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer) to ) +; CHECK-NEXT: ret shufflevector ( insertelement ( poison, float 0x36A0000000000000, i32 0), poison, zeroinitializer) ; %i1 = insertelement poison, i32 1, i32 0 %i2 = shufflevector %i1, poison, zeroinitializer diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll b/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll --- a/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/vscale.ll @@ -221,7 +221,7 @@ define @bitcast() { ; CHECK-LABEL: @bitcast( -; CHECK-NEXT: ret bitcast ( shufflevector ( insertelement ( undef, i32 1, i32 0), undef, zeroinitializer) to ) +; CHECK-NEXT: ret shufflevector ( insertelement ( poison, float 0x36A0000000000000, i32 0), poison, zeroinitializer) ; %i1 = insertelement undef, i32 1, i32 0 %i2 = shufflevector %i1, undef, zeroinitializer diff --git a/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll b/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll --- a/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll +++ b/llvm/test/Transforms/InstSimplify/vscale-inseltpoison.ll @@ -140,7 +140,7 @@ define @bitcast() { ; CHECK-LABEL: @bitcast( -; CHECK-NEXT: ret bitcast ( shufflevector ( insertelement ( poison, i32 1, i32 0), poison, zeroinitializer) to ) +; CHECK-NEXT: ret shufflevector ( insertelement ( poison, float 0x36A0000000000000, i32 0), poison, zeroinitializer) ; %i1 = insertelement poison, i32 1, i32 0 %i2 = shufflevector %i1, poison, zeroinitializer diff --git a/llvm/test/Transforms/InstSimplify/vscale.ll b/llvm/test/Transforms/InstSimplify/vscale.ll --- a/llvm/test/Transforms/InstSimplify/vscale.ll +++ b/llvm/test/Transforms/InstSimplify/vscale.ll @@ -152,7 +152,7 @@ define @bitcast() { ; CHECK-LABEL: @bitcast( -; CHECK-NEXT: ret bitcast ( shufflevector ( insertelement ( undef, i32 1, i32 0), undef, zeroinitializer) to ) +; CHECK-NEXT: ret shufflevector ( insertelement ( poison, float 0x36A0000000000000, i32 0), poison, zeroinitializer) ; %i1 = insertelement undef, i32 1, i32 0 %i2 = shufflevector %i1, undef, zeroinitializer