diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17971,38 +17971,9 @@ DAG.getTargetLoweringInfo().isTypeLegal(VT) && "Expected legal fixed length vector!"); - int PgPattern; - switch (VT.getVectorNumElements()) { - default: - llvm_unreachable("unexpected element count for SVE predicate"); - case 1: - PgPattern = AArch64SVEPredPattern::vl1; - break; - case 2: - PgPattern = AArch64SVEPredPattern::vl2; - break; - case 4: - PgPattern = AArch64SVEPredPattern::vl4; - break; - case 8: - PgPattern = AArch64SVEPredPattern::vl8; - break; - case 16: - PgPattern = AArch64SVEPredPattern::vl16; - break; - case 32: - PgPattern = AArch64SVEPredPattern::vl32; - break; - case 64: - PgPattern = AArch64SVEPredPattern::vl64; - break; - case 128: - PgPattern = AArch64SVEPredPattern::vl128; - break; - case 256: - PgPattern = AArch64SVEPredPattern::vl256; - break; - } + unsigned PgPattern = + getSVEPredPatternFromNumElements(VT.getVectorNumElements()); + assert(PgPattern && "Unexpected element count for SVE predicate"); // TODO: For vectors that are exactly getMaxSVEVectorSizeInBits big, we can // use AArch64SVEPredPattern::all, which can enable the use of unpredicated diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -593,39 +593,11 @@ cast(IntrPG->getOperand(0))->getZExtValue(); // Can the intrinsic's predicate be converted to a known constant index? - unsigned Idx; - switch (PTruePattern) { - default: + unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern); + if (!MinNumElts) return None; - case AArch64SVEPredPattern::vl1: - Idx = 0; - break; - case AArch64SVEPredPattern::vl2: - Idx = 1; - break; - case AArch64SVEPredPattern::vl3: - Idx = 2; - break; - case AArch64SVEPredPattern::vl4: - Idx = 3; - break; - case AArch64SVEPredPattern::vl5: - Idx = 4; - break; - case AArch64SVEPredPattern::vl6: - Idx = 5; - break; - case AArch64SVEPredPattern::vl7: - Idx = 6; - break; - case AArch64SVEPredPattern::vl8: - Idx = 7; - break; - case AArch64SVEPredPattern::vl16: - Idx = 15; - break; - } + unsigned Idx = MinNumElts - 1; // Increment the index if extracting the element after the last active // predicate element. if (IsAfter) @@ -678,24 +650,9 @@ return IC.replaceInstUsesWith(II, VScale); } - unsigned MinNumElts = 0; - switch (Pattern) { - default: + unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern); + if (!MinNumElts) return None; - case AArch64SVEPredPattern::vl1: - case AArch64SVEPredPattern::vl2: - case AArch64SVEPredPattern::vl3: - case AArch64SVEPredPattern::vl4: - case AArch64SVEPredPattern::vl5: - case AArch64SVEPredPattern::vl6: - case AArch64SVEPredPattern::vl7: - case AArch64SVEPredPattern::vl8: - MinNumElts = Pattern; - break; - case AArch64SVEPredPattern::vl16: - MinNumElts = 16; - break; - } return NumElts >= MinNumElts ? Optional(IC.replaceInstUsesWith( diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -454,6 +454,79 @@ #include "AArch64GenSystemOperands.inc" } +/// Return the number of active elements for VL1 to VL256 predicate pattern, +/// zero for all other patterns. +inline unsigned getNumElementsFromSVEPredPattern(unsigned Pattern) { + unsigned MinNumElts = 0; + switch (Pattern) { + default: + return 0; + case AArch64SVEPredPattern::vl1: + case AArch64SVEPredPattern::vl2: + case AArch64SVEPredPattern::vl3: + case AArch64SVEPredPattern::vl4: + case AArch64SVEPredPattern::vl5: + case AArch64SVEPredPattern::vl6: + case AArch64SVEPredPattern::vl7: + case AArch64SVEPredPattern::vl8: + MinNumElts = Pattern; + break; + case AArch64SVEPredPattern::vl16: + MinNumElts = 16; + break; + case AArch64SVEPredPattern::vl32: + MinNumElts = 32; + break; + case AArch64SVEPredPattern::vl64: + MinNumElts = 64; + break; + case AArch64SVEPredPattern::vl128: + MinNumElts = 128; + break; + case AArch64SVEPredPattern::vl256: + MinNumElts = 256; + break; + } + return MinNumElts; +} + +/// Return specific VL predicate pattern based on the number of elements. +inline unsigned getSVEPredPatternFromNumElements(unsigned MinNumElts) { + unsigned PgPattern; + switch (MinNumElts) { + default: + llvm_unreachable("unexpected element count for SVE predicate"); + case 1: + PgPattern = AArch64SVEPredPattern::vl1; + break; + case 2: + PgPattern = AArch64SVEPredPattern::vl2; + break; + case 4: + PgPattern = AArch64SVEPredPattern::vl4; + break; + case 8: + PgPattern = AArch64SVEPredPattern::vl8; + break; + case 16: + PgPattern = AArch64SVEPredPattern::vl16; + break; + case 32: + PgPattern = AArch64SVEPredPattern::vl32; + break; + case 64: + PgPattern = AArch64SVEPredPattern::vl64; + break; + case 128: + PgPattern = AArch64SVEPredPattern::vl128; + break; + case 256: + PgPattern = AArch64SVEPredPattern::vl256; + break; + } + return PgPattern; +} + namespace AArch64ExactFPImm { struct ExactFPImm { const char *Name;