diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -2287,10 +2287,43 @@ // Only manually lower vector shifts assert(Op.getSimpleValueType().isVector()); - auto ShiftVal = DAG.getSplatValue(Op.getOperand(1)); + uint64_t LaneBits = Op.getValueType().getScalarSizeInBits(); + auto ShiftVal = Op.getOperand(1); + + // Try to skip bitmask operation since it is implied inside shift instruction + auto SkipImpliedMask = [](SDValue MaskOp, uint64_t MaskBits) { + if (MaskOp.getOpcode() != ISD::AND) + return MaskOp; + SDValue LHS = MaskOp.getOperand(0); + SDValue RHS = MaskOp.getOperand(1); + if (MaskOp.getValueType().isVector()) { + APInt MaskVal; + if (!ISD::isConstantSplatVector(RHS.getNode(), MaskVal)) + std::swap(LHS, RHS); + + if (ISD::isConstantSplatVector(RHS.getNode(), MaskVal) && + MaskVal == MaskBits) + MaskOp = LHS; + } else { + if (!isa(RHS.getNode())) + std::swap(LHS, RHS); + + auto ConstantRHS = dyn_cast(RHS.getNode()); + if (ConstantRHS->getAPIntValue() == MaskBits) + MaskOp = LHS; + } + + return MaskOp; + }; + + // Skip vector and operation + ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1); + ShiftVal = DAG.getSplatValue(ShiftVal); if (!ShiftVal) return unrollVectorShift(Op, DAG); + // Skip scalar and operation + ShiftVal = SkipImpliedMask(ShiftVal, LaneBits - 1); // Use anyext because none of the high bits can affect the shift ShiftVal = DAG.getAnyExtOrTrunc(ShiftVal, DL, MVT::i32); diff --git a/llvm/test/CodeGen/WebAssembly/masked-shifts.ll b/llvm/test/CodeGen/WebAssembly/masked-shifts.ll --- a/llvm/test/CodeGen/WebAssembly/masked-shifts.ll +++ b/llvm/test/CodeGen/WebAssembly/masked-shifts.ll @@ -106,10 +106,6 @@ ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i8x16.splat -; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 -; CHECK-NEXT: v128.and -; CHECK-NEXT: i8x16.extract_lane_u 0 ; CHECK-NEXT: i8x16.shl ; CHECK-NEXT: # fallthrough-return %t = insertelement <16 x i8> undef, i8 %x, i32 0 @@ -145,10 +141,6 @@ ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i8x16.splat -; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 -; CHECK-NEXT: v128.and -; CHECK-NEXT: i8x16.extract_lane_u 0 ; CHECK-NEXT: i8x16.shr_s ; CHECK-NEXT: # fallthrough-return %t = insertelement <16 x i8> undef, i8 %x, i32 0 @@ -184,10 +176,6 @@ ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i8x16.splat -; CHECK-NEXT: v128.const 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 -; CHECK-NEXT: v128.and -; CHECK-NEXT: i8x16.extract_lane_u 0 ; CHECK-NEXT: i8x16.shr_u ; CHECK-NEXT: # fallthrough-return %t = insertelement <16 x i8> undef, i8 %x, i32 0 @@ -222,10 +210,6 @@ ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i16x8.splat -; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15 -; CHECK-NEXT: v128.and -; CHECK-NEXT: i16x8.extract_lane_u 0 ; CHECK-NEXT: i16x8.shl ; CHECK-NEXT: # fallthrough-return %t = insertelement <8 x i16> undef, i16 %x, i32 0 @@ -259,10 +243,6 @@ ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i16x8.splat -; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15 -; CHECK-NEXT: v128.and -; CHECK-NEXT: i16x8.extract_lane_u 0 ; CHECK-NEXT: i16x8.shr_s ; CHECK-NEXT: # fallthrough-return %t = insertelement <8 x i16> undef, i16 %x, i32 0 @@ -296,10 +276,6 @@ ; CHECK-NEXT: # %bb.0: ; CHECK-NEXT: local.get 0 ; CHECK-NEXT: local.get 1 -; CHECK-NEXT: i16x8.splat -; CHECK-NEXT: v128.const 15, 15, 15, 15, 15, 15, 15, 15 -; CHECK-NEXT: v128.and -; CHECK-NEXT: i16x8.extract_lane_u 0 ; CHECK-NEXT: i16x8.shr_u ; CHECK-NEXT: # fallthrough-return %t = insertelement <8 x i16> undef, i16 %x, i32 0 @@ -519,6 +495,22 @@ ret <2 x i64> %a } +define <2 x i64> @shl_v2i64_i32_late(<2 x i64> %v, i32 %x) { +; CHECK-LABEL: shl_v2i64_i32_late: +; CHECK: .functype shl_v2i64_i32_late (v128, i32) -> (v128) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i64x2.shl +; CHECK-NEXT: # fallthrough-return + %z = zext i32 %x to i64 + %t = insertelement <2 x i64> undef, i64 %z, i32 0 + %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> + %m = and <2 x i64> %s, + %a = shl <2 x i64> %v, %m + ret <2 x i64> %a +} + define <2 x i64> @ashr_v2i64_i32(<2 x i64> %v, i32 %x) { ; CHECK-LABEL: ashr_v2i64_i32: ; CHECK: .functype ashr_v2i64_i32 (v128, i32) -> (v128) @@ -535,6 +527,22 @@ ret <2 x i64> %a } +define <2 x i64> @ashr_v2i64_i32_late(<2 x i64> %v, i32 %x) { +; CHECK-LABEL: ashr_v2i64_i32_late: +; CHECK: .functype ashr_v2i64_i32_late (v128, i32) -> (v128) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i64x2.shr_s +; CHECK-NEXT: # fallthrough-return + %z = zext i32 %x to i64 + %t = insertelement <2 x i64> undef, i64 %z, i32 0 + %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> + %m = and <2 x i64> %s, + %a = ashr <2 x i64> %v, %m + ret <2 x i64> %a +} + define <2 x i64> @lshr_v2i64_i32(<2 x i64> %v, i32 %x) { ; CHECK-LABEL: lshr_v2i64_i32: ; CHECK: .functype lshr_v2i64_i32 (v128, i32) -> (v128) @@ -551,3 +559,18 @@ ret <2 x i64> %a } +define <2 x i64> @lshr_v2i64_i32_late(<2 x i64> %v, i32 %x) { +; CHECK-LABEL: lshr_v2i64_i32_late: +; CHECK: .functype lshr_v2i64_i32_late (v128, i32) -> (v128) +; CHECK-NEXT: # %bb.0: +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: i64x2.shr_u +; CHECK-NEXT: # fallthrough-return + %z = zext i32 %x to i64 + %t = insertelement <2 x i64> undef, i64 %z, i32 0 + %s = shufflevector <2 x i64> %t, <2 x i64> undef, <2 x i32> + %m = and <2 x i64> %s, + %a = lshr <2 x i64> %v, %m + ret <2 x i64> %a +}