diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -288,9 +288,11 @@ return findScalarElement(III->getOperand(0), EltNo); } - if (ShuffleVectorInst *SVI = dyn_cast(V)) { + ShuffleVectorInst *SVI = dyn_cast(V); + // Restrict the following transformation to fixed-length vector. + if (SVI && isa(SVI->getType())) { unsigned LHSWidth = - cast(SVI->getOperand(0)->getType())->getNumElements(); + cast(SVI->getOperand(0)->getType())->getNumElements(); int InEl = SVI->getMaskValue(EltNo); if (InEl < 0) return UndefValue::get(VTy->getElementType()); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -332,14 +332,18 @@ // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast(Index); if (IndexC) { - unsigned NumElts = EI.getVectorOperandType()->getNumElements(); + ElementCount EC = EI.getVectorOperandType()->getElementCount(); + unsigned NumElts = EC.Min; // InstSimplify should handle cases where the index is invalid. - if (!IndexC->getValue().ule(NumElts)) + // For fixed-length vector, it's invalid to extract out-of-range element. + if (!EC.Scalable && IndexC->getValue().uge(NumElts)) return nullptr; // This instruction only demands the single element from the input vector. - if (NumElts != 1) { + // Skip for scalable type, the number of elements is unknown at + // compile-time. + if (!EC.Scalable && NumElts != 1) { // If the input vector has a single use, simplify it based on this use // property. if (SrcVec->hasOneUse()) { @@ -417,11 +421,13 @@ } else if (auto *SVI = dyn_cast(I)) { // If this is extracting an element from a shufflevector, figure out where // it came from and extract from the appropriate input element instead. - if (auto *Elt = dyn_cast(Index)) { - int SrcIdx = SVI->getMaskValue(Elt->getZExtValue()); + // Restrict the following transformation to fixed-length vector. + if (isa(SVI->getType()) && isa(Index)) { + int SrcIdx = + SVI->getMaskValue(cast(Index)->getZExtValue()); Value *Src; - unsigned LHSWidth = - cast(SVI->getOperand(0)->getType())->getNumElements(); + unsigned LHSWidth = cast(SVI->getOperand(0)->getType()) + ->getNumElements(); if (SrcIdx < 0) return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); @@ -432,9 +438,8 @@ Src = SVI->getOperand(1); } Type *Int32Ty = Type::getInt32Ty(EI.getContext()); - return ExtractElementInst::Create(Src, - ConstantInt::get(Int32Ty, - SrcIdx, false)); + return ExtractElementInst::Create( + Src, ConstantInt::get(Int32Ty, SrcIdx, false)); } } else if (auto *CI = dyn_cast(I)) { // Canonicalize extractelement(cast) -> cast(extractelement). diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll @@ -0,0 +1,148 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instcombine -S | FileCheck %s + +define i32 @extractelement_in_range( %a) { +; CHECK-LABEL: @extractelement_in_range( +; CHECK-NEXT: [[R:%.*]] = extractelement [[A:%.*]], i64 1 +; CHECK-NEXT: ret i32 [[R]] +; + %r = extractelement %a, i64 1 + ret i32 %r +} + +define i32 @extractelement_maybe_out_of_range( %a) { +; CHECK-LABEL: @extractelement_maybe_out_of_range( +; CHECK-NEXT: [[R:%.*]] = extractelement [[A:%.*]], i64 4 +; CHECK-NEXT: ret i32 [[R]] +; + %r = extractelement %a, i64 4 + ret i32 %r +} + +define i32 @extractelement_bitcast(float %f) { +; CHECK-LABEL: @extractelement_bitcast( +; CHECK-NEXT: [[R:%.*]] = bitcast float [[F:%.*]] to i32 +; CHECK-NEXT: ret i32 [[R]] +; + %vec_float = insertelement undef, float %f, i32 0 + %vec_int = bitcast %vec_float to + %r = extractelement %vec_int, i32 0 + ret i32 %r +} + +define i8 @extractelement_bitcast_to_trunc( %a, i32 %x) { +; CHECK-LABEL: @extractelement_bitcast_to_trunc( +; CHECK-NEXT: [[R:%.*]] = trunc i32 [[X:%.*]] to i8 +; CHECK-NEXT: ret i8 [[R]] +; + %vec = insertelement %a, i32 %x, i32 1 + %vec_cast = bitcast %vec to + %r = extractelement %vec_cast, i32 4 + ret i8 %r +} + +; TODO: Instcombine could remove the insert. +define i8 @extractelement_bitcast_wrong_insert( %a, i32 %x) { +; CHECK-LABEL: @extractelement_bitcast_wrong_insert( +; CHECK-NEXT: [[VEC:%.*]] = insertelement [[A:%.*]], i32 [[X:%.*]], i32 1 +; CHECK-NEXT: [[VEC_CAST:%.*]] = bitcast [[VEC]] to +; CHECK-NEXT: [[R:%.*]] = extractelement [[VEC_CAST]], i32 2 +; CHECK-NEXT: ret i8 [[R]] +; + %vec = insertelement %a, i32 %x, i32 1 ; <- This insert could be removed. + %vec_cast = bitcast %vec to + %r = extractelement %vec_cast, i32 2 + ret i8 %r +} + +; TODO: Instcombine could optimize to return %v. +define i32 @extractelement_shuffle_in_range(i32 %v) { +; CHECK-LABEL: @extractelement_shuffle_in_range( +; CHECK-NEXT: [[IN:%.*]] = insertelement undef, i32 [[V:%.*]], i32 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], undef, zeroinitializer +; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i32 1 +; CHECK-NEXT: ret i32 [[R]] +; + %in = insertelement undef, i32 %v, i32 0 + %splat = shufflevector %in, undef, zeroinitializer + %r = extractelement %splat, i32 1 + ret i32 %r +} + +define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) { +; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range( +; CHECK-NEXT: [[IN:%.*]] = insertelement undef, i32 [[V:%.*]], i32 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], undef, zeroinitializer +; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i32 4 +; CHECK-NEXT: ret i32 [[R]] +; + %in = insertelement undef, i32 %v, i32 0 + %splat = shufflevector %in, undef, zeroinitializer + %r = extractelement %splat, i32 4 + ret i32 %r +} + +define i32 @extractelement_shuffle_invalid_index(i32 %v) { +; CHECK-LABEL: @extractelement_shuffle_invalid_index( +; CHECK-NEXT: [[IN:%.*]] = insertelement undef, i32 [[V:%.*]], i32 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], undef, zeroinitializer +; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i32 -1 +; CHECK-NEXT: ret i32 [[R]] +; + %in = insertelement undef, i32 %v, i32 0 + %splat = shufflevector %in, undef, zeroinitializer + %r = extractelement %splat, i32 -1 + ret i32 %r +} + + +define i32 @extractelement_shuffle_symbolic_index(i32 %v, i32 %idx) { +; CHECK-LABEL: @extractelement_shuffle_symbolic_index( +; CHECK-NEXT: [[IN:%.*]] = insertelement undef, i32 [[V:%.*]], i32 0 +; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector [[IN]], undef, zeroinitializer +; CHECK-NEXT: [[R:%.*]] = extractelement [[SPLAT]], i32 [[IDX:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %in = insertelement undef, i32 %v, i32 0 + %splat = shufflevector %in, undef, zeroinitializer + %r = extractelement %splat, i32 %idx + ret i32 %r +} + +define @extractelement_insertelement_same_positions( %vec) { +; CHECK-LABEL: @extractelement_insertelement_same_positions( +; CHECK-NEXT: ret [[VEC:%.*]] +; + %vec.e0 = extractelement %vec, i32 0 + %vec.e1 = extractelement %vec, i32 1 + %vec.e2 = extractelement %vec, i32 2 + %vec.e3 = extractelement %vec, i32 3 + %1 = insertelement %vec, i32 %vec.e0, i32 0 + %2 = insertelement %1, i32 %vec.e1, i32 1 + %3 = insertelement %2, i32 %vec.e2, i32 2 + %4 = insertelement %3, i32 %vec.e3, i32 3 + ret %4 +} + +define @extractelement_insertelement_diff_positions( %vec) { +; CHECK-LABEL: @extractelement_insertelement_diff_positions( +; CHECK-NEXT: [[VEC_E0:%.*]] = extractelement [[VEC:%.*]], i32 4 +; CHECK-NEXT: [[VEC_E1:%.*]] = extractelement [[VEC]], i32 5 +; CHECK-NEXT: [[VEC_E2:%.*]] = extractelement [[VEC]], i32 6 +; CHECK-NEXT: [[VEC_E3:%.*]] = extractelement [[VEC]], i32 7 +; CHECK-NEXT: [[TMP1:%.*]] = insertelement undef, i32 [[VEC_E0]], i32 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement [[TMP1]], i32 [[VEC_E1]], i32 1 +; CHECK-NEXT: [[TMP3:%.*]] = insertelement [[TMP2]], i32 [[VEC_E2]], i32 2 +; CHECK-NEXT: [[TMP4:%.*]] = insertelement [[TMP3]], i32 [[VEC_E3]], i32 3 +; CHECK-NEXT: ret [[TMP4]] +; + %vec.e0 = extractelement %vec, i32 4 + %vec.e1 = extractelement %vec, i32 5 + %vec.e2 = extractelement %vec, i32 6 + %vec.e3 = extractelement %vec, i32 7 + %1 = insertelement %vec, i32 %vec.e0, i32 0 + %2 = insertelement %1, i32 %vec.e1, i32 1 + %3 = insertelement %2, i32 %vec.e2, i32 2 + %4 = insertelement %3, i32 %vec.e3, i32 3 + ret %4 +}