diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def --- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def @@ -31,6 +31,7 @@ HANDLE_NODETYPE(VEC_SHL) HANDLE_NODETYPE(VEC_SHR_S) HANDLE_NODETYPE(VEC_SHR_U) +HANDLE_NODETYPE(NARROW_U) HANDLE_NODETYPE(EXTEND_LOW_S) HANDLE_NODETYPE(EXTEND_LOW_U) HANDLE_NODETYPE(EXTEND_HIGH_S) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -176,6 +176,8 @@ setTargetDAGCombine(ISD::FP_ROUND); setTargetDAGCombine(ISD::CONCAT_VECTORS); + setTargetDAGCombine(ISD::TRUNCATE); + // Support saturating add for i8x16 and i16x8 for (auto Op : {ISD::SADDSAT, ISD::UADDSAT}) for (auto T : {MVT::v16i8, MVT::v8i16}) @@ -2609,6 +2611,114 @@ return DAG.getNode(Op, SDLoc(N), ResVT, Source); } +// Helper to extract VectorWidth bits from Vec, starting from IdxVal. +static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG, + const SDLoc &DL, unsigned VectorWidth) { + EVT VT = Vec.getValueType(); + EVT ElVT = VT.getVectorElementType(); + unsigned Factor = VT.getSizeInBits() / VectorWidth; + EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT, + VT.getVectorNumElements() / Factor); + + // Extract the relevant VectorWidth bits. Generate an EXTRACT_SUBVECTOR + unsigned ElemsPerChunk = VectorWidth / ElVT.getSizeInBits(); + assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2"); + + // This is the index of the first element of the VectorWidth-bit chunk + // we want. Since ElemsPerChunk is a power of 2 just need to clear bits. + IdxVal &= ~(ElemsPerChunk - 1); + + // If the input is a buildvector just emit a smaller one. + if (Vec.getOpcode() == ISD::BUILD_VECTOR) + return DAG.getBuildVector(ResultVT, DL, + Vec->ops().slice(IdxVal, ElemsPerChunk)); + + SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, DL); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResultVT, Vec, VecIdx); +} + +// Helper to recursively truncate vector elements in half with NARROW_U. DstVT +// is the expected destination value type after recursion. In is the initial +// input. Note that the input should have enough leading zero bits to prevent +// NARROW_U from saturating results. +static SDValue truncateVectorWithNARROW(EVT DstVT, SDValue In, const SDLoc &DL, + SelectionDAG &DAG) { + EVT SrcVT = In.getValueType(); + + // No truncation required, we might get here due to recursive calls. + if (SrcVT == DstVT) + return In; + + unsigned SrcSizeInBits = SrcVT.getSizeInBits(); + unsigned NumElems = SrcVT.getVectorNumElements(); + if (!isPowerOf2_32(NumElems)) + return SDValue(); + assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation"); + assert(SrcSizeInBits > DstVT.getSizeInBits() && "Illegal truncation"); + + LLVMContext &Ctx = *DAG.getContext(); + EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2); + + // Narrow to the largest type possible: + // vXi64/vXi32 -> i16x8.narrow_i32x4_u and vXi16 -> i8x16.narrow_i16x8_u. + EVT InVT = MVT::i16, OutVT = MVT::i8; + if (SrcVT.getScalarSizeInBits() > 16) { + InVT = MVT::i32; + OutVT = MVT::i16; + } + unsigned SubSizeInBits = SrcSizeInBits / 2; + InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits()); + OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits()); + + // Split lower/upper subvectors. + SDValue Lo = extractSubVector(In, 0, DAG, DL, SubSizeInBits); + SDValue Hi = extractSubVector(In, NumElems / 2, DAG, DL, SubSizeInBits); + + // 256bit -> 128bit truncate - Narrow lower/upper 128-bit subvectors. + if (SrcVT.is256BitVector() && DstVT.is128BitVector()) { + Lo = DAG.getBitcast(InVT, Lo); + Hi = DAG.getBitcast(InVT, Hi); + SDValue Res = DAG.getNode(WebAssemblyISD::NARROW_U, DL, OutVT, Lo, Hi); + return DAG.getBitcast(DstVT, Res); + } + + // Recursively narrow lower/upper subvectors, concat result and narrow again. + EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2); + Lo = truncateVectorWithNARROW(PackedVT, Lo, DL, DAG); + Hi = truncateVectorWithNARROW(PackedVT, Hi, DL, DAG); + + PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems); + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi); + return truncateVectorWithNARROW(DstVT, Res, DL, DAG); +} + +static SDValue performTruncateCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + auto &DAG = DCI.DAG; + + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + if (!InVT.isSimple()) + return SDValue(); + + EVT OutVT = N->getValueType(0); + if (!OutVT.isVector()) + return SDValue(); + + EVT OutSVT = OutVT.getVectorElementType(); + EVT InSVT = InVT.getVectorElementType(); + // Currently only cover truncate to v16i8 or v8i16. + if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) && + (OutSVT == MVT::i8 || OutSVT == MVT::i16) && OutVT.is128BitVector())) + return SDValue(); + + SDLoc DL(N); + APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(), + OutVT.getScalarSizeInBits()); + In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT)); + return truncateVectorWithNARROW(OutVT, In, DL, DAG); +} + SDValue WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { @@ -2625,5 +2735,7 @@ case ISD::FP_ROUND: case ISD::CONCAT_VECTORS: return performVectorTruncZeroCombine(N, DCI); + case ISD::TRUNCATE: + return performTruncateCombine(N, DCI); } } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1278,6 +1278,14 @@ defm "" : SIMDNarrow; defm "" : SIMDNarrow; +// WebAssemblyISD::NARROW_U +def wasm_narrow_t : SDTypeProfile<1, 2, []>; +def wasm_narrow_u : SDNode<"WebAssemblyISD::NARROW_U", wasm_narrow_t>; +def : Pat<(v16i8 (wasm_narrow_u (v8i16 V128:$left), (v8i16 V128:$right))), + (NARROW_U_I8x16 $left, $right)>; +def : Pat<(v8i16 (wasm_narrow_u (v4i32 V128:$left), (v4i32 V128:$right))), + (NARROW_U_I16x8 $left, $right)>; + // Bitcasts are nops // Matching bitcast t1 to t1 causes strange errors, so avoid repeating types foreach t1 = AllVecs in diff --git a/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll b/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll --- a/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll +++ b/llvm/test/CodeGen/WebAssembly/fpclamptosat_vec.ll @@ -532,7 +532,7 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) { ; CHECK-LABEL: stest_f16i16: ; CHECK: .functype stest_f16i16 (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128) -; CHECK-NEXT: .local v128, v128 +; CHECK-NEXT: .local v128, v128, v128 ; CHECK-NEXT: # %bb.0: # %entry ; CHECK-NEXT: local.get 5 ; CHECK-NEXT: call __truncsfhf2 @@ -578,6 +578,9 @@ ; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768 ; CHECK-NEXT: local.tee 9 ; CHECK-NEXT: i32x4.max_s +; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535 +; CHECK-NEXT: local.tee 10 +; CHECK-NEXT: v128.and ; CHECK-NEXT: local.get 4 ; CHECK-NEXT: i32.trunc_sat_f32_s ; CHECK-NEXT: i32x4.splat @@ -594,7 +597,9 @@ ; CHECK-NEXT: i32x4.min_s ; CHECK-NEXT: local.get 9 ; CHECK-NEXT: i32x4.max_s -; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 +; CHECK-NEXT: local.get 10 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u ; CHECK-NEXT: # fallthrough-return entry: %conv = fptosi <8 x half> %x to <8 x i32> @@ -666,7 +671,7 @@ ; CHECK-NEXT: i32x4.replace_lane 3 ; CHECK-NEXT: local.get 8 ; CHECK-NEXT: i32x4.min_u -; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 +; CHECK-NEXT: i16x8.narrow_i32x4_u ; CHECK-NEXT: # fallthrough-return entry: %conv = fptoui <8 x half> %x to <8 x i32> @@ -741,7 +746,7 @@ ; CHECK-NEXT: i32x4.min_s ; CHECK-NEXT: local.get 9 ; CHECK-NEXT: i32x4.max_s -; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 +; CHECK-NEXT: i16x8.narrow_i32x4_u ; CHECK-NEXT: # fallthrough-return entry: %conv = fptosi <8 x half> %x to <8 x i32> @@ -2106,7 +2111,7 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) { ; CHECK-LABEL: stest_f16i16_mm: ; CHECK: .functype stest_f16i16_mm (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128) -; CHECK-NEXT: .local v128, v128 +; CHECK-NEXT: .local v128, v128, v128 ; CHECK-NEXT: # %bb.0: # %entry ; CHECK-NEXT: local.get 5 ; CHECK-NEXT: call __truncsfhf2 @@ -2152,6 +2157,9 @@ ; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768 ; CHECK-NEXT: local.tee 9 ; CHECK-NEXT: i32x4.max_s +; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535 +; CHECK-NEXT: local.tee 10 +; CHECK-NEXT: v128.and ; CHECK-NEXT: local.get 4 ; CHECK-NEXT: i32.trunc_sat_f32_s ; CHECK-NEXT: i32x4.splat @@ -2168,7 +2176,9 @@ ; CHECK-NEXT: i32x4.min_s ; CHECK-NEXT: local.get 9 ; CHECK-NEXT: i32x4.max_s -; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 +; CHECK-NEXT: local.get 10 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u ; CHECK-NEXT: # fallthrough-return entry: %conv = fptosi <8 x half> %x to <8 x i32> @@ -2238,7 +2248,7 @@ ; CHECK-NEXT: i32x4.replace_lane 3 ; CHECK-NEXT: local.get 8 ; CHECK-NEXT: i32x4.min_u -; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 +; CHECK-NEXT: i16x8.narrow_i32x4_u ; CHECK-NEXT: # fallthrough-return entry: %conv = fptoui <8 x half> %x to <8 x i32> @@ -2312,7 +2322,7 @@ ; CHECK-NEXT: i32x4.min_s ; CHECK-NEXT: local.get 9 ; CHECK-NEXT: i32x4.max_s -; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 +; CHECK-NEXT: i16x8.narrow_i32x4_u ; CHECK-NEXT: # fallthrough-return entry: %conv = fptosi <8 x half> %x to <8 x i32> diff --git a/llvm/test/CodeGen/WebAssembly/simd-vector-trunc.ll b/llvm/test/CodeGen/WebAssembly/simd-vector-trunc.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/simd-vector-trunc.ll @@ -0,0 +1,141 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -verify-machineinstrs -mattr=+simd128 | FileCheck %s + +; Test that a vector trunc correctly optimizes and lowers to narrow instructions + +target triple = "wasm32-unknown-unknown" + +define <16 x i8> @trunc16i64_16i8(<16 x i64> %a) { +; CHECK-LABEL: trunc16i64_16i8: +; CHECK: .functype trunc16i64_16i8 (v128, v128, v128, v128, v128, v128, v128, v128) -> (v128) +; CHECK-NEXT: .local v128 +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.const 255, 255 +; CHECK-NEXT: local.tee 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: local.get 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: local.get 2 +; CHECK-NEXT: local.get 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 3 +; CHECK-NEXT: local.get 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: local.get 4 +; CHECK-NEXT: local.get 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 5 +; CHECK-NEXT: local.get 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: local.get 6 +; CHECK-NEXT: local.get 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 7 +; CHECK-NEXT: local.get 8 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: i8x16.narrow_i16x8_u +; CHECK-NEXT: # fallthrough-return +entry: + %0 = trunc <16 x i64> %a to <16 x i8> + ret <16 x i8> %0 +} + +define <16 x i8> @trunc16i32_16i8(<16 x i32> %a) { +; CHECK-LABEL: trunc16i32_16i8: +; CHECK: .functype trunc16i32_16i8 (v128, v128, v128, v128) -> (v128) +; CHECK-NEXT: .local v128 +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.const 255, 255, 255, 255 +; CHECK-NEXT: local.tee 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: local.get 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: local.get 2 +; CHECK-NEXT: local.get 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 3 +; CHECK-NEXT: local.get 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: i8x16.narrow_i16x8_u +; CHECK-NEXT: # fallthrough-return +entry: + %0 = trunc <16 x i32> %a to <16 x i8> + ret <16 x i8> %0 +} + +define <16 x i8> @trunc16i16_16i8(<16 x i16> %a) { +; CHECK-LABEL: trunc16i16_16i8: +; CHECK: .functype trunc16i16_16i8 (v128, v128) -> (v128) +; CHECK-NEXT: .local v128 +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.const 255, 255, 255, 255, 255, 255, 255, 255 +; CHECK-NEXT: local.tee 2 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: local.get 2 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i8x16.narrow_i16x8_u +; CHECK-NEXT: # fallthrough-return +entry: + %0 = trunc <16 x i16> %a to <16 x i8> + ret <16 x i8> %0 +} + +define <8 x i16> @trunc8i64_8i16(<8 x i64> %a) { +; CHECK-LABEL: trunc8i64_8i16: +; CHECK: .functype trunc8i64_8i16 (v128, v128, v128, v128) -> (v128) +; CHECK-NEXT: .local v128 +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.const 65535, 65535 +; CHECK-NEXT: local.tee 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: local.get 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: local.get 2 +; CHECK-NEXT: local.get 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 3 +; CHECK-NEXT: local.get 4 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: # fallthrough-return +entry: + %0 = trunc <8 x i64> %a to <8 x i16> + ret <8 x i16> %0 +} + +define <8 x i16> @trunc8i32_8i16(<8 x i32> %a) { +; CHECK-LABEL: trunc8i32_8i16: +; CHECK: .functype trunc8i32_8i16 (v128, v128) -> (v128) +; CHECK-NEXT: .local v128 +; CHECK-NEXT: # %bb.0: # %entry +; CHECK-NEXT: local.get 0 +; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535 +; CHECK-NEXT: local.tee 2 +; CHECK-NEXT: v128.and +; CHECK-NEXT: local.get 1 +; CHECK-NEXT: local.get 2 +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u +; CHECK-NEXT: # fallthrough-return +entry: + %0 = trunc <8 x i32> %a to <8 x i16> + ret <8 x i16> %0 +}