diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -1436,7 +1436,7 @@ return std::nullopt; } -bool SimplifyValuePattern(SmallVector &Vec) { +bool SimplifyValuePattern(SmallVector &Vec, bool AllowPoison) { size_t VecSize = Vec.size(); if (VecSize == 1) return true; @@ -1446,13 +1446,20 @@ for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize; RHS != Vec.end(); LHS++, RHS++) { - if (*LHS != nullptr && *RHS != nullptr && *LHS == *RHS) - continue; - return false; + if (*LHS != nullptr && *RHS != nullptr) { + if (*LHS == *RHS) + continue; + else + return false; + } + if (!AllowPoison) + return false; + if (*LHS == nullptr && *RHS != nullptr) + *LHS = *RHS; } Vec.resize(HalfVecSize); - SimplifyValuePattern(Vec); + SimplifyValuePattern(Vec, AllowPoison); return true; } @@ -1476,7 +1483,9 @@ CurrentInsertElt = InsertElt->getOperand(0); } - if (!SimplifyValuePattern(Elts)) + bool AllowPoison = + isa(CurrentInsertElt) && isa(Default); + if (!SimplifyValuePattern(Elts, AllowPoison)) return std::nullopt; // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b) @@ -1484,9 +1493,13 @@ Builder.SetInsertPoint(&II); Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType()); for (size_t I = 0; I < Elts.size(); I++) { + if (Elts[I] == nullptr) + continue; InsertEltChain = Builder.CreateInsertElement(InsertEltChain, Elts[I], Builder.getInt64(I)); } + if (InsertEltChain == nullptr) + return std::nullopt; // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64 // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll --- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-dupqlane.ll @@ -96,12 +96,11 @@ ; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0 ; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 ; CHECK-NEXT: [[TMP3:%.*]] = insertelement <8 x half> [[TMP2]], half [[C:%.*]], i64 2 -; CHECK-NEXT: [[TMP4:%.*]] = insertelement <8 x half> [[TMP3]], half [[A]], i64 4 -; CHECK-NEXT: [[TMP5:%.*]] = insertelement <8 x half> [[TMP4]], half [[B]], i64 5 -; CHECK-NEXT: [[TMP6:%.*]] = insertelement <8 x half> [[TMP5]], half [[C]], i64 6 -; CHECK-NEXT: [[TMP7:%.*]] = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP6]], i64 0) -; CHECK-NEXT: [[TMP8:%.*]] = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( [[TMP7]], i64 0) -; CHECK-NEXT: ret [[TMP8]] +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP3]], i64 0) +; CHECK-NEXT: [[TMP5:%.*]] = bitcast [[TMP4]] to +; CHECK-NEXT: [[TMP6:%.*]] = shufflevector [[TMP5]], poison, zeroinitializer +; CHECK-NEXT: [[TMP7:%.*]] = bitcast [[TMP6]] to +; CHECK-NEXT: ret [[TMP7]] ; %1 = insertelement <8 x half> poison, half %a, i64 0 %2 = insertelement <8 x half> %1, half %b, i64 1 @@ -114,6 +113,57 @@ ret %8 } +define dso_local @dupq_f16_abnull_pattern(half %a, half %b) { +; CHECK-LABEL: @dupq_f16_abnull_pattern( +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP2]], i64 0) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast [[TMP3]] to +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector [[TMP4]], poison, zeroinitializer +; CHECK-NEXT: [[TMP6:%.*]] = bitcast [[TMP5]] to +; CHECK-NEXT: ret [[TMP6]] +; + %1 = insertelement <8 x half> poison, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> %2, i64 0) + %4 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %3, i64 0) + ret %4 +} + +define dso_local @neg_dupq_f16_non_poison_fixed(half %a, half %b, <8 x half> %v) { +; CHECK-LABEL: @neg_dupq_f16_non_poison_fixed( +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> [[V:%.*]], half [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> [[TMP2]], i64 0) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( [[TMP3]], i64 0) +; CHECK-NEXT: ret [[TMP4]] +; + %1 = insertelement <8 x half> %v, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 0 + %4 = insertelement <8 x half> %3, half %b, i64 1 + %5 = tail call @llvm.vector.insert.nxv8f16.v8f16( poison, <8 x half> %4, i64 0) + %6 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %5, i64 0) + ret %6 +} + +define dso_local @neg_dupq_f16_into_non_poison_scalable(half %a, half %b, %v) { +; CHECK-LABEL: @neg_dupq_f16_into_non_poison_scalable( +; CHECK-NEXT: [[TMP1:%.*]] = insertelement <8 x half> poison, half [[A:%.*]], i64 0 +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <8 x half> [[TMP1]], half [[B:%.*]], i64 1 +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.vector.insert.nxv8f16.v8f16( [[V:%.*]], <8 x half> [[TMP2]], i64 0) +; CHECK-NEXT: [[TMP4:%.*]] = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( [[TMP3]], i64 0) +; CHECK-NEXT: ret [[TMP4]] +; + %1 = insertelement <8 x half> poison, half %a, i64 0 + %2 = insertelement <8 x half> %1, half %b, i64 1 + %3 = insertelement <8 x half> %2, half %a, i64 0 + %4 = insertelement <8 x half> %3, half %b, i64 1 + %5 = tail call @llvm.vector.insert.nxv8f16.v8f16( %v, <8 x half> %4, i64 0) + %6 = tail call @llvm.aarch64.sve.dupq.lane.nxv8f16( %5, i64 0) + ret %6 +} + ; Insert %c to override the last element in the insertelement chain, which will fail to combine define dso_local @neg_dupq_f16_abcd_pattern_double_insert(half %a, half %b, half %c, half %d) {