diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -22297,6 +22297,12 @@ return DAG.getNode(ISD::INSERT_SUBVECTOR, SDLoc(N), VT, N0.getOperand(0), N1, N2); + // Combine INSERT_SUBVECTOR(UNDEF, SPLAT(X)) -> SPLAT(X) + if (N0.isUndef() && N1.getOpcode() == ISD::SPLAT_VECTOR) { + SDValue Scalar = DAG.getSplatValue(N1, true); + return DAG.getSplatVector(VT, SDLoc(N), Scalar); + } + // Eliminate an intermediate insert into an undef vector: // insert_subvector undef, (insert_subvector undef, X, 0), N2 --> // insert_subvector undef, X, N2 diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -10990,6 +10990,9 @@ return SDValue(); } +static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL, + EVT VT); + SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const { assert(Op.getValueType().isScalableVector() && @@ -11041,10 +11044,7 @@ if (Vec0.isUndef()) return Op; - Optional PredPattern = - getSVEPredPatternFromNumElements(InVT.getVectorNumElements()); - auto PredTy = VT.changeVectorElementType(MVT::i1); - SDValue PTrue = getPTrue(DAG, DL, PredTy, *PredPattern); + SDValue PTrue = getPredicateForFixedLengthVector(DAG, DL, InVT); SDValue ScalableVec1 = convertToScalableVector(DAG, VT, Vec1); return DAG.getNode(ISD::VSELECT, DL, VT, PTrue, ScalableVec1, Vec0); } @@ -13970,12 +13970,112 @@ return SDValue(); } +// Recursively propagate the scalable predicate type to the leafs. +/// \p PassthruVal is the value for lanes that are disabled under the active +/// vector length (i.e. when sizeof(SVE register) > sizeof(fixed-width vector). +/// \p Gaps tells whether it is allowed to have gaps in the result +/// vector that weren't there when the vectors were all fixed-width. +/// Gaps can be inserted when concatenating two scalable vectors that +/// were previously fixed-width vectors, if the active vector length +/// does not match the fixed-width vector length. +static SDValue propagatePredicateTy(TargetLowering::DAGCombinerInfo &DCI, + SDValue V, bool PassthruVal, bool Gaps) { + SelectionDAG &DAG = DCI.DAG; + SDLoc DL(V.getNode()); + + const auto &Subtarget = + static_cast(DAG.getSubtarget()); + if (!Subtarget.hasSVE() || !V->hasOneUse()) + return SDValue(); + + // End of recursion. + if (V.getValueType().isScalableVector() || isa(V.getNode())) + return V; + + switch (V.getOpcode()) { + case ISD::SETCC: { + SDValue CmpOp0 = V.getOperand(0); + SDValue CmpOp1 = V.getOperand(1); + SDValue CmpPred = V.getOperand(2); + EVT FixedVT = V.getOperand(0).getValueType(); + EVT ScalableVT = getContainerForFixedLengthVector(DAG, FixedVT); + EVT PredVT = ScalableVT.changeVectorElementType(MVT::i1); + + SDValue ZeroIdx = DAG.getConstant(0, DL, MVT::i64); + SDValue Passthru = ScalableVT.isInteger() + ? DAG.getConstant(PassthruVal, DL, ScalableVT) + : DAG.getConstantFP(PassthruVal, DL, ScalableVT); + CmpOp0 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVT, Passthru, + CmpOp0, ZeroIdx); + CmpOp1 = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ScalableVT, Passthru, + CmpOp1, ZeroIdx); + return DAG.getNode(ISD::SETCC, DL, PredVT, {CmpOp0, CmpOp1, CmpPred}); + } + case ISD::OR: + case ISD::XOR: + if (SDValue Op0 = + propagatePredicateTy(DCI, V.getOperand(0), PassthruVal, Gaps)) + if (SDValue Op1 = + propagatePredicateTy(DCI, V.getOperand(1), PassthruVal, Gaps)) + return DAG.getNode(V.getOpcode(), DL, Op0.getValueType(), Op0, Op1); + break; + case ISD::TRUNCATE: + if (SDValue Op = + propagatePredicateTy(DCI, V.getOperand(0), PassthruVal, Gaps)) + return Op; + break; + case ISD::CONCAT_VECTORS: { + if (!Gaps && Subtarget.getMinSVEVectorSizeInBits() != + Subtarget.getMaxSVEVectorSizeInBits()) + return SDValue(); + + SmallVector ConcatOps; + for (unsigned I = 0; I < V.getNumOperands(); ++I) { + SDValue Op = + propagatePredicateTy(DCI, V.getOperand(I), PassthruVal, Gaps); + if (!Op) + return SDValue(); + ConcatOps.push_back(Op); + } + + // Now generate a new vecreduce_or with the new scalable types. + ElementCount NewConcatEC = + ConcatOps[0].getValueType().getVectorElementCount() * ConcatOps.size(); + EVT NewConcatVT = EVT::getVectorVT( + *DAG.getContext(), ConcatOps[0].getValueType().getVectorElementType(), + NewConcatEC); + return DAG.getNode(ISD::CONCAT_VECTORS, DL, NewConcatVT, ConcatOps); + } + case ISD::VECREDUCE_OR: { + if (SDValue NewOp = + propagatePredicateTy(DCI, V.getOperand(0), PassthruVal, Gaps)) + return DAG.getNode(ISD::VECREDUCE_OR, DL, V.getValueType(), NewOp); + break; + } + default: + break; + } + + // Can't propagate scalable predicate any further. + return SDValue(); +} + static SDValue performANDCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SelectionDAG &DAG = DCI.DAG; SDValue LHS = N->getOperand(0); EVT VT = N->getValueType(0); - if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT)) + if (!DAG.getTargetLoweringInfo().isTypeLegal(VT)) + return SDValue(); + + if (auto *C = dyn_cast(N->getOperand(1))) { + if (C->getZExtValue() == 1 && + N->getOperand(0)->getOpcode() == ISD::VECREDUCE_OR) + return propagatePredicateTy(DCI, N->getOperand(0), /*PassthruVal=*/false, + /*Gaps=*/true); + } + + if (!VT.isVector()) return SDValue(); if (VT.isScalableVector()) @@ -16739,6 +16839,9 @@ SDValue N0 = N->getOperand(0); EVT CCVT = N0.getValueType(); + if (isAllActivePredicate(N0)) + return N->getOperand(1); + // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform // into (OR (ASR lhs, N-1), 1), which requires less instructions for the // supported types. @@ -18806,6 +18909,7 @@ break; case MVT::i16: case MVT::f16: + case MVT::bf16: MaskVT = MVT::nxv8i1; break; case MVT::i32: diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-ptest.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-ptest.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-ptest.ll @@ -0,0 +1,97 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=aarch64 -mattr=+sve < %s | FileCheck %s + +define i1 @ptest_v16i1_256bit_min_sve(float* %a, float * %b) vscale_range(2, 0) { +; CHECK-LABEL: ptest_v16i1_256bit_min_sve: +; CHECK: // %bb.0: +; CHECK-NEXT: mov x8, #8 +; CHECK-NEXT: ptrue p0.s, vl8 +; CHECK-NEXT: mov z2.s, #0 // =0x0 +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0] +; CHECK-NEXT: sel z0.s, p0, z0.s, z2.s +; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s +; CHECK-NEXT: fcmeq p0.s, p1/z, z0.s, #0.0 +; CHECK-NEXT: fcmeq p2.s, p1/z, z1.s, #0.0 +; CHECK-NEXT: not p0.b, p1/z, p0.b +; CHECK-NEXT: not p1.b, p1/z, p2.b +; CHECK-NEXT: uzp1 p0.h, p1.h, p0.h +; CHECK-NEXT: ptrue p1.h +; CHECK-NEXT: ptest p1, p0.b +; CHECK-NEXT: cset w0, ne +; CHECK-NEXT: ret + %v0 = bitcast float* %a to <16 x float>* + %v1 = load <16 x float>, <16 x float>* %v0, align 4 + %v2 = fcmp une <16 x float> %v1, zeroinitializer + %v3 = call i1 @llvm.vector.reduce.or.i1.v16i1 (<16 x i1> %v2) + ret i1 %v3 +} + +define i1 @ptest_v16i1_512bit_min_sve(float* %a, float * %b) vscale_range(4, 0) { +; CHECK-LABEL: ptest_v16i1_512bit_min_sve: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s, vl16 +; CHECK-NEXT: mov z1.s, #0 // =0x0 +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: not p1.b, p0/z, p1.b +; CHECK-NEXT: ptest p0, p1.b +; CHECK-NEXT: cset w0, ne +; CHECK-NEXT: ret + %v0 = bitcast float* %a to <16 x float>* + %v1 = load <16 x float>, <16 x float>* %v0, align 4 + %v2 = fcmp une <16 x float> %v1, zeroinitializer + %v3 = call i1 @llvm.vector.reduce.or.i1.v16i1 (<16 x i1> %v2) + ret i1 %v3 +} + +define i1 @ptest_v16i1_512bit_sve(float* %a, float * %b) vscale_range(4, 4) { +; CHECK-LABEL: ptest_v16i1_512bit_sve: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: not p1.b, p0/z, p1.b +; CHECK-NEXT: ptest p0, p1.b +; CHECK-NEXT: cset w0, ne +; CHECK-NEXT: ret + %v0 = bitcast float* %a to <16 x float>* + %v1 = load <16 x float>, <16 x float>* %v0, align 4 + %v2 = fcmp une <16 x float> %v1, zeroinitializer + %v3 = call i1 @llvm.vector.reduce.or.i1.v16i1 (<16 x i1> %v2) + ret i1 %v3 +} + +define i1 @ptest_or_v16i1_512bit_min_sve(float* %a, float * %b) vscale_range(4, 0) { +; CHECK-LABEL: ptest_or_v16i1_512bit_min_sve: +; CHECK: // %bb.0: +; CHECK-NEXT: ptrue p0.s, vl16 +; CHECK-NEXT: mov z2.s, #0 // =0x0 +; CHECK-NEXT: ld1w { z0.s }, p0/z, [x0] +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x1] +; CHECK-NEXT: sel z0.s, p0, z0.s, z2.s +; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s +; CHECK-NEXT: ptrue p0.s +; CHECK-NEXT: fcmeq p1.s, p0/z, z0.s, #0.0 +; CHECK-NEXT: fcmeq p2.s, p0/z, z1.s, #0.0 +; CHECK-NEXT: not p1.b, p0/z, p1.b +; CHECK-NEXT: not p2.b, p0/z, p2.b +; CHECK-NEXT: orr p1.b, p0/z, p1.b, p2.b +; CHECK-NEXT: ptest p0, p1.b +; CHECK-NEXT: cset w0, ne +; CHECK-NEXT: ret + %v0 = bitcast float* %a to <16 x float>* + %v1 = load <16 x float>, <16 x float>* %v0, align 4 + %v2 = fcmp une <16 x float> %v1, zeroinitializer + %v3 = bitcast float* %b to <16 x float>* + %v4 = load <16 x float>, <16 x float>* %v3, align 4 + %v5 = fcmp une <16 x float> %v4, zeroinitializer + %v6 = or <16 x i1> %v2, %v5 + %v7 = call i1 @llvm.vector.reduce.or.i1.v16i1 (<16 x i1> %v6) + ret i1 %v7 +} + +declare i1 @llvm.vector.reduce.or.i1.v16i1(<16 x i1>)