Index: llvm/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAG.h +++ llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1938,8 +1938,10 @@ /// immediately after an "SRA X, 2", we know that the top 3 bits are all equal /// to each other, so we return 3. The DemandedElts argument allows /// us to only collect the minimum sign bits of the requested vector elements. - /// Targets can implement the ComputeNumSignBitsForTarget method in the - /// TargetLowering class to allow target nodes to be understood. + /// For scalable vectors the DemandedElts must be getVectorMinNumElements in + /// size and all lanes must be demanded. Targets can implement the + /// ComputeNumSignBitsForTarget method in the TargetLowering class to allow + /// target nodes to be understood. unsigned ComputeNumSignBits(SDValue Op, const APInt &DemandedElts, unsigned Depth = 0) const; Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -3898,12 +3898,8 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, unsigned Depth) const { EVT VT = Op.getValueType(); - // TODO: Assume we don't know anything for now. - if (VT.isScalableVector()) - return 1; - APInt DemandedElts = VT.isVector() - ? APInt::getAllOnes(VT.getVectorNumElements()) + ? APInt::getAllOnes(VT.getVectorMinNumElements()) : APInt(1, 1); return ComputeNumSignBits(Op, DemandedElts, Depth); } @@ -3917,6 +3913,9 @@ unsigned Tmp, Tmp2; unsigned FirstAnswer = 1; + assert((!VT.isScalableVector() || DemandedElts.isAllOnes()) && + "Expected all demanded lanes from scalable vectors"); + if (auto *C = dyn_cast(Op)) { const APInt &Val = C->getAPIntValue(); return Val.getNumSignBits(); @@ -3925,9 +3924,6 @@ if (Depth >= MaxRecursionDepth) return 1; // Limit search depth. - if (!DemandedElts || VT.isScalableVector()) - return 1; // No demanded elts, better to assume we don't know anything. - unsigned Opcode = Op.getOpcode(); switch (Opcode) { default: break; @@ -3992,6 +3988,9 @@ } case ISD::BITCAST: { + if (VT.isScalableVector()) + return 1; + SDValue N0 = Op.getOperand(0); EVT SrcVT = N0.getValueType(); unsigned SrcBits = SrcVT.getScalarSizeInBits(); @@ -4049,9 +4048,12 @@ Tmp2 = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1); return std::max(Tmp, Tmp2); case ISD::SIGN_EXTEND_VECTOR_INREG: { + if (VT.isScalableVector()) + return 1; + SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); - APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorNumElements()); + APInt DemandedSrcElts = DemandedElts.zext(SrcVT.getVectorMinNumElements()); Tmp = VTBits - SrcVT.getScalarSizeInBits(); return ComputeNumSignBits(Src, DemandedSrcElts, Depth+1) + Tmp; } @@ -4274,6 +4276,9 @@ return std::max(std::min(KnownSign - rIndex * BitWidth, BitWidth), 0); } case ISD::INSERT_VECTOR_ELT: { + if (VT.isScalableVector()) + return 1; + // If we know the element index, split the demand between the // source vector and the inserted element, otherwise assume we need // the original demanded vector elements and the value. @@ -4331,6 +4336,9 @@ return ComputeNumSignBits(InVec, DemandedSrcElts, Depth + 1); } case ISD::EXTRACT_SUBVECTOR: { + if (VT.isScalableVector()) + return 1; + // Offset the demanded elts by the subvector index. SDValue Src = Op.getOperand(0); // Bail until we can represent demanded elements for scalable vectors. @@ -4342,6 +4350,9 @@ return ComputeNumSignBits(Src, DemandedSrcElts, Depth + 1); } case ISD::CONCAT_VECTORS: { + if (VT.isScalableVector()) + return 1; + // Determine the minimum number of sign bits across all demanded // elts of the input vectors. Early out if the result is already 1. Tmp = std::numeric_limits::max(); @@ -4360,6 +4371,9 @@ return Tmp; } case ISD::INSERT_SUBVECTOR: { + if (VT.isScalableVector()) + return 1; + // Demand any elements from the subvector and the remainder from the src its // inserted into. SDValue Src = Op.getOperand(0); Index: llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll +++ llvm/test/CodeGen/AArch64/sve-masked-gather-legalize.ll @@ -95,7 +95,7 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: ptrue p1.d ; CHECK-NEXT: sxth z0.d, p1/m, z0.d -; CHECK-NEXT: ld1w { z0.d }, p0/z, [x0, z0.d, sxtw #2] +; CHECK-NEXT: ld1w { z0.d }, p0/z, [x0, z0.d, lsl #2] ; CHECK-NEXT: ret %ptrs = getelementptr float, float* %base, %indices %data = call @llvm.masked.gather.nxv2f32( %ptrs, i32 1, %mask, undef) Index: llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll +++ llvm/test/CodeGen/AArch64/sve-smulo-sdnode.ll @@ -9,15 +9,10 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: sxtb z1.d, p0/m, z1.d ; CHECK-NEXT: sxtb z0.d, p0/m, z0.d -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d ; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-NEXT: asr z1.d, z0.d, #63 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtb z3.d, p0/m, z0.d -; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d -; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtb z1.d, p0/m, z0.d +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d ; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv2i8( %x, %y) @@ -35,15 +30,10 @@ ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: sxtb z1.s, p0/m, z1.s ; CHECK-NEXT: sxtb z0.s, p0/m, z0.s -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.s, p0/m, z2.s, z1.s ; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s -; CHECK-NEXT: asr z1.s, z0.s, #31 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtb z3.s, p0/m, z0.s -; CHECK-NEXT: cmpne p1.s, p0/z, z2.s, z1.s -; CHECK-NEXT: cmpne p0.s, p0/z, z3.s, z0.s -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtb z1.s, p0/m, z0.s +; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, z0.s ; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv4i8( %x, %y) @@ -61,15 +51,10 @@ ; CHECK-NEXT: ptrue p0.h ; CHECK-NEXT: sxtb z1.h, p0/m, z1.h ; CHECK-NEXT: sxtb z0.h, p0/m, z0.h -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.h, p0/m, z2.h, z1.h ; CHECK-NEXT: mul z0.h, p0/m, z0.h, z1.h -; CHECK-NEXT: asr z1.h, z0.h, #15 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtb z3.h, p0/m, z0.h -; CHECK-NEXT: cmpne p1.h, p0/z, z2.h, z1.h -; CHECK-NEXT: cmpne p0.h, p0/z, z3.h, z0.h -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtb z1.h, p0/m, z0.h +; CHECK-NEXT: cmpne p0.h, p0/z, z1.h, z0.h ; CHECK-NEXT: mov z0.h, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv8i8( %x, %y) @@ -175,15 +160,10 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: sxth z1.d, p0/m, z1.d ; CHECK-NEXT: sxth z0.d, p0/m, z0.d -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d ; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-NEXT: asr z1.d, z0.d, #63 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxth z3.d, p0/m, z0.d -; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d -; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxth z1.d, p0/m, z0.d +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d ; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv2i16( %x, %y) @@ -201,15 +181,10 @@ ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: sxth z1.s, p0/m, z1.s ; CHECK-NEXT: sxth z0.s, p0/m, z0.s -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.s, p0/m, z2.s, z1.s ; CHECK-NEXT: mul z0.s, p0/m, z0.s, z1.s -; CHECK-NEXT: asr z1.s, z0.s, #31 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxth z3.s, p0/m, z0.s -; CHECK-NEXT: cmpne p1.s, p0/z, z2.s, z1.s -; CHECK-NEXT: cmpne p0.s, p0/z, z3.s, z0.s -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxth z1.s, p0/m, z0.s +; CHECK-NEXT: cmpne p0.s, p0/z, z1.s, z0.s ; CHECK-NEXT: mov z0.s, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv4i16( %x, %y) @@ -315,15 +290,10 @@ ; CHECK-NEXT: ptrue p0.d ; CHECK-NEXT: sxtw z1.d, p0/m, z1.d ; CHECK-NEXT: sxtw z0.d, p0/m, z0.d -; CHECK-NEXT: movprfx z2, z0 -; CHECK-NEXT: smulh z2.d, p0/m, z2.d, z1.d ; CHECK-NEXT: mul z0.d, p0/m, z0.d, z1.d -; CHECK-NEXT: asr z1.d, z0.d, #63 -; CHECK-NEXT: movprfx z3, z0 -; CHECK-NEXT: sxtw z3.d, p0/m, z0.d -; CHECK-NEXT: cmpne p1.d, p0/z, z2.d, z1.d -; CHECK-NEXT: cmpne p0.d, p0/z, z3.d, z0.d -; CHECK-NEXT: sel p0.b, p0, p0.b, p1.b +; CHECK-NEXT: movprfx z1, z0 +; CHECK-NEXT: sxtw z1.d, p0/m, z0.d +; CHECK-NEXT: cmpne p0.d, p0/z, z1.d, z0.d ; CHECK-NEXT: mov z0.d, p0/m, #0 // =0x0 ; CHECK-NEXT: ret %a = call { , } @llvm.smul.with.overflow.nxv2i32( %x, %y)