Index: include/llvm/IR/Constants.h =================================================================== --- include/llvm/IR/Constants.h +++ include/llvm/IR/Constants.h @@ -486,6 +486,7 @@ public: /// Return a ConstantVector with the specified constant in each element. static Constant *getSplat(unsigned NumElts, Constant *Elt); + static Constant *getSplat(VectorType::ElementCount EC, Constant *Elt); /// Specialize the getType() method to always return a VectorType, /// which reduces the amount of casting needed in parts of the compiler. Index: lib/IR/ConstantFold.cpp =================================================================== --- lib/IR/ConstantFold.cpp +++ lib/IR/ConstantFold.cpp @@ -805,6 +805,9 @@ if (isa(Idx)) return UndefValue::get(Val->getType()); + if (Val->getType()->getVectorIsScalable()) + return nullptr; + ConstantInt *CIdx = dyn_cast(Idx); if (!CIdx) return nullptr; @@ -842,6 +845,9 @@ // Don't break the bitcode reader hack. if (isa(Mask)) return nullptr; + if (V1->getType()->getVectorIsScalable()) + return nullptr; + unsigned SrcNumElts = V1->getType()->getVectorNumElements(); // Loop over the shuffle mask, evaluating each element. Index: lib/IR/Constants.cpp =================================================================== --- lib/IR/Constants.cpp +++ lib/IR/Constants.cpp @@ -241,7 +241,7 @@ // Broadcast a scalar to a vector, if necessary. if (VectorType *VTy = dyn_cast(Ty)) - C = ConstantVector::getSplat(VTy->getNumElements(), C); + C = ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -258,7 +258,7 @@ } VectorType *VTy = cast(Ty); - return ConstantVector::getSplat(VTy->getNumElements(), + return ConstantVector::getSplat(VTy->getElementCount(), getAllOnesValue(VTy->getElementType())); } @@ -525,7 +525,7 @@ } assert(VTy->getElementType()->isIntegerTy(1) && "True must be vector of i1 or i1."); - return ConstantVector::getSplat(VTy->getNumElements(), + return ConstantVector::getSplat(VTy->getElementCount(), ConstantInt::getTrue(Ty->getContext())); } @@ -537,7 +537,7 @@ } assert(VTy->getElementType()->isIntegerTy(1) && "False must be vector of i1 or i1."); - return ConstantVector::getSplat(VTy->getNumElements(), + return ConstantVector::getSplat(VTy->getElementCount(), ConstantInt::getFalse(Ty->getContext())); } @@ -560,7 +560,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -584,7 +584,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -631,7 +631,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -645,7 +645,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -656,7 +656,7 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -667,7 +667,7 @@ Constant *C = get(Ty->getContext(), NegZero); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -715,7 +715,7 @@ Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative)); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -1044,6 +1044,27 @@ return get(Elts); } +Constant *ConstantVector::getSplat(VectorType::ElementCount EC, Constant *V) { + if (!EC.isScalable()) + return getSplat(EC.getNumElements(), V); + + Type *Ty = VectorType::get(V->getType(), EC); + + if (V->isNullValue()) + return ConstantAggregateZero::get(Ty); + else if (isa(V)) + return UndefValue::get(Ty); + + // Move splat value into vector[0]. + Type *Int32Ty = Type::getInt32Ty(Ty->getContext()); + Constant *Zero = ConstantInt::get(Int32Ty, 0); + V = ConstantExpr::getInsertElement(UndefValue::get(Ty), V, Zero); + + // Broadcast element zero to all lanes. + Constant *Zeros = ConstantInt::get(VectorType::get(Int32Ty, EC), 0); + return ConstantExpr::getShuffleVector(V, UndefValue::get(Ty), Zeros); +} + ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) { LLVMContextImpl *pImpl = Context.pImpl; if (!pImpl->TheNoneToken) @@ -2052,7 +2073,7 @@ if (Constant *FC = ConstantFoldShuffleVectorInstruction(V1, V2, Mask)) return FC; // Fold a few common cases. - unsigned NElts = Mask->getType()->getVectorNumElements(); + auto NElts = cast(Mask->getType())->getElementCount(); Type *EltTy = V1->getType()->getVectorElementType(); Type *ShufTy = VectorType::get(EltTy, NElts); Index: lib/IR/ConstantsContext.h =================================================================== --- lib/IR/ConstantsContext.h +++ lib/IR/ConstantsContext.h @@ -143,7 +143,7 @@ ShuffleVectorConstantExpr(Constant *C1, Constant *C2, Constant *C3) : ConstantExpr(VectorType::get( cast(C1->getType())->getElementType(), - cast(C3->getType())->getNumElements()), + cast(C3->getType())->getElementCount()), Instruction::ShuffleVector, &Op<0>(), 3) { Op<0>() = C1; Index: lib/IR/Instructions.cpp =================================================================== --- lib/IR/Instructions.cpp +++ lib/IR/Instructions.cpp @@ -1789,7 +1789,7 @@ const Twine &Name, Instruction *InsertBefore) : Instruction(VectorType::get(cast(V1->getType())->getElementType(), - cast(Mask->getType())->getNumElements()), + cast(Mask->getType())->getElementCount()), ShuffleVector, OperandTraits::op_begin(this), OperandTraits::operands(this), @@ -1806,7 +1806,7 @@ const Twine &Name, BasicBlock *InsertAtEnd) : Instruction(VectorType::get(cast(V1->getType())->getElementType(), - cast(Mask->getType())->getNumElements()), + cast(Mask->getType())->getElementCount()), ShuffleVector, OperandTraits::op_begin(this), OperandTraits::operands(this), Index: test/Transforms/ConstProp/splat.ll =================================================================== --- /dev/null +++ test/Transforms/ConstProp/splat.ll @@ -0,0 +1,47 @@ +; RUN: opt < %s -constprop -S | FileCheck %s + +define <4 x i32> @test1() { +; CHECK-LABEL: @test1 +; CHECK: ret <4 x i32> undef + ret <4 x i32> undef +} + +define @test2() { +; CHECK-LABEL: @test2 +; CHECK: ret undef + ret undef +} + +define <4 x i32> @test3() { +; CHECK-LABEL: @test3 +; CHECK: ret <4 x i32> zeroinitializer + ret <4 x i32> zeroinitializer +} + +define @test4() { +; CHECK-LABEL: @test4 +; CHECK: ret zeroinitializer + ret zeroinitializer +} + +; Test const splat sequences become const vectors for non-scalable types. +define <4 x i32> @test5() { +; CHECK-LABEL: @test5 +; CHECK: ret <4 x i32> + %a = insertelement <4 x i32> undef, i32 29, i32 0 + %b = shufflevector <4 x i32> %a, <4 x i32> undef, <4 x i32> zeroinitializer + ret <4 x i32> %b +} + +; Test const splat sequences are maintained for scalable vectors with the +; exception that all nodes are constant. +define @test6() { +; CHECK-LABEL: @test6 +; CHECK: ret shufflevector ( +; CHECK-SAME: insertelement ( undef, i32 29, i32 0), +; CHECK-SAME: undef, +; CHECK-SAME: zeroinitializer) + %a = insertelement undef, i32 29, i32 0 + %b = shufflevector %a, undef, zeroinitializer + ret %b +}