diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -5147,5 +5147,5 @@ const AArch64TargetLowering *TLI = static_cast(getTargetLowering()); - return TLI->isAllActivePredicate(N); + return TLI->isAllActivePredicate(*CurDAG, N); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -844,7 +844,7 @@ return 128; } - bool isAllActivePredicate(SDValue N) const; + bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) const; EVT getPromotedVTForPredicate(EVT VT) const; EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -15087,7 +15087,17 @@ Zero); } -static bool isAllActivePredicate(SDValue N) { +static bool isAllInactivePredicate(SDValue N) { + unsigned NumElts = N.getValueType().getVectorMinNumElements(); + + // Look through cast. + while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) + N = N.getOperand(0); + + return N.getOpcode() == AArch64ISD::PFALSE; +} + +static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) { unsigned NumElts = N.getValueType().getVectorMinNumElements(); // Look through cast. @@ -15106,6 +15116,21 @@ N.getConstantOperandVal(0) == AArch64SVEPredPattern::all) return N.getValueType().getVectorMinNumElements() >= NumElts; + // If we're compiling for a specific vector-length, we can check if the + // pattern's VL equals that of the scalable vector at runtime. + if (N.getOpcode() == AArch64ISD::PTRUE) { + const auto &Subtarget = + static_cast(DAG.getSubtarget()); + unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits(); + unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits(); + if (MaxSVESize && MinSVESize == MaxSVESize) { + unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock; + unsigned PatNumElts = + getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0)); + return PatNumElts == (NumElts * VScale); + } + } + return false; } @@ -15122,7 +15147,7 @@ SDValue Op2 = N->getOperand(SwapOperands ? 2 : 3); // ISD way to specify an all active predicate. - if (isAllActivePredicate(Pg)) { + if (isAllActivePredicate(DAG, Pg)) { if (UnpredOp) return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op1, Op2); @@ -16793,6 +16818,12 @@ SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); + if (isAllActivePredicate(DAG, N0)) + return N->getOperand(1); + + if (isAllInactivePredicate(N0)) + return N->getOperand(2); + // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform // into (OR (ASR lhs, N-1), 1), which requires less instructions for the // supported types. @@ -19364,7 +19395,7 @@ default: return SDValue(); case ISD::VECREDUCE_OR: - if (isAllActivePredicate(Pg)) + if (isAllActivePredicate(DAG, Pg)) // The predicate can be 'Op' because // vecreduce_or(Op & ) <=> vecreduce_or(Op). return getPTest(DAG, VT, Op, Op, AArch64CC::ANY_ACTIVE); @@ -19813,8 +19844,9 @@ return Op; } -bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const { - return ::isAllActivePredicate(N); +bool AArch64TargetLowering::isAllActivePredicate(SelectionDAG &DAG, + SDValue N) const { + return ::isAllActivePredicate(DAG, N); } EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const { diff --git a/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll @@ -0,0 +1,60 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -verify-machineinstrs < %s | FileCheck %s + +; Test that the select returns %true, because the predicate is all active. +define @select_ptrue_fold_all_active( %false, %true) { +; CHECK-LABEL: select_ptrue_fold_all_active: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: ret + %p = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %res = select %p, %true, %false + ret %res +} + +; Test that the select returns %true, because the predicate is all active for vscale_range(2, 2) +define @select_ptrue_fold_vl8( %false, %true) vscale_range(2, 2) { +; CHECK-LABEL: select_ptrue_fold_vl8: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: ret + %p = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 8) + %res = select %p, %true, %false + ret %res +} + +define @select_ptrue_fold_all_inactive( %true, %false) { +; CHECK-LABEL: select_ptrue_fold_all_inactive: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: ret + %p = call @llvm.aarch64.sve.convert.from.svbool.nxv16i1( zeroinitializer) + %res = select %p, %true, %false + ret %res +} + +define @select_ptrue_fold_all_inactive_reinterpret( %true, %false) { +; CHECK-LABEL: select_ptrue_fold_all_inactive_reinterpret: +; CHECK: // %bb.0: +; CHECK-NEXT: mov z0.d, z1.d +; CHECK-NEXT: ret + %p = call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( zeroinitializer) + %res = select %p, %true, %false + ret %res +} + +; Test that the select remains, because predicate is not all active (only half lanes are set for vscale_range(2, 2)) +define @select_ptrue_no_fold_vl4( %true, %false) vscale_range(2, 2) { +; CHECK-LABEL: select_ptrue_no_fold_vl4: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s, vl4 +; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s +; CHECK-NEXT: ret + %p = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 4) + %res = select %p, %true, %false + ret %res +} + +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32) +declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv16i1()