diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -4496,8 +4496,8 @@ } Value *CodeGenFunction::EmitNeonSplat(Value *V, Constant *C) { - unsigned nElts = V->getType()->getVectorNumElements(); - Value* SV = llvm::ConstantVector::getSplat(nElts, C); + ElementCount EC = V->getType()->getVectorElementCount(); + Value *SV = llvm::ConstantVector::getSplat(EC, C); return Builder.CreateShuffleVector(V, V, SV, "lane"); } @@ -8701,7 +8701,7 @@ llvm::VectorType::get(VTy->getElementType(), VTy->getNumElements() / 2) : VTy; llvm::Constant *cst = cast(Ops[3]); - Value *SV = llvm::ConstantVector::getSplat(VTy->getNumElements(), cst); + Value *SV = llvm::ConstantVector::getSplat(VTy->getElementCount(), cst); Ops[1] = Builder.CreateBitCast(Ops[1], SourceTy); Ops[1] = Builder.CreateShuffleVector(Ops[1], Ops[1], SV, "lane"); @@ -8730,7 +8730,7 @@ llvm::Type *STy = llvm::VectorType::get(VTy->getElementType(), VTy->getNumElements() * 2); Ops[2] = Builder.CreateBitCast(Ops[2], STy); - Value* SV = llvm::ConstantVector::getSplat(VTy->getNumElements(), + Value *SV = llvm::ConstantVector::getSplat(VTy->getElementCount(), cast(Ops[3])); Ops[2] = Builder.CreateShuffleVector(Ops[2], Ops[2], SV, "lane"); diff --git a/llvm/include/llvm/Analysis/Utils/Local.h b/llvm/include/llvm/Analysis/Utils/Local.h --- a/llvm/include/llvm/Analysis/Utils/Local.h +++ b/llvm/include/llvm/Analysis/Utils/Local.h @@ -63,7 +63,7 @@ // Splat the constant if needed. if (IntIdxTy->isVectorTy() && !OpC->getType()->isVectorTy()) - OpC = ConstantVector::getSplat(IntIdxTy->getVectorNumElements(), OpC); + OpC = ConstantVector::getSplat(IntIdxTy->getVectorElementCount(), OpC); Constant *Scale = ConstantInt::get(IntIdxTy, Size); Constant *OC = ConstantExpr::getIntegerCast(OpC, IntIdxTy, true /*SExt*/); diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/llvm/include/llvm/IR/Constants.h @@ -517,7 +517,7 @@ public: /// Return a ConstantVector with the specified constant in each element. - static Constant *getSplat(unsigned NumElts, Constant *Elt); + static Constant *getSplat(ElementCount EC, Constant *Elt); /// Specialize the getType() method to always return a VectorType, /// which reduces the amount of casting needed in parts of the compiler. diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -707,9 +707,8 @@ Offset = Offset.sextOrTrunc(IntIdxTy->getIntegerBitWidth()); Constant *OffsetIntPtr = ConstantInt::get(IntIdxTy, Offset); - if (V->getType()->isVectorTy()) - return ConstantVector::getSplat(V->getType()->getVectorNumElements(), - OffsetIntPtr); + if (VectorType *VecTy = dyn_cast(V->getType())) + return ConstantVector::getSplat(VecTy->getElementCount(), OffsetIntPtr); return OffsetIntPtr; } diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -6565,19 +6565,23 @@ UseSplat = true; } - unsigned End = getTransitionType()->getVectorNumElements(); + ElementCount EC = getTransitionType()->getVectorElementCount(); if (UseSplat) - return ConstantVector::getSplat(End, Val); - - SmallVector ConstVec; - UndefValue *UndefVal = UndefValue::get(Val->getType()); - for (unsigned Idx = 0; Idx != End; ++Idx) { - if (Idx == ExtractIdx) - ConstVec.push_back(Val); - else - ConstVec.push_back(UndefVal); - } - return ConstantVector::get(ConstVec); + return ConstantVector::getSplat(EC, Val); + + if (!EC.Scalable) { + SmallVector ConstVec; + UndefValue *UndefVal = UndefValue::get(Val->getType()); + for (unsigned Idx = 0; Idx != EC.Min; ++Idx) { + if (Idx == ExtractIdx) + ConstVec.push_back(Val); + else + ConstVec.push_back(UndefVal); + } + return ConstantVector::get(ConstVec); + } else + llvm_unreachable( + "Generate scalable vector for non-splat is unimplemented"); } /// Check if promoting to a vector type an operand at \p OperandIdx diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -2229,8 +2229,7 @@ Constant *Idx0 = cast(Idxs[0]); if (Idxs.size() == 1 && (Idx0->isNullValue() || isa(Idx0))) return GEPTy->isVectorTy() && !C->getType()->isVectorTy() - ? ConstantVector::getSplat( - cast(GEPTy)->getNumElements(), C) + ? ConstantVector::getSplat(GEPTy->getVectorElementCount(), C) : C; if (C->isNullValue()) { diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -370,7 +370,7 @@ // Broadcast a scalar to a vector, if necessary. if (VectorType *VTy = dyn_cast(Ty)) - C = ConstantVector::getSplat(VTy->getNumElements(), C); + C = ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -387,7 +387,7 @@ } VectorType *VTy = cast(Ty); - return ConstantVector::getSplat(VTy->getNumElements(), + return ConstantVector::getSplat(VTy->getElementCount(), getAllOnesValue(VTy->getElementType())); } @@ -681,7 +681,7 @@ assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *TrueC = ConstantInt::getTrue(Ty->getContext()); if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), TrueC); + return ConstantVector::getSplat(VTy->getElementCount(), TrueC); return TrueC; } @@ -689,7 +689,7 @@ assert(Ty->isIntOrIntVectorTy(1) && "Type not i1 or vector of i1."); ConstantInt *FalseC = ConstantInt::getFalse(Ty->getContext()); if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), FalseC); + return ConstantVector::getSplat(VTy->getElementCount(), FalseC); return FalseC; } @@ -712,7 +712,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -736,7 +736,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -781,7 +781,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -793,7 +793,7 @@ // For vectors, broadcast the value. if (auto *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -806,7 +806,7 @@ // For vectors, broadcast the value. if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -817,7 +817,7 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -828,7 +828,7 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -839,7 +839,7 @@ Constant *C = get(Ty->getContext(), NaN); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -850,7 +850,7 @@ Constant *C = get(Ty->getContext(), NegZero); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -898,7 +898,7 @@ Constant *C = get(Ty->getContext(), APFloat::getInf(Semantics, Negative)); if (VectorType *VTy = dyn_cast(Ty)) - return ConstantVector::getSplat(VTy->getNumElements(), C); + return ConstantVector::getSplat(VTy->getElementCount(), C); return C; } @@ -1204,15 +1204,35 @@ return nullptr; } -Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) { - // If this splat is compatible with ConstantDataVector, use it instead of - // ConstantVector. - if ((isa(V) || isa(V)) && - ConstantDataSequential::isElementTypeCompatible(V->getType())) - return ConstantDataVector::getSplat(NumElts, V); +Constant *ConstantVector::getSplat(ElementCount EC, Constant *V) { + if (!EC.Scalable) { + // If this splat is compatible with ConstantDataVector, use it instead of + // ConstantVector. + if ((isa(V) || isa(V)) && + ConstantDataSequential::isElementTypeCompatible(V->getType())) + return ConstantDataVector::getSplat(EC.Min, V); - SmallVector Elts(NumElts, V); - return get(Elts); + SmallVector Elts(EC.Min, V); + return get(Elts); + } + + Type *VTy = VectorType::get(V->getType(), EC); + + if (V->isNullValue()) + return ConstantAggregateZero::get(VTy); + else if (isa(V)) + return UndefValue::get(VTy); + + Type *I32Ty = Type::getInt32Ty(VTy->getContext()); + + // Move scalar into vector. + Constant *UndefV = UndefValue::get(VTy); + V = ConstantExpr::getInsertElement(UndefV, V, ConstantInt::get(I32Ty, 0)); + // Build shuffle mask to perform the splat. + Type *MaskTy = VectorType::get(I32Ty, EC); + Constant *Zeros = ConstantAggregateZero::get(MaskTy); + // Splat. + return ConstantExpr::getShuffleVector(V, UndefV, Zeros); } ConstantTokenNone *ConstantTokenNone::get(LLVMContext &Context) { @@ -2098,15 +2118,15 @@ unsigned AS = C->getType()->getPointerAddressSpace(); Type *ReqTy = DestTy->getPointerTo(AS); - unsigned NumVecElts = 0; - if (C->getType()->isVectorTy()) - NumVecElts = C->getType()->getVectorNumElements(); + ElementCount EltCount = {0, false}; + if (VectorType *VecTy = dyn_cast(C->getType())) + EltCount = VecTy->getElementCount(); else for (auto Idx : Idxs) - if (Idx->getType()->isVectorTy()) - NumVecElts = Idx->getType()->getVectorNumElements(); + if (VectorType *VecTy = dyn_cast(Idx->getType())) + EltCount = VecTy->getElementCount(); - if (NumVecElts) - ReqTy = VectorType::get(ReqTy, NumVecElts); + if (EltCount.Min != 0) + ReqTy = VectorType::get(ReqTy, EltCount); if (OnlyIfReducedTy == ReqTy) return nullptr; @@ -2117,12 +2137,12 @@ ArgVec.push_back(C); for (unsigned i = 0, e = Idxs.size(); i != e; ++i) { assert((!Idxs[i]->getType()->isVectorTy() || - Idxs[i]->getType()->getVectorNumElements() == NumVecElts) && + Idxs[i]->getType()->getVectorElementCount() == EltCount) && "getelementptr index type missmatch"); Constant *Idx = cast(Idxs[i]); - if (NumVecElts && !Idxs[i]->getType()->isVectorTy()) - Idx = ConstantVector::getSplat(NumVecElts, Idx); + if (EltCount.Min != 0 && !Idxs[i]->getType()->isVectorTy()) + Idx = ConstantVector::getSplat(EltCount, Idx); ArgVec.push_back(Idx); } @@ -2759,7 +2779,7 @@ return getFP(V->getContext(), Elts); } } - return ConstantVector::getSplat(NumElts, V); + return ConstantVector::getSplat({NumElts, false}, V); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5379,8 +5379,9 @@ if (ScalarC && ScalarM) { // We allow undefs in matching, but this transform removes those for safety. // Demanded elements analysis should be able to recover some/all of that. - C = ConstantVector::getSplat(V1Ty->getVectorNumElements(), ScalarC); - M = ConstantVector::getSplat(M->getType()->getVectorNumElements(), ScalarM); + C = ConstantVector::getSplat(V1Ty->getVectorElementCount(), ScalarC); + M = ConstantVector::getSplat(M->getType()->getVectorElementCount(), + ScalarM); Value *NewCmp = IsFP ? Builder.CreateFCmp(Pred, V1, C) : Builder.CreateICmp(Pred, V1, C); return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -774,7 +774,7 @@ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } @@ -809,7 +809,7 @@ APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1717,9 +1717,10 @@ // FIXME: If the step is non-constant, we create the vector splat with // IRBuilder. IRBuilder can constant-fold the multiply, but it doesn't // handle a constant vector splat. - Value *SplatVF = isa(Mul) - ? ConstantVector::getSplat(VF, cast(Mul)) - : Builder.CreateVectorSplat(VF, Mul); + Value *SplatVF = + isa(Mul) + ? ConstantVector::getSplat({VF, false}, cast(Mul)) + : Builder.CreateVectorSplat(VF, Mul); Builder.restoreIP(CurrIP); // We may need to add the step a number of times, depending on the unroll @@ -3731,7 +3732,7 @@ // incoming scalar reduction. VectorStart = ReductionStartValue; } else { - Identity = ConstantVector::getSplat(VF, Iden); + Identity = ConstantVector::getSplat({VF, false}, Iden); // This vector is the Identity vector where the first element is the // incoming scalar reduction. diff --git a/llvm/test/CodeGen/AArch64/scalable-vector-promotion.ll b/llvm/test/CodeGen/AArch64/scalable-vector-promotion.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/scalable-vector-promotion.ll @@ -0,0 +1,23 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -mtriple=aarch64 -codegenprepare -S < %s | FileCheck %s + +; This test intends to check vector promotion for scalable vector. Current target lowering +; rejects scalable vector before reaching getConstantVector() in CodeGenPrepare. This test +; will assert once target lowering is ready, then we can bring in implementation for non-splat +; codepath for scalable vector. + +define void @simpleOneInstructionPromotion(* %addr1, i32* %dest) { +; CHECK-LABEL: @simpleOneInstructionPromotion( +; CHECK-NEXT: [[IN1:%.*]] = load , * [[ADDR1:%.*]], align 8 +; CHECK-NEXT: [[EXTRACT:%.*]] = extractelement [[IN1]], i32 1 +; CHECK-NEXT: [[OUT:%.*]] = or i32 [[EXTRACT]], 1 +; CHECK-NEXT: store i32 [[OUT]], i32* [[DEST:%.*]], align 4 +; CHECK-NEXT: ret void +; + %in1 = load , * %addr1, align 8 + %extract = extractelement %in1, i32 1 + %out = or i32 %extract, 1 + store i32 %out, i32* %dest, align 4 + ret void +} + diff --git a/llvm/test/Transforms/InstSimplify/gep.ll b/llvm/test/Transforms/InstSimplify/gep.ll --- a/llvm/test/Transforms/InstSimplify/gep.ll +++ b/llvm/test/Transforms/InstSimplify/gep.ll @@ -103,3 +103,69 @@ ret <8 x i64*> %el } +; Check ConstantExpr::getGetElementPtr() using ElementCount for size queries - begin. + +; Constant ptr + +define i32* @ptr_idx_scalar() { +; CHECK-LABEL: @ptr_idx_scalar( +; CHECK-NEXT: ret i32* inttoptr (i64 4 to i32*) +; + %gep = getelementptr <4 x i32>, <4 x i32>* null, i64 0, i64 1 + ret i32* %gep +} + +define <2 x i32*> @ptr_idx_vector() { +; CHECK-LABEL: @ptr_idx_vector( +; CHECK-NEXT: ret <2 x i32*> getelementptr (i32, i32* null, <2 x i64> ) +; + %gep = getelementptr i32, i32* null, <2 x i64> + ret <2 x i32*> %gep +} + +define <4 x i32*> @ptr_idx_mix_scalar_vector(){ +; CHECK-LABEL: @ptr_idx_mix_scalar_vector( +; CHECK-NEXT: ret <4 x i32*> getelementptr ([42 x [3 x i32]], [42 x [3 x i32]]* null, <4 x i64> zeroinitializer, <4 x i64> , <4 x i64> zeroinitializer) +; + %gep = getelementptr [42 x [3 x i32]], [42 x [3 x i32]]* null, i64 0, <4 x i64> , i64 0 + ret <4 x i32*> %gep +} + +; Constant vector + +define <4 x i32*> @vector_idx_scalar() { +; CHECK-LABEL: @vector_idx_scalar( +; CHECK-NEXT: ret <4 x i32*> getelementptr (i32, <4 x i32*> zeroinitializer, <4 x i64> ) +; + %gep = getelementptr i32, <4 x i32*> zeroinitializer, i64 1 + ret <4 x i32*> %gep +} + +define <4 x i32*> @vector_idx_vector() { +; CHECK-LABEL: @vector_idx_vector( +; CHECK-NEXT: ret <4 x i32*> getelementptr (i32, <4 x i32*> zeroinitializer, <4 x i64> ) +; + %gep = getelementptr i32, <4 x i32*> zeroinitializer, <4 x i64> + ret <4 x i32*> %gep +} + +%struct = type { double, float } +define <4 x float*> @vector_idx_mix_scalar_vector() { +; CHECK-LABEL: @vector_idx_mix_scalar_vector( +; CHECK-NEXT: ret <4 x float*> getelementptr (%struct, <4 x %struct*> zeroinitializer, <4 x i64> zeroinitializer, <4 x i32> ) +; + %gep = getelementptr %struct, <4 x %struct*> zeroinitializer, i32 0, <4 x i32> + ret <4 x float*> %gep +} + +; Constant scalable + +define @scalable_idx_scalar() { +; CHECK-LABEL: @scalable_idx_scalar( +; CHECK-NEXT: ret getelementptr (i32, zeroinitializer, shufflevector ( insertelement ( undef, i64 1, i32 0), undef, zeroinitializer)) +; + %gep = getelementptr i32, zeroinitializer, i64 1 + ret %gep +} + +; Check ConstantExpr::getGetElementPtr() using ElementCount for size queries - end. diff --git a/llvm/unittests/FuzzMutate/OperationsTest.cpp b/llvm/unittests/FuzzMutate/OperationsTest.cpp --- a/llvm/unittests/FuzzMutate/OperationsTest.cpp +++ b/llvm/unittests/FuzzMutate/OperationsTest.cpp @@ -92,8 +92,8 @@ ConstantStruct::get(StructType::create(Ctx, "OpaqueStruct")); Constant *a = ConstantArray::get(ArrayType::get(i32->getType(), 2), {i32, i32}); - Constant *v8i8 = ConstantVector::getSplat(8, i8); - Constant *v4f16 = ConstantVector::getSplat(4, f16); + Constant *v8i8 = ConstantVector::getSplat({8, false}, i8); + Constant *v4f16 = ConstantVector::getSplat({4, false}, f16); Constant *p0i32 = ConstantPointerNull::get(PointerType::get(i32->getType(), 0)); diff --git a/llvm/unittests/IR/VerifierTest.cpp b/llvm/unittests/IR/VerifierTest.cpp --- a/llvm/unittests/IR/VerifierTest.cpp +++ b/llvm/unittests/IR/VerifierTest.cpp @@ -57,7 +57,7 @@ ConstantInt *CI = ConstantInt::get(ITy, 0); // Valid type : freeze(<2 x i32>) - Constant *CV = ConstantVector::getSplat(2, CI); + Constant *CV = ConstantVector::getSplat({2, false}, CI); FreezeInst *FI_vec = new FreezeInst(CV); FI_vec->insertBefore(RI);