diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -3637,7 +3637,7 @@ There are certain limitations on the type: * The type can be used for function parameters and return values. * The supported LLVM operations on this type are strictly limited to ``load``, - ``store``, ``phi`` and ``alloca`` instructions. + ``store``, ``phi``, ``select`` and ``alloca`` instructions. The predicate-as-counter type is a scalable type. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -377,6 +377,8 @@ if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) { addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass); + setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom); + setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand); } // Compute derived properties from the register classes @@ -8916,6 +8918,22 @@ SDLoc DL(Op); EVT Ty = Op.getValueType(); + if (Ty == MVT::aarch64svcount) { + TVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, TVal); + FVal = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, FVal); + EVT CCVT = CCVal.getValueType(); + SDValue ID = + DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, CCVT); + SDValue Zero = DAG.getConstant(0, DL, CCVT); + SDValue SplatVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, CCVT, CCVal, + DAG.getValueType(MVT::i1)); + SDValue SplatPred = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv16i1, ID, Zero, SplatVal); + SDValue Sel = + DAG.getNode(ISD::VSELECT, DL, MVT::nxv16i1, SplatPred, TVal, FVal); + return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, Ty, Sel); + } + if (Ty.isScalableVector()) { SDValue TruncCC = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, CCVal); MVT PredVT = MVT::getVectorVT(MVT::i1, Ty.getVectorElementCount()); @@ -19838,7 +19856,7 @@ if (N0.getOpcode() != ISD::SETCC) return SDValue(); - if (ResVT.isScalableVector()) + if (ResVT.isScalableVector() || ResVT == MVT::aarch64svcount) return SDValue(); // Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2445,6 +2445,10 @@ def : Pat<(nxv2f64 (bitconvert (nxv8bf16 ZPR:$src))), (nxv2f64 ZPR:$src)>; } + // These allow casting from/to the opaque aarch64svcount type. + def : Pat<(aarch64svcount (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + def : Pat<(nxv16i1 (reinterpret_cast (aarch64svcount PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; + // These allow casting from/to unpacked predicate types. def : Pat<(nxv16i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1629,6 +1629,10 @@ return &SI; } + // The code below doesn't work on aarch64_svcount. + if (TrueVal->getType()->isAArch64SvcountTy()) + return nullptr; + // FIXME: This code is nearly duplicated in InstSimplify. Using/refactoring // decomposeBitTestICmp() might help. { diff --git a/llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll b/llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll --- a/llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll +++ b/llvm/test/CodeGen/AArch64/sme-aarch64-svcount.ll @@ -115,3 +115,32 @@ call void @take_svcount_5(aarch64_svcount %arg, aarch64_svcount %arg, aarch64_svcount %arg, aarch64_svcount %arg, aarch64_svcount %arg) ret void } + +; Test code-generation of aarch64svcount being used in a select to support e.g. +; return k < 42 ? y : z; +; where y and z are of types svcount_t. +define aarch64_svcount @foo(i32 %k, aarch64_svcount %x, aarch64_svcount %y) { +; CHECKO0-LABEL: foo: +; CHECKO0: // %bb.0: +; CHECKO0-NEXT: mov p2.b, p1.b +; CHECKO0-NEXT: mov p1.b, p0.b +; CHECKO0-NEXT: subs w8, w0, #42 +; CHECKO0-NEXT: cset w8, lt +; CHECKO0-NEXT: sbfx w9, w8, #0, #1 +; CHECKO0-NEXT: mov w8, wzr +; CHECKO0-NEXT: whilelo p0.b, w8, w9 +; CHECKO0-NEXT: sel p0.b, p0, p1.b, p2.b +; CHECKO0-NEXT: ret +; +; CHECKO3-LABEL: foo: +; CHECKO3: // %bb.0: +; CHECKO3-NEXT: cmp w0, #42 +; CHECKO3-NEXT: cset w8, lt +; CHECKO3-NEXT: sbfx w8, w8, #0, #1 +; CHECKO3-NEXT: whilelo p2.b, wzr, w8 +; CHECKO3-NEXT: sel p0.b, p2, p0.b, p1.b +; CHECKO3-NEXT: ret + %cmp = icmp slt i32 %k, 42 + %x.y = select i1 %cmp, aarch64_svcount %x, aarch64_svcount %y + ret aarch64_svcount %x.y +} diff --git a/llvm/test/Transforms/InstCombine/AArch64/sme-svcount.ll b/llvm/test/Transforms/InstCombine/AArch64/sme-svcount.ll --- a/llvm/test/Transforms/InstCombine/AArch64/sme-svcount.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sme-svcount.ll @@ -10,3 +10,11 @@ %res = load aarch64_svcount, ptr %ptr ret aarch64_svcount %res } + +; Test that instcombine doesn't try to query the (scalable) size of aarch64_svcount +; in foldSelectInstWithICmp. +define aarch64_svcount @test_combine_on_select(aarch64_svcount %x, aarch64_svcount %y, i32 %k) { + %cmp = icmp sgt i32 %k, 42 + %x.y = select i1 %cmp, aarch64_svcount %x, aarch64_svcount %y + ret aarch64_svcount %x.y +}