diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -5147,5 +5147,5 @@ const AArch64TargetLowering *TLI = static_cast(getTargetLowering()); - return TLI->isAllActivePredicate(N); + return TLI->isAllActivePredicate(*CurDAG, N); } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -842,7 +842,7 @@ return 128; } - bool isAllActivePredicate(SDValue N) const; + bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) const; EVT getPromotedVTForPredicate(EVT VT) const; EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -15079,7 +15079,7 @@ Zero); } -static bool isAllActivePredicate(SDValue N) { +static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) { unsigned NumElts = N.getValueType().getVectorMinNumElements(); // Look through cast. @@ -15098,6 +15098,21 @@ N.getConstantOperandVal(0) == AArch64SVEPredPattern::all) return N.getValueType().getVectorMinNumElements() >= NumElts; + // If we're compiling for a specific vector-length, we can check if the + // pattern's VL equals that of the scalable vector at runtime. + if (N.getOpcode() == AArch64ISD::PTRUE) { + const auto &Subtarget = + static_cast(DAG.getSubtarget()); + unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits(); + unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits(); + if (MaxSVESize && MinSVESize == MaxSVESize) { + unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock; + unsigned PatNumElts = + getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0)); + return PatNumElts == NumElts * VScale; + } + } + return false; } @@ -15114,7 +15129,7 @@ SDValue Op2 = N->getOperand(SwapOperands ? 2 : 3); // ISD way to specify an all active predicate. - if (isAllActivePredicate(Pg)) { + if (isAllActivePredicate(DAG, Pg)) { if (UnpredOp) return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op1, Op2); @@ -16785,6 +16800,9 @@ SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); + if (isAllActivePredicate(DAG, N0)) + return N->getOperand(1); + // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform // into (OR (ASR lhs, N-1), 1), which requires less instructions for the // supported types. @@ -19800,8 +19818,9 @@ return Op; } -bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const { - return ::isAllActivePredicate(N); +bool AArch64TargetLowering::isAllActivePredicate(SelectionDAG &DAG, + SDValue N) const { + return ::isAllActivePredicate(DAG, N); } EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const { diff --git a/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll @@ -0,0 +1,24 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -verify-machineinstrs < %s | FileCheck %s + +; Test that the select returns %true, because the predicate is all active. +define @select_ptrue_fold_all_active( %true, %false) { +; CHECK-LABEL: select_ptrue_fold_all_active: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %p = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %res = select %p, %true, %false + ret %res +} + +; Test that the select returns %true, because the predicate is all active for vscale_range(2, 2) +define @select_ptrue_fold_vl8( %true, %false) vscale_range(2, 2) { +; CHECK-LABEL: select_ptrue_fold_vl8: +; CHECK: // %bb.0: +; CHECK-NEXT: ret + %p = call @llvm.aarch64.sve.ptrue.nxv4i1(i32 8) + %res = select %p, %true, %false + ret %res +} + +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32)