diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -638,18 +638,10 @@ if (E0.empty() || E1.empty()) return false; - unsigned Op = NVPTX::SplitF16x2; - // If the vector has been BITCAST'ed from i32, we can use original - // value directly and avoid register-to-register move. - SDValue Source = Vector; - if (Vector->getOpcode() == ISD::BITCAST) { - Op = NVPTX::SplitI32toF16x2; - Source = Vector->getOperand(0); - } // Merge (f16 extractelt(V, 0), f16 extractelt(V,1)) // into f16,f16 SplitF16x2(V) - SDNode *ScatterOp = - CurDAG->getMachineNode(Op, SDLoc(N), MVT::f16, MVT::f16, Source); + SDNode *ScatterOp = CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), + MVT::f16, MVT::f16, Vector); for (auto *Node : E0) ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0)); for (auto *Node : E1) diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -2647,8 +2647,6 @@ defm LDV_i16 : LD_VEC; defm LDV_i32 : LD_VEC; defm LDV_i64 : LD_VEC; - defm LDV_f16 : LD_VEC; - defm LDV_f16x2 : LD_VEC; defm LDV_f32 : LD_VEC; defm LDV_f64 : LD_VEC; } @@ -2742,8 +2740,6 @@ defm STV_i16 : ST_VEC; defm STV_i32 : ST_VEC; defm STV_i64 : ST_VEC; - defm STV_f16 : ST_VEC; - defm STV_f16x2 : ST_VEC; defm STV_f32 : ST_VEC; defm STV_f64 : ST_VEC; } @@ -3056,6 +3052,10 @@ (ins Int32Regs:$s), "{{ .reg .b16 tmp; mov.b32 {tmp, $high}, $s; }}", []>; + def I32toI16L : NVPTXInst<(outs Int16Regs:$low), + (ins Int32Regs:$s), + "{{ .reg .b16 tmp; mov.b32 {$low, tmp}, $s; }}", + []>; def I64toI32H : NVPTXInst<(outs Int32Regs:$high), (ins Int64Regs:$s), "{{ .reg .b32 tmp; mov.b64 {tmp, $high}, $s; }}", @@ -3073,47 +3073,12 @@ def : Pat<(i32 (trunc (sra Int64Regs:$s, (i32 32)))), (I64toI32H Int64Regs:$s)>; -let hasSideEffects = false in { - // Extract element of f16x2 register. PTX does not provide any way - // to access elements of f16x2 vector directly, so we need to - // extract it using a temporary register. - def F16x2toF16_0 : NVPTXInst<(outs Int16Regs:$dst), - (ins Int32Regs:$src), - "{{ .reg .b16 \t%tmp_hi;\n\t" - " mov.b32 \t{$dst, %tmp_hi}, $src; }}", - [(set Int16Regs:$dst, - (extractelt (v2f16 Int32Regs:$src), 0))]>; - def F16x2toF16_1 : NVPTXInst<(outs Int16Regs:$dst), - (ins Int32Regs:$src), - "{{ .reg .b16 \t%tmp_lo;\n\t" - " mov.b32 \t{%tmp_lo, $dst}, $src; }}", - [(set Int16Regs:$dst, - (extractelt (v2f16 Int32Regs:$src), 1))]>; - - // Coalesce two f16 registers into f16x2 - def BuildF16x2 : NVPTXInst<(outs Int32Regs:$dst), - (ins Int16Regs:$a, Int16Regs:$b), - "mov.b32 \t$dst, {{$a, $b}};", - [(set (v2f16 Int32Regs:$dst), - (build_vector (f16 Int16Regs:$a), (f16 Int16Regs:$b)))]>; - - // Directly initializing underlying the b32 register is one less SASS - // instruction than than vector-packing move. - def BuildF16x2i : NVPTXInst<(outs Int32Regs:$dst), (ins i32imm:$src), - "mov.b32 \t$dst, $src;", - []>; - - // Split f16x2 into two f16 registers. - def SplitF16x2 : NVPTXInst<(outs Int16Regs:$lo, Int16Regs:$hi), - (ins Int32Regs:$src), - "mov.b32 \t{{$lo, $hi}}, $src;", - []>; - // Split an i32 into two f16 - def SplitI32toF16x2 : NVPTXInst<(outs Int16Regs:$lo, Int16Regs:$hi), - (ins Int32Regs:$src), - "mov.b32 \t{{$lo, $hi}}, $src;", - []>; -} +def : Pat<(f16 (extractelt (v2f16 Int32Regs:$src), 0)), + (I32toI16L Int32Regs:$src)>; +def : Pat<(f16 (extractelt (v2f16 Int32Regs:$src), 1)), + (I32toI16H Int32Regs:$src)>; +def : Pat<(v2f16 (build_vector (f16 Int16Regs:$a), (f16 Int16Regs:$b))), + (V2I16toI32 Int16Regs:$a, Int16Regs:$b)>; // Count leading zeros let hasSideEffects = false in { diff --git a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll --- a/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll +++ b/llvm/test/CodeGen/NVPTX/f16x2-instructions.ll @@ -40,7 +40,7 @@ ; CHECK-LABEL: test_extract_0( ; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_extract_0_param_0]; -; CHECK: mov.b32 {[[R:%rs[0-9]+]], %tmp_hi}, [[A]]; +; CHECK: mov.b32 {[[R:%rs[0-9]+]], tmp}, [[A]]; ; CHECK: st.param.b16 [func_retval0+0], [[R]]; ; CHECK: ret; define half @test_extract_0(<2 x half> %a) #0 { @@ -50,7 +50,7 @@ ; CHECK-LABEL: test_extract_1( ; CHECK: ld.param.b32 [[A:%r[0-9]+]], [test_extract_1_param_0]; -; CHECK: mov.b32 {%tmp_lo, [[R:%rs[0-9]+]]}, [[A]]; +; CHECK: mov.b32 {tmp, [[R:%rs[0-9]+]]}, [[A]]; ; CHECK: st.param.b16 [func_retval0+0], [[R]]; ; CHECK: ret; define half @test_extract_1(<2 x half> %a) #0 { @@ -1468,7 +1468,7 @@ } ; CHECK-LABEL: test_insertelement( -; CHECK: mov.b32 {%rs2, %tmp_hi}, %r1; +; CHECK: mov.b32 {%rs2, tmp}, %r1; ; CHECK: mov.b32 %r2, {%rs2, %rs1}; define <2 x half> @test_insertelement(<2 x half> %a, half %x) #0 { %i = insertelement <2 x half> %a, half %x, i64 1