diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -74,6 +74,7 @@ bool tryConstantFP16(SDNode *N); bool SelectSETP_F16X2(SDNode *N); bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N); + bool tryTruncate(SDNode *N); inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) { return CurDAG->getTargetConstant(Imm, DL, MVT::i32); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -101,6 +101,10 @@ if (tryEXTRACT_VECTOR_ELEMENT(N)) return; break; + case ISD::TRUNCATE: + if (tryTruncate(N)) + return; + break; case NVPTXISD::SETP_F16X2: SelectSETP_F16X2(N); return; @@ -658,6 +662,78 @@ return true; } +// Try to replace the following IR: +// +// %high32 = lshr i32 %input, 16 +// %high = trunc i32 %high32 to i16 +// %low = trunc i32 %input to i16 +// +// With a "mov.b32 {high, low}, input" +// +bool NVPTXDAGToDAGISel::tryTruncate(SDNode *N) { + unsigned InputSize = N->getOperand(0).getValueSizeInBits(); + if (!(InputSize == 32 || InputSize == 64)) + return false; + unsigned PartSize = InputSize / 2; + + auto FindTrunc = [&](SDValue Input) -> SDNode * { + auto Result = llvm::find_if(Input->uses(), [&](const SDNode *Op) { + bool result = + (Op->getOpcode() == ISD::TRUNCATE && Op->getOperand(0) == Input && + Op->getValueSizeInBits(0) == PartSize && + Op->getOperand(0).getValueSizeInBits() == InputSize); + return result; + }); + return Result.atEnd() ? nullptr : *Result; + }; + + auto IsShiftRight = [&](const SDNode *Op) { + return (Op->getOpcode() == ISD::SRL || Op->getOpcode() == ISD::SRA) && + Op->getConstantOperandVal(1) == PartSize; + }; + auto FindSrlTrunc = [&](SDValue Input) -> SDNode * { + auto SRL = llvm::find_if(Input->uses(), [&](const SDNode *Op) { + return IsShiftRight(Op) && Op->getOperand(0) == Input; + }); + return SRL.atEnd() ? nullptr : FindTrunc(SDValue(*SRL, 0)); + }; + + SDValue Input; + SDNode *High = nullptr, *Low = nullptr; + + // This 'trunc' may be for low, high part of the move, or neither. The tricky + // parts is when the input value itself is produced by `srl` which makes 'low' + // trunc indistinguishable from 'high'. + // First, check if it's 'high' part. + if (IsShiftRight(N->getOperand(0).getNode())) { + Input = N->getOperand(0)->getOperand(0); + High = N; + Low = FindTrunc(Input); + } + // If that didn't work, check if it's for the 'low' part. + // TODO: For some reason I can't trigger this in tests. We always seem to + // process the trinc(ashr(value)) first. I don't know if I can rely on that, + // but to be safe, we'll also handle the case when we get to see trunc(value) + // first. + if (!(High && Low)) { + Input = N->getOperand(0); + High = FindSrlTrunc(Input); + Low = N; + } + if (!(High && Low)) + return false; + + auto Mov = CurDAG->getMachineNode( + InputSize == 32 ? NVPTX::I32toV2I16 : NVPTX::I64toV2I32, SDLoc(N), + High->getValueType(0), Low->getValueType(0), Input); + SDValue F[] = {SDValue(High, 0), SDValue(Low, 0)}; + SDValue T[] = {SDValue(Mov, 0), SDValue(Mov, 1)}; + ReplaceUses(F, T, 2); + CurDAG->RemoveDeadNode(Low); + CurDAG->RemoveDeadNode(High); + return true; +} + static unsigned int getCodeAddrSpace(MemSDNode *N) { const Value *Src = N->getMemOperand()->getValue(); diff --git a/llvm/test/CodeGen/NVPTX/idioms.ll b/llvm/test/CodeGen/NVPTX/idioms.ll --- a/llvm/test/CodeGen/NVPTX/idioms.ll +++ b/llvm/test/CodeGen/NVPTX/idioms.ll @@ -5,6 +5,9 @@ ; RUN: %if ptxas %{ llc < %s -march=nvptx -mcpu=sm_20 | %ptxas-verify %} ; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 | %ptxas-verify %} +%struct.S16 = type { i16, i16 } +%struct.S32 = type { i32, i32 } + ; CHECK-LABEL: abs_i16( define i16 @abs_i16(i16 %a) { ; CHECK: abs.s16 @@ -31,3 +34,87 @@ %abs = select i1 %abs.cond, i64 %a, i64 %neg ret i64 %abs } + +; CHECK-LABEL: i32_to_2xi16( +define %struct.S16 @i32_to_2xi16(i32 noundef %in) { + %low = trunc i32 %in to i16 + %high32 = lshr i32 %in, 16 + %high = trunc i32 %high32 to i16 +; CHECK: ld.param.u32 %[[R32:r[0-9]+]], [i32_to_2xi16_param_0]; +; CHECK: mov.b32 {%rs{{[0-9]+}}, %rs{{[0-9+]}}}, %[[R32]]; + %s1 = insertvalue %struct.S16 poison, i16 %low, 0 + %s = insertvalue %struct.S16 %s1, i16 %high, 1 + ret %struct.S16 %s +} + +; CHECK-LABEL: i32_to_2xi16_lh( +; Same as above, but with rearranged order of low/high parts. +define %struct.S16 @i32_to_2xi16_lh(i32 noundef %in) { + %high32 = lshr i32 %in, 16 + %high = trunc i32 %high32 to i16 + %low = trunc i32 %in to i16 +; CHECK: ld.param.u32 %[[R32:r[0-9]+]], [i32_to_2xi16_lh_param_0]; +; CHECK: mov.b32 {%rs{{[0-9]+}}, %rs{{[0-9+]}}}, %[[R32]]; + %s1 = insertvalue %struct.S16 poison, i16 %low, 0 + %s = insertvalue %struct.S16 %s1, i16 %high, 1 + ret %struct.S16 %s +} + + +; CHECK-LABEL: i32_to_2xi16_not( +define %struct.S16 @i32_to_2xi16_not(i32 noundef %in) { + %low = trunc i32 %in to i16 + ; Shift by any value other than 16 blocks the conversiopn to mov. + %high32 = lshr i32 %in, 15 + %high = trunc i32 %high32 to i16 +; CHECK: cvt.u16.u32 +; CHECK: shr.u32 +; CHECK: cvt.u16.u32 + %s1 = insertvalue %struct.S16 poison, i16 %low, 0 + %s = insertvalue %struct.S16 %s1, i16 %high, 1 + ret %struct.S16 %s +} + +; CHECK-LABEL: i64_to_2xi32( +define %struct.S32 @i64_to_2xi32(i64 noundef %in) { + %low = trunc i64 %in to i32 + %high64 = lshr i64 %in, 32 + %high = trunc i64 %high64 to i32 +; CHECK: ld.param.u64 %[[R64:rd[0-9]+]], [i64_to_2xi32_param_0]; +; CHECK: mov.b64 {%r{{[0-9]+}}, %r{{[0-9+]}}}, %[[R64]]; + %s1 = insertvalue %struct.S32 poison, i32 %low, 0 + %s = insertvalue %struct.S32 %s1, i32 %high, 1 + ret %struct.S32 %s +} + +; CHECK-LABEL: i64_to_2xi32_not( +define %struct.S32 @i64_to_2xi32_not(i64 noundef %in) { + %low = trunc i64 %in to i32 + ; Shift by any value other than 32 blocks the conversiopn to mov. + %high64 = lshr i64 %in, 31 + %high = trunc i64 %high64 to i32 +; CHECK: cvt.u32.u64 +; CHECK: shr.u64 +; CHECK: cvt.u32.u64 + %s1 = insertvalue %struct.S32 poison, i32 %low, 0 + %s = insertvalue %struct.S32 %s1, i32 %high, 1 + ret %struct.S32 %s +} + +; CHECK-LABEL: i32_to_2xi16_shr( +; Make sure we do not get confused when our input itself is [al]shr. +define %struct.S16 @i32_to_2xi16_shr(i32 noundef %i){ + call void @escape_int(i32 %i); // Force %i to be loaded completely. + %i1 = ashr i32 %i, 16 + %l = trunc i32 %i1 to i16 + %h32 = ashr i32 %i1, 16 + %h = trunc i32 %h32 to i16 +; CHECK: ld.param.u32 %r1, [i32_to_2xi16_shr_param_0]; +; CHECK: shr.s32 %r2, %r1, 16; +; CHECK: mov.b32 {%rs1, %rs2}, %r2; + %s0 = insertvalue %struct.S16 poison, i16 %l, 0 + %s1 = insertvalue %struct.S16 %s0, i16 %h, 1 + ret %struct.S16 %s1 +} +declare dso_local void @escape_int(i32 noundef) +