diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -74,6 +74,7 @@ bool tryConstantFP16(SDNode *N); bool SelectSETP_F16X2(SDNode *N); bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N); + bool tryTruncate(SDNode *N); inline SDValue getI32Imm(unsigned Imm, const SDLoc &DL) { return CurDAG->getTargetConstant(Imm, DL, MVT::i32); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -101,6 +101,10 @@ if (tryEXTRACT_VECTOR_ELEMENT(N)) return; break; + case ISD::TRUNCATE: + if (tryTruncate(N)) + return; + break; case NVPTXISD::SETP_F16X2: SelectSETP_F16X2(N); return; @@ -658,6 +662,87 @@ return true; } +// Try to replace the following IR: +// +// %high32 = lshr i32 %input, 16 +// %high = trunc i32 %high32 to i16 +// %low = trunc i32 %input to i16 +// +// With a "mov.b32 {high, low}, input" +// +bool NVPTXDAGToDAGISel::tryTruncate(SDNode *N) { + unsigned InputSize = N->getOperand(0).getValueSizeInBits(); + if (!(InputSize == 32 || InputSize == 64)) + return false; + unsigned PartSize = InputSize / 2; + + auto FindTrunc = [&](SDValue Input) -> SDNode * { + auto Trunc = llvm::find_if(Input->uses(), [&](const SDNode *Op) { + bool result = + (Op->getOpcode() == ISD::TRUNCATE && Op->getOperand(0) == Input && + Op->getValueSizeInBits(0) == PartSize && + Op->getOperand(0).getValueSizeInBits() == InputSize); + return result; + }); + return Trunc.atEnd() ? nullptr : *Trunc; + }; + + auto IsShiftRight = [&](const SDNode *Op) { + return (Op->getOpcode() == ISD::SRL || Op->getOpcode() == ISD::SRA) && + Op->getConstantOperandVal(1) == PartSize; + }; + auto FindSrlTrunc = [&](SDValue Input) -> SDNode * { + auto SRL = llvm::find_if(Input->uses(), [&](const SDNode *Op) { + return IsShiftRight(Op) && Op->getOperand(0) == Input; + }); + return SRL.atEnd() ? nullptr : FindTrunc(SDValue(*SRL, 0)); + }; + + SDValue Input; + SDNode *High = nullptr, *Low = nullptr; + + // This 'trunc' may be for low, high part of the move, or neither. + // First, check if it's 'high' part. + if (IsShiftRight(N->getOperand(0).getNode())) { + Input = N->getOperand(0)->getOperand(0); + High = FindSrlTrunc(Input); + Low = FindTrunc(Input); + if (High && Low) { + llvm::errs() << "High"; + N->dump(); + High->dump(); + Low->dump(); + CurDAG->dump(); + } + } + // If that didn't work, check if it's for the 'low' part. + // TODO: For some reason I can't trigger this in tests. + if (!(High && Low)) { + Input = N->getOperand(0); + High = FindSrlTrunc(Input); + Low = FindTrunc(Input); + if (High && Low) { + llvm::errs() << "High"; + N->dump(); + High->dump(); + Low->dump(); + CurDAG->dump(); + } + } + if (!(High && Low)) + return false; + + auto Mov = CurDAG->getMachineNode( + InputSize == 32 ? NVPTX::I32toV2I16 : NVPTX::I64toV2I32, SDLoc(N), + High->getValueType(0), Low->getValueType(0), Input); + SDValue F[] = {SDValue(High, 0), SDValue(Low, 0)}; + SDValue T[] = {SDValue(Mov, 0), SDValue(Mov, 1)}; + ReplaceUses(F, T, 2); + CurDAG->RemoveDeadNode(Low); + CurDAG->RemoveDeadNode(High); + return true; +} + static unsigned int getCodeAddrSpace(MemSDNode *N) { const Value *Src = N->getMemOperand()->getValue(); diff --git a/llvm/test/CodeGen/NVPTX/idioms.ll b/llvm/test/CodeGen/NVPTX/idioms.ll --- a/llvm/test/CodeGen/NVPTX/idioms.ll +++ b/llvm/test/CodeGen/NVPTX/idioms.ll @@ -31,3 +31,94 @@ %abs = select i1 %abs.cond, i64 %a, i64 %neg ret i64 %abs } + +%struct.S16 = type { i16, i16 } +; CHECK-LABEL: i32_to_2xi16( +define %struct.S16 @i32_to_2xi16(i32 noundef %in) { + %low = trunc i32 %in to i16 + %high32 = lshr i32 %in, 16 + %high = trunc i32 %high32 to i16 +; CHECK: ld.param.u32 %[[R32:r[0-9]+]], [i32_to_2xi16_param_0]; +; CHECK: mov.b32 {%rs{{[0-9]+}}, %rs{{[0-9+]}}}, %[[R32]]; + %s1 = insertvalue %struct.S16 poison, i16 %low, 0 + %s = insertvalue %struct.S16 %s1, i16 %high, 1 + ret %struct.S16 %s +} + +; CHECK-LABEL: i32_to_2xi16( +define %struct.S16 @i32_to_2xi16_lh(i32 noundef %in) { + %high32 = lshr i32 %in, 16 + %high = trunc i32 %high32 to i16 + %low = trunc i32 %in to i16 +; CHECK: ld.param.u32 %[[R32:r[0-9]+]], [i32_to_2xi16_lh_param_0]; +; CHECK: mov.b32 {%rs{{[0-9]+}}, %rs{{[0-9+]}}}, %[[R32]]; + %s1 = insertvalue %struct.S16 poison, i16 %low, 0 + %s = insertvalue %struct.S16 %s1, i16 %high, 1 + ret %struct.S16 %s +} + + +; CHECK-LABEL: i32_to_2xi16_not( +define %struct.S16 @i32_to_2xi16_not(i32 noundef %in) { + %low = trunc i32 %in to i16 + ; Shift by any value other than 16 blocks the conversiopn to mov. + %high32 = lshr i32 %in, 15 + %high = trunc i32 %high32 to i16 +; CHECK: cvt.u16.u32 +; CHECK: shr.u32 +; CHECK: cvt.u16.u32 + %s1 = insertvalue %struct.S16 poison, i16 %low, 0 + %s = insertvalue %struct.S16 %s1, i16 %high, 1 + ret %struct.S16 %s +} + +%struct.S32 = type { i32, i32 } +; CHECK-LABEL: i64_to_2xi32( +define %struct.S32 @i64_to_2xi32(i64 noundef %in) { + %low = trunc i64 %in to i32 + %high64 = lshr i64 %in, 32 + %high = trunc i64 %high64 to i32 +; CHECK: ld.param.u64 %[[R64:rd[0-9]+]], [i64_to_2xi32_param_0]; +; CHECK: mov.b64 {%r{{[0-9]+}}, %r{{[0-9+]}}}, %[[R64]]; + %s1 = insertvalue %struct.S32 poison, i32 %low, 0 + %s = insertvalue %struct.S32 %s1, i32 %high, 1 + ret %struct.S32 %s +} + +; CHECK-LABEL: i64_to_2xi32_not( +define %struct.S32 @i64_to_2xi32_not(i64 noundef %in) { + %low = trunc i64 %in to i32 + ; Shift by any value other than 32 blocks the conversiopn to mov. + %high64 = lshr i64 %in, 31 + %high = trunc i64 %high64 to i32 +; CHECK: cvt.u32.u64 +; CHECK: shr.u64 +; CHECK: cvt.u32.u64 + %s1 = insertvalue %struct.S32 poison, i32 %low, 0 + %s = insertvalue %struct.S32 %s1, i32 %high, 1 + ret %struct.S32 %s +} + +define %struct.S16 @i32_ashr_escape(ptr nocapture noundef readnone %0, i32 noundef %1) { + %3 = trunc i32 %1 to i16 + %4 = ashr i32 %1, 16 + %5 = trunc i32 %4 to i16 + tail call void @_Z6escapei(i32 noundef %4) #2 + %6 = insertvalue %struct.S16 poison, i16 %5, 0 + %7 = insertvalue %struct.S16 %6, i16 %3, 1 + ret %struct.S16 %7 +} + +define dso_local %struct.S16 @_Z1fPsi(ptr nocapture noundef readnone %0, i32 noundef %1) local_unnamed_addr #0 { + %3 = trunc i32 %1 to i16 + %4 = shl i32 %1, 16 + %5 = ashr exact i32 %4, 16 + tail call void @_Z6escapei(i32 noundef %5) #2 + %6 = lshr i32 %1, 16 + %7 = trunc i32 %6 to i16 + %8 = insertvalue %struct.S16 poison, i16 %7, 0 + %9 = insertvalue %struct.S16 %8, i16 %3, 1 + ret %struct.S16 %9 +} + +declare dso_local void @_Z6escapei(i32 noundef)