diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3962,11 +3962,10 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, unsigned Depth) const { EVT VT = Op.getValueType(); - // TODO: Assume we don't know anything for now. - if (VT.isScalableVector()) - return 1; - - APInt DemandedElts = VT.isVector() + // Since the number of lanes in a scalable vector is unknown at compile time, + // we track one bit which is implicitly broadcast to all lanes. This means + // that all lanes in a scalable vector are considered demanded. + APInt DemandedElts = VT.isFixedLengthVector() ? APInt::getAllOnes(VT.getVectorNumElements()) : APInt(1, 1); return ComputeNumSignBits(Op, DemandedElts, Depth); @@ -3989,7 +3988,7 @@ if (Depth >= MaxRecursionDepth) return 1; // Limit search depth. - if (!DemandedElts || VT.isScalableVector()) + if (!DemandedElts) return 1; // No demanded elts, better to assume we don't know anything. unsigned Opcode = Op.getOpcode(); @@ -4004,7 +4003,16 @@ case ISD::MERGE_VALUES: return ComputeNumSignBits(Op.getOperand(Op.getResNo()), DemandedElts, Depth + 1); + case ISD::SPLAT_VECTOR: { + // Check if the sign bits of source go down as far as the truncated value. + unsigned NumSrcBits = Op.getOperand(0).getValueSizeInBits(); + unsigned NumSrcSignBits = ComputeNumSignBits(Op.getOperand(0), Depth + 1); + if (NumSrcSignBits > (NumSrcBits - VTBits)) + return NumSrcSignBits - (NumSrcBits - VTBits); + break; + } case ISD::BUILD_VECTOR: + assert(!VT.isScalableVector()); Tmp = VTBits; for (unsigned i = 0, e = Op.getNumOperands(); (i < e) && (Tmp > 1); ++i) { if (!DemandedElts[i]) @@ -4049,6 +4057,8 @@ } case ISD::BITCAST: { + if (VT.isScalableVector()) + return 1; SDValue N0 = Op.getOperand(0); EVT SrcVT = N0.getValueType(); unsigned SrcBits = SrcVT.getScalarSizeInBits(); @@ -4106,6 +4116,8 @@ Tmp2 = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1); return std::max(Tmp, Tmp2); case ISD::SIGN_EXTEND_VECTOR_INREG: { + if (VT.isScalableVector()) + return 1; SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorNumElements()); @@ -4323,6 +4335,8 @@ break; } case ISD::EXTRACT_ELEMENT: { + if (VT.isScalableVector()) + return 1; const int KnownSign = ComputeNumSignBits(Op.getOperand(0), Depth+1); const int BitWidth = Op.getValueSizeInBits(); const int Items = Op.getOperand(0).getValueSizeInBits() / BitWidth; @@ -4336,6 +4350,8 @@ return std::clamp(KnownSign - rIndex * BitWidth, 0, BitWidth); } case ISD::INSERT_VECTOR_ELT: { + if (VT.isScalableVector()) + return 1; // If we know the element index, split the demand between the // source vector and the inserted element, otherwise assume we need // the original demanded vector elements and the value. @@ -4366,6 +4382,8 @@ return Tmp; } case ISD::EXTRACT_VECTOR_ELT: { + if (VT.isScalableVector()) + return 1; SDValue InVec = Op.getOperand(0); SDValue EltNo = Op.getOperand(1); EVT VecVT = InVec.getValueType(); @@ -4404,6 +4422,8 @@ return ComputeNumSignBits(Src, DemandedSrcElts, Depth + 1); } case ISD::CONCAT_VECTORS: { + if (VT.isScalableVector()) + return 1; // Determine the minimum number of sign bits across all demanded // elts of the input vectors. Early out if the result is already 1. Tmp = std::numeric_limits::max(); @@ -4422,6 +4442,8 @@ return Tmp; } case ISD::INSERT_SUBVECTOR: { + if (VT.isScalableVector()) + return 1; // Demand any elements from the subvector and the remainder from the src its // inserted into. SDValue Src = Op.getOperand(0); @@ -4492,7 +4514,7 @@ // We only need to handle vectors - computeKnownBits should handle // scalar cases. Type *CstTy = Cst->getType(); - if (CstTy->isVectorTy() && + if (CstTy->isVectorTy() && !VT.isScalableVector() && (NumElts * VTBits) == CstTy->getPrimitiveSizeInBits() && VTBits == CstTy->getScalarSizeInBits()) { Tmp = VTBits; @@ -4527,6 +4549,10 @@ Opcode == ISD::INTRINSIC_WO_CHAIN || Opcode == ISD::INTRINSIC_W_CHAIN || Opcode == ISD::INTRINSIC_VOID) { + // TODO: This can probably be removed once target code is audited. This + // is here purely to reduce patch size and review complexity. + if (VT.isScalableVector()) + return 1; unsigned NumBits = TLI->ComputeNumSignBitsForTargetNode(Op, DemandedElts, *this, Depth); if (NumBits > 1) diff --git a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll @@ -95,7 +95,7 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p1.d ; CHECK-NEXT: sxth z0.d, p1/m, z0.d -; CHECK-NEXT: ld1w { z0.d }, p0/z, [x0, z0.d, sxtw #2] +; CHECK-NEXT: ld1w { z0.d }, p0/z, [x0, z0.d, lsl #2] ; CHECK-NEXT: ret %ptrs = getelementptr float, float* %base, %indices %data = call @llvm.masked.gather.nxv2f32( %ptrs, i32 1, %mask, undef) diff --git a/llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll b/llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll --- a/llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll +++ b/llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll @@ -9,15 +9,10 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: sxtb z1.d, p0/m, z1.d ; CHECK-NEXT: sxtb z0.d, p0/m, z0.d -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d ; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-NEXT: asr z1.d, z0.d, #63 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtb z3.d, p0/m, z0.d -; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d -; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtb z1.d, p0/m, z0.d +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d ; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv2i8( %x, %y) @@ -35,15 +30,10 @@ ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: sxtb z1.s, p0/m, z1.s ; CHECK-NEXT: sxtb z0.s, p0/m, z0.s -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.s, p0/m, z2.s, z1.s ; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s -; CHECK-NEXT: asr z1.s, z0.s, #31 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtb z3.s, p0/m, z0.s -; CHECK-NEXT: cmpne p1.s, p0/z, z2.s, z1.s -; CHECK-NEXT: cmpne p0.s, p0/z, z3.s, z0.s -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtb z1.s, p0/m, z0.s +; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, z0.s ; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv4i8( %x, %y) @@ -61,15 +51,10 @@ ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: sxtb z1.h, p0/m, z1.h ; CHECK-NEXT: sxtb z0.h, p0/m, z0.h -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.h, p0/m, z2.h, z1.h ; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h -; CHECK-NEXT: asr z1.h, z0.h, #15 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtb z3.h, p0/m, z0.h -; CHECK-NEXT: cmpne p1.h, p0/z, z2.h, z1.h -; CHECK-NEXT: cmpne p0.h, p0/z, z3.h, z0.h -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtb z1.h, p0/m, z0.h +; CHECK-NEXT: cmpne p0.h, p0/z, z1.h, z0.h ; CHECK-NEXT: mov z0.h, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv8i8( %x, %y) @@ -175,15 +160,10 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: sxth z1.d, p0/m, z1.d ; CHECK-NEXT: sxth z0.d, p0/m, z0.d -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d ; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-NEXT: asr z1.d, z0.d, #63 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxth z3.d, p0/m, z0.d -; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d -; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxth z1.d, p0/m, z0.d +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d ; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv2i16( %x, %y) @@ -201,15 +181,10 @@ ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: sxth z1.s, p0/m, z1.s ; CHECK-NEXT: sxth z0.s, p0/m, z0.s -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.s, p0/m, z2.s, z1.s ; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s -; CHECK-NEXT: asr z1.s, z0.s, #31 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxth z3.s, p0/m, z0.s -; CHECK-NEXT: cmpne p1.s, p0/z, z2.s, z1.s -; CHECK-NEXT: cmpne p0.s, p0/z, z3.s, z0.s -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxth z1.s, p0/m, z0.s +; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, z0.s ; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv4i16( %x, %y) @@ -315,15 +290,10 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: sxtw z1.d, p0/m, z1.d ; CHECK-NEXT: sxtw z0.d, p0/m, z0.d -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d ; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-NEXT: asr z1.d, z0.d, #63 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtw z3.d, p0/m, z0.d -; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d -; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtw z1.d, p0/m, z0.d +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d ; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv2i32( %x, %y) diff --git a/llvm/test/CodeGen/RISCV/rvv/vdiv-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vdiv-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vdiv-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vdiv-vp.ll @@ -12,11 +12,8 @@ ; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma ; CHECK-NEXT: vadd.vv v8, v8, v8 ; CHECK-NEXT: vsra.vi v8, v8, 1 -; CHECK-NEXT: vmv.v.x v9, a0 -; CHECK-NEXT: vadd.vv v9, v9, v9 -; CHECK-NEXT: vsra.vi v9, v9, 1 ; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma -; CHECK-NEXT: vdiv.vv v8, v8, v9, v0.t +; CHECK-NEXT: vdiv.vx v8, v8, a0, v0.t ; CHECK-NEXT: ret %elt.head = insertelement poison, i7 %b, i32 0 %vb = shufflevector %elt.head, poison, zeroinitializer diff --git a/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vmax-vp.ll @@ -12,11 +12,8 @@ ; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma ; CHECK-NEXT: vadd.vv v8, v8, v8 ; CHECK-NEXT: vsra.vi v8, v8, 1 -; CHECK-NEXT: vmv.v.x v9, a0 -; CHECK-NEXT: vadd.vv v9, v9, v9 -; CHECK-NEXT: vsra.vi v9, v9, 1 ; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma -; CHECK-NEXT: vmax.vv v8, v8, v9, v0.t +; CHECK-NEXT: vmax.vx v8, v8, a0, v0.t ; CHECK-NEXT: ret %elt.head = insertelement poison, i7 %b, i32 0 %vb = shufflevector %elt.head, poison, zeroinitializer diff --git a/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vmin-vp.ll @@ -12,11 +12,8 @@ ; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma ; CHECK-NEXT: vadd.vv v8, v8, v8 ; CHECK-NEXT: vsra.vi v8, v8, 1 -; CHECK-NEXT: vmv.v.x v9, a0 -; CHECK-NEXT: vadd.vv v9, v9, v9 -; CHECK-NEXT: vsra.vi v9, v9, 1 ; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma -; CHECK-NEXT: vmin.vv v8, v8, v9, v0.t +; CHECK-NEXT: vmin.vx v8, v8, a0, v0.t ; CHECK-NEXT: ret %elt.head = insertelement poison, i7 %b, i32 0 %vb = shufflevector %elt.head, poison, zeroinitializer diff --git a/llvm/test/CodeGen/RISCV/rvv/vrem-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vrem-vp.ll --- a/llvm/test/CodeGen/RISCV/rvv/vrem-vp.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vrem-vp.ll @@ -12,11 +12,8 @@ ; CHECK-NEXT: vsetvli a2, zero, e8, m1, ta, ma ; CHECK-NEXT: vadd.vv v8, v8, v8 ; CHECK-NEXT: vsra.vi v8, v8, 1 -; CHECK-NEXT: vmv.v.x v9, a0 -; CHECK-NEXT: vadd.vv v9, v9, v9 -; CHECK-NEXT: vsra.vi v9, v9, 1 ; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, ma -; CHECK-NEXT: vrem.vv v8, v8, v9, v0.t +; CHECK-NEXT: vrem.vx v8, v8, a0, v0.t ; CHECK-NEXT: ret %elt.head = insertelement poison, i7 %b, i32 0 %vb = shufflevector %elt.head, poison, zeroinitializer