diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -18022,38 +18022,9 @@ DAG.getTargetLoweringInfo().isTypeLegal(VT) && "Expected legal fixed length vector!"); - int PgPattern; - switch (VT.getVectorNumElements()) { - default: - llvm_unreachable("unexpected element count for SVE predicate"); - case 1: - PgPattern = AArch64SVEPredPattern::vl1; - break; - case 2: - PgPattern = AArch64SVEPredPattern::vl2; - break; - case 4: - PgPattern = AArch64SVEPredPattern::vl4; - break; - case 8: - PgPattern = AArch64SVEPredPattern::vl8; - break; - case 16: - PgPattern = AArch64SVEPredPattern::vl16; - break; - case 32: - PgPattern = AArch64SVEPredPattern::vl32; - break; - case 64: - PgPattern = AArch64SVEPredPattern::vl64; - break; - case 128: - PgPattern = AArch64SVEPredPattern::vl128; - break; - case 256: - PgPattern = AArch64SVEPredPattern::vl256; - break; - } + unsigned PgPattern = + getSVEPredPatternFromNumElements(VT.getVectorNumElements()); + assert(PgPattern && "Unexpected element count for SVE predicate"); // TODO: For vectors that are exactly getMaxSVEVectorSizeInBits big, we can // use AArch64SVEPredPattern::all, which can enable the use of unpredicated diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -593,39 +593,11 @@ cast(IntrPG->getOperand(0))->getZExtValue(); // Can the intrinsic's predicate be converted to a known constant index? - unsigned Idx; - switch (PTruePattern) { - default: + unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern); + if (!MinNumElts) return None; - case AArch64SVEPredPattern::vl1: - Idx = 0; - break; - case AArch64SVEPredPattern::vl2: - Idx = 1; - break; - case AArch64SVEPredPattern::vl3: - Idx = 2; - break; - case AArch64SVEPredPattern::vl4: - Idx = 3; - break; - case AArch64SVEPredPattern::vl5: - Idx = 4; - break; - case AArch64SVEPredPattern::vl6: - Idx = 5; - break; - case AArch64SVEPredPattern::vl7: - Idx = 6; - break; - case AArch64SVEPredPattern::vl8: - Idx = 7; - break; - case AArch64SVEPredPattern::vl16: - Idx = 15; - break; - } + unsigned Idx = MinNumElts - 1; // Increment the index if extracting the element after the last active // predicate element. if (IsAfter) @@ -678,26 +650,9 @@ return IC.replaceInstUsesWith(II, VScale); } - unsigned MinNumElts = 0; - switch (Pattern) { - default: - return None; - case AArch64SVEPredPattern::vl1: - case AArch64SVEPredPattern::vl2: - case AArch64SVEPredPattern::vl3: - case AArch64SVEPredPattern::vl4: - case AArch64SVEPredPattern::vl5: - case AArch64SVEPredPattern::vl6: - case AArch64SVEPredPattern::vl7: - case AArch64SVEPredPattern::vl8: - MinNumElts = Pattern; - break; - case AArch64SVEPredPattern::vl16: - MinNumElts = 16; - break; - } + unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern); - return NumElts >= MinNumElts + return MinNumElts && NumElts >= MinNumElts ? Optional(IC.replaceInstUsesWith( II, ConstantInt::get(II.getType(), MinNumElts))) : None; diff --git a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h --- a/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h +++ b/llvm/lib/Target/AArch64/Utils/AArch64BaseInfo.h @@ -454,6 +454,60 @@ #include "AArch64GenSystemOperands.inc" } +/// Return the number of active elements for VL1 to VL256 predicate pattern, +/// zero for all other patterns. +inline unsigned getNumElementsFromSVEPredPattern(unsigned Pattern) { + switch (Pattern) { + default: + return 0; + case AArch64SVEPredPattern::vl1: + case AArch64SVEPredPattern::vl2: + case AArch64SVEPredPattern::vl3: + case AArch64SVEPredPattern::vl4: + case AArch64SVEPredPattern::vl5: + case AArch64SVEPredPattern::vl6: + case AArch64SVEPredPattern::vl7: + case AArch64SVEPredPattern::vl8: + return Pattern; + case AArch64SVEPredPattern::vl16: + return 16; + case AArch64SVEPredPattern::vl32: + return 32; + case AArch64SVEPredPattern::vl64: + return 64; + case AArch64SVEPredPattern::vl128: + return 128; + case AArch64SVEPredPattern::vl256: + return 256; + } +} + +/// Return specific VL predicate pattern based on the number of elements. +inline unsigned getSVEPredPatternFromNumElements(unsigned MinNumElts) { + switch (MinNumElts) { + default: + llvm_unreachable("unexpected element count for SVE predicate"); + case 1: + return AArch64SVEPredPattern::vl1; + case 2: + return AArch64SVEPredPattern::vl2; + case 4: + return AArch64SVEPredPattern::vl4; + case 8: + return AArch64SVEPredPattern::vl8; + case 16: + return AArch64SVEPredPattern::vl16; + case 32: + return AArch64SVEPredPattern::vl32; + case 64: + return AArch64SVEPredPattern::vl64; + case 128: + return AArch64SVEPredPattern::vl128; + case 256: + return AArch64SVEPredPattern::vl256; + } +} + namespace AArch64ExactFPImm { struct ExactFPImm { const char *Name;