diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def --- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def @@ -31,6 +31,7 @@ HANDLE_NODETYPE(VEC_SHL) HANDLE_NODETYPE(VEC_SHR_S) HANDLE_NODETYPE(VEC_SHR_U) +HANDLE_NODETYPE(NARROW_U) HANDLE_NODETYPE(EXTEND_LOW_S) HANDLE_NODETYPE(EXTEND_LOW_U) HANDLE_NODETYPE(EXTEND_HIGH_S) diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp --- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp @@ -174,6 +174,8 @@ setTargetDAGCombine(ISD::FP_ROUND); setTargetDAGCombine(ISD::CONCAT_VECTORS); + setTargetDAGCombine(ISD::TRUNCATE); + // Support saturating add for i8x16 and i16x8 for (auto Op : {ISD::SADDSAT, ISD::UADDSAT}) for (auto T : {MVT::v16i8, MVT::v8i16}) @@ -2434,6 +2436,109 @@ return DAG.getNode(Op, SDLoc(N), ResVT, Source); } +static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG, + const SDLoc &DL, unsigned VectorWidth) { + EVT VT = Vec.getValueType(); + EVT ElVT = VT.getVectorElementType(); + unsigned Factor = VT.getSizeInBits() / VectorWidth; + EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT, + VT.getVectorNumElements() / Factor); + + // Extract the relevant VectorWidth bits. Generate an EXTRACT_SUBVECTOR + unsigned ElemsPerChunk = VectorWidth / ElVT.getSizeInBits(); + assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2"); + + // This is the index of the first element of the VectorWidth-bit chunk + // we want. Since ElemsPerChunk is a power of 2 just need to clear bits. + IdxVal &= ~(ElemsPerChunk - 1); + + // If the input is a buildvector just emit a smaller one. + if (Vec.getOpcode() == ISD::BUILD_VECTOR) + return DAG.getBuildVector(ResultVT, DL, + Vec->ops().slice(IdxVal, ElemsPerChunk)); + + SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, DL); + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResultVT, Vec, VecIdx); +} + +static SDValue truncateVectorWithNARROW(unsigned Opcode, EVT DstVT, SDValue In, + const SDLoc &DL, SelectionDAG &DAG) { + EVT SrcVT = In.getValueType(); + + // No truncation required, we might get here due to recursive calls. + if (SrcVT == DstVT) + return In; + + unsigned SrcSizeInBits = SrcVT.getSizeInBits(); + unsigned NumElems = SrcVT.getVectorNumElements(); + if (!isPowerOf2_32(NumElems)) + return SDValue(); + assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation"); + assert(SrcSizeInBits > DstVT.getSizeInBits() && "Illegal truncation"); + + LLVMContext &Ctx = *DAG.getContext(); + EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2); + + // Narrow to the largest type possible: + // vXi64/vXi32 -> i16x8.narrow_i32x4_u and vXi16 -> i8x16.narrow_i16x8_u. + EVT InVT = MVT::i16, OutVT = MVT::i8; + if (SrcVT.getScalarSizeInBits() > 16) { + InVT = MVT::i32; + OutVT = MVT::i16; + } + unsigned SubSizeInBits = SrcSizeInBits / 2; + InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits()); + OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits()); + + // Split lower/upper subvectors. + SDValue Lo = extractSubVector(In, 0, DAG, DL, SubSizeInBits); + SDValue Hi = extractSubVector(In, NumElems / 2, DAG, DL, SubSizeInBits); + + // 256bit -> 128bit truncate - Narrow lower/upper 128-bit subvectors. + if (SrcVT.is256BitVector() && DstVT.is128BitVector()) { + Lo = DAG.getBitcast(InVT, Lo); + Hi = DAG.getBitcast(InVT, Hi); + SDValue Res = DAG.getNode(Opcode, DL, OutVT, Lo, Hi); + return DAG.getBitcast(DstVT, Res); + } + + // Recursively narrow lower/upper subvectors, concat result and narrow again. + EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2); + Lo = truncateVectorWithNARROW(Opcode, PackedVT, Lo, DL, DAG); + Hi = truncateVectorWithNARROW(Opcode, PackedVT, Hi, DL, DAG); + + PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems); + SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi); + return truncateVectorWithNARROW(Opcode, DstVT, Res, DL, DAG); +} + +static SDValue performTruncateCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + auto &DAG = DCI.DAG; + + SDValue In = N->getOperand(0); + EVT InVT = In.getValueType(); + if (!InVT.isSimple()) + return SDValue(); + + EVT OutVT = N->getValueType(0); + if (!OutVT.isVector()) + return SDValue(); + + unsigned NumElems = OutVT.getVectorNumElements(); + EVT OutSVT = OutVT.getVectorElementType(); + EVT InSVT = InVT.getVectorElementType(); + // TODO: only v16i32 => v16i8 here. Can cover more. + if (!(InSVT == MVT::i32 && OutSVT == MVT::i8 && NumElems == 16)) + return SDValue(); + + SDLoc DL(N); + APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(), + OutVT.getScalarSizeInBits()); + In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT)); + return truncateVectorWithNARROW(WebAssemblyISD::NARROW_U, OutVT, In, DL, DAG); +} + SDValue WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { @@ -2450,5 +2555,7 @@ case ISD::FP_ROUND: case ISD::CONCAT_VECTORS: return performVectorTruncZeroCombine(N, DCI); + case ISD::TRUNCATE: + return performTruncateCombine(N, DCI); } } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -1261,6 +1261,14 @@ defm "" : SIMDNarrow; defm "" : SIMDNarrow; +// WebAssemblyISD::NARROW_U +def wasm_narrow_t : SDTypeProfile<1, 2, []>; +def wasm_narrow_u : SDNode<"WebAssemblyISD::NARROW_U", wasm_narrow_t>; +def : Pat<(v16i8 (wasm_narrow_u (v8i16 V128:$left), (v8i16 V128:$right))), + (NARROW_U_I8x16 $left, $right)>; +def : Pat<(v8i16 (wasm_narrow_u (v4i32 V128:$left), (v4i32 V128:$right))), + (NARROW_U_I16x8 $left, $right)>; + // Bitcasts are nops // Matching bitcast t1 to t1 causes strange errors, so avoid repeating types foreach t1 = AllVecs in diff --git a/llvm/test/CodeGen/WebAssembly/simd-vector-trunc.ll b/llvm/test/CodeGen/WebAssembly/simd-vector-trunc.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/simd-vector-trunc.ll @@ -0,0 +1,22 @@ +; RUN: llc < %s -asm-verbose=false -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s + +; Test that a vector trunc correctly optimizes and lowers to narrow instructions + +target triple = "wasm32-unknown-unknown" + +; CHECK-LABEL: trunc16i32_16i8: +; CHECK-NEXT: .functype trunc16i32_16i8 (v128, v128, v128, v128) -> (v128) +; CHECK-NEXT: v128.const $push{{[0-9]+}}=, 255, 255, 255, 255 +; CHECK: v128.and +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u $push[[R1:[0-9]+]]=, $pop{{[0-9]+}}, $pop{{[0-9]+}} +; CHECK-NEXT: v128.and +; CHECK-NEXT: v128.and +; CHECK-NEXT: i16x8.narrow_i32x4_u $push[[R2:[0-9]+]]=, $pop{{[0-9]+}}, $pop{{[0-9]+}} +; CHECK-NEXT: i8x16.narrow_i16x8_u $push[[R3:[0-9]+]]=, $pop[[R1]], $pop[[R2]] +; CHECK-NEXT: return $pop[[R3]] +define <16 x i8> @trunc16i32_16i8(<16 x i32> %a) { +entry: + %0 = trunc <16 x i32> %a to <16 x i8> + ret <16 x i8> %0 +}