This is an archive of the discontinued LLVM Phabricator instance.

[NVPTX] Allow using v4i32 for memcpy lowering.
ClosedPublic

Authored by tra on Jun 6 2023, 4:06 PM.

Diff Detail

Event Timeline

tra created this revision.Jun 6 2023, 4:06 PM
Herald added a project: Restricted Project. · View Herald TranscriptJun 6 2023, 4:06 PM
tra updated this revision to Diff 529079.Jun 6 2023, 4:16 PM

Simplified. Added 8-byte aligned memcpy test.

tra published this revision for review.Jun 7 2023, 1:17 PM
tra added a reviewer: jlebar.
Herald added a project: Restricted Project. · View Herald TranscriptJun 7 2023, 1:17 PM
tra updated this revision to Diff 529416.Jun 7 2023, 1:22 PM

Removed unnecessary lowering customization -- we already set it for all fixed
vector types.

tra updated this revision to Diff 529417.Jun 7 2023, 1:29 PM

Consolidate the changes in one place.

jlebar accepted this revision.Jun 7 2023, 7:40 PM

Thank you, Art!

This revision is now accepted and ready to land.Jun 7 2023, 7:40 PM
This revision was automatically updated to reflect the committed changes.

Hi Artem,

It looks like this commit is causing a massive compile time issue (a few seconds -> 1.5h) in one of our models in IREE.
I'm attaching a repro that you can pass to llc.
Could you please take a look or revert? Thanks!

; ModuleID = 'main_dispatch_394'
source_filename = "main_dispatch_394"
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(read, argmem: readwrite, inaccessiblemem: readwrite)
define void @main_dispatch_394_generic_16x16x512x512_i32xi32xi32xf32(ptr noalias readonly align 16 %0, ptr noalias readonly align 16 %1, ptr noalias align 16 %2) local_unnamed_addr #0 {
  %4 = addrspacecast ptr %2 to ptr addrspace(1)
  %5 = addrspacecast ptr %1 to ptr addrspace(1)
  %6 = addrspacecast ptr %0 to ptr addrspace(1)
  %7 = getelementptr float, ptr addrspace(1) %6, i64 399581184
  %8 = getelementptr i32, ptr addrspace(1) %5, i64 8388608
  %9 = getelementptr i32, ptr addrspace(1) %5, i64 41943040
  %10 = getelementptr float, ptr addrspace(1) %4, i64 58720768
  %11 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4
  %12 = zext i32 %11 to i64
  %13 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !5
  %14 = lshr i64 %12, 1
  %15 = and i64 %14, 511
  %16 = shl nuw nsw i64 %12, 8
  %17 = shl nuw nsw i32 %13, 2
  %18 = zext i32 %17 to i64
  %19 = or i64 %16, %18
  %20 = shl nuw nsw i64 %14, 9
  %21 = sub nsw i64 %19, %20
  %22 = shl nuw nsw i64 %15, 9
  %23 = add nsw i64 %22, %21
  %24 = getelementptr i32, ptr addrspace(1) %8, i64 %23
  %25 = load <4 x i32>, ptr addrspace(1) %24, align 16
  %26 = getelementptr i32, ptr addrspace(1) %5, i64 %15
  %27 = getelementptr i32, ptr addrspace(1) %26, i64 41943040
  %28 = load i32, ptr addrspace(1) %27, align 4
  %29 = getelementptr i32, ptr addrspace(1) %9, i64 %21
  %30 = load <4 x i32>, ptr addrspace(1) %29, align 16
  %31 = insertelement <4 x i32> undef, i32 %28, i64 0
  %32 = shufflevector <4 x i32> %31, <4 x i32> undef, <4 x i32> zeroinitializer
  %.not = icmp slt <4 x i32> %32, %30
  %33 = tail call <4 x i32> @llvm.smin.v4i32(<4 x i32> %25, <4 x i32> zeroinitializer)
  %34 = sub <4 x i32> zeroinitializer, %33
  %35 = sitofp <4 x i32> %34 to <4 x float>
  %36 = fmul <4 x float> %35, <float 6.250000e-02, float 6.250000e-02, float 6.250000e-02, float 6.250000e-02>
  %.inv = fcmp ole <4 x float> %36, <float 0x3810000000000000, float 0x3810000000000000, float 0x3810000000000000, float 0x3810000000000000>
  %37 = select <4 x i1> %.inv, <4 x float> <float 0x3810000000000000, float 0x3810000000000000, float 0x3810000000000000, float 0x3810000000000000>, <4 x float> %36
  %38 = bitcast <4 x float> %37 to <4 x i32>
  %39 = and <4 x i32> %38, <i32 -2139095041, i32 -2139095041, i32 -2139095041, i32 -2139095041>
  %40 = or <4 x i32> %39, <i32 1056964608, i32 1056964608, i32 1056964608, i32 1056964608>
  %41 = bitcast <4 x i32> %40 to <4 x float>
  %42 = lshr <4 x i32> %38, <i32 23, i32 23, i32 23, i32 23>
  %43 = sitofp <4 x i32> %42 to <4 x float>
  %44 = fadd <4 x float> %43, <float -1.260000e+02, float -1.260000e+02, float -1.260000e+02, float -1.260000e+02>
  %45 = fcmp olt <4 x float> %41, <float 0x3FE6A09E60000000, float 0x3FE6A09E60000000, float 0x3FE6A09E60000000, float 0x3FE6A09E60000000>
  %46 = select <4 x i1> %45, <4 x float> %41, <4 x float> zeroinitializer
  %47 = fadd <4 x float> %41, <float -1.000000e+00, float -1.000000e+00, float -1.000000e+00, float -1.000000e+00>
  %48 = select <4 x i1> %45, <4 x float> <float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00>, <4 x float> zeroinitializer
  %49 = fsub <4 x float> %44, %48
  %50 = fadd <4 x float> %47, %46
  %51 = fmul <4 x float> %50, %50
  %52 = fmul <4 x float> %50, %51
  %53 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %50, <4 x float> <float 0x3FB2043760000000, float 0x3FB2043760000000, float 0x3FB2043760000000, float 0x3FB2043760000000>, <4 x float> <float 0xBFBD7A3700000000, float 0xBFBD7A3700000000, float 0xBFBD7A3700000000, float 0xBFBD7A3700000000>)
  %54 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %50, <4 x float> <float 0xBFBFCBA9E0000000, float 0xBFBFCBA9E0000000, float 0xBFBFCBA9E0000000, float 0xBFBFCBA9E0000000>, <4 x float> <float 0x3FC23D37E0000000, float 0x3FC23D37E0000000, float 0x3FC23D37E0000000, float 0x3FC23D37E0000000>)
  %55 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %50, <4 x float> <float 0x3FC999D580000000, float 0x3FC999D580000000, float 0x3FC999D580000000, float 0x3FC999D580000000>, <4 x float> <float 0xBFCFFFFF80000000, float 0xBFCFFFFF80000000, float 0xBFCFFFFF80000000, float 0xBFCFFFFF80000000>)
  %56 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %53, <4 x float> %50, <4 x float> <float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000, float 0x3FBDE4A340000000>)
  %57 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %54, <4 x float> %50, <4 x float> <float 0xBFC555CA00000000, float 0xBFC555CA00000000, float 0xBFC555CA00000000, float 0xBFC555CA00000000>)
  %58 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %55, <4 x float> %50, <4 x float> <float 0x3FD5555540000000, float 0x3FD5555540000000, float 0x3FD5555540000000, float 0x3FD5555540000000>)
  %59 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %56, <4 x float> %52, <4 x float> %57)
  %60 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %59, <4 x float> %52, <4 x float> %58)
  %61 = fmul <4 x float> %52, %60
  %62 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %51, <4 x float> <float -5.000000e-01, float -5.000000e-01, float -5.000000e-01, float -5.000000e-01>, <4 x float> %61)
  %63 = fadd <4 x float> %50, %62
  %64 = tail call <4 x float> @llvm.fma.v4f32(<4 x float> %49, <4 x float> <float 0x3FE62E4300000000, float 0x3FE62E4300000000, float 0x3FE62E4300000000, float 0x3FE62E4300000000>, <4 x float> %63)
  %65 = fcmp ult <4 x float> %36, zeroinitializer
  %66 = fcmp oeq <4 x float> %36, zeroinitializer
  %67 = fcmp oeq <4 x float> %36, <float 0x7FF0000000000000, float 0x7FF0000000000000, float 0x7FF0000000000000, float 0x7FF0000000000000>
  %68 = fdiv <4 x float> %64, <float 0x4000A2B240000000, float 0x4000A2B240000000, float 0x4000A2B240000000, float 0x4000A2B240000000>
  %69 = fmul <4 x float> %68, <float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01>
  %70 = fadd <4 x float> %69, <float 1.600000e+01, float 1.600000e+01, float 1.600000e+01, float 1.600000e+01>
  %71 = select <4 x i1> %67, <4 x float> <float 0x7FF0000000000000, float 0x7FF0000000000000, float 0x7FF0000000000000, float 0x7FF0000000000000>, <4 x float> %70
  %72 = select <4 x i1> %65, <4 x float> <float 0x7FF8000000000000, float 0x7FF8000000000000, float 0x7FF8000000000000, float 0x7FF8000000000000>, <4 x float> %71
  %73 = select <4 x i1> %66, <4 x float> <float 0xFFF0000000000000, float 0xFFF0000000000000, float 0xFFF0000000000000, float 0xFFF0000000000000>, <4 x float> %72
  %.inv1 = fcmp oge <4 x float> %73, <float 3.100000e+01, float 3.100000e+01, float 3.100000e+01, float 3.100000e+01>
  %74 = select <4 x i1> %.inv1, <4 x float> <float 3.100000e+01, float 3.100000e+01, float 3.100000e+01, float 3.100000e+01>, <4 x float> %73
  %75 = icmp slt <4 x i32> %34, <i32 16, i32 16, i32 16, i32 16>
  %76 = select <4 x i1> %75, <4 x float> %35, <4 x float> %74
  %77 = fadd <4 x float> %76, zeroinitializer
  %78 = fptosi <4 x float> %77 to <4 x i32>
  %79 = add <4 x i32> %78, <i32 32, i32 32, i32 32, i32 32>
  %80 = icmp slt <4 x i32> %78, zeroinitializer
  %81 = select <4 x i1> %80, <4 x i32> %79, <4 x i32> %78
  %82 = lshr i64 %12, 10
  %83 = lshr i64 %12, 14
  %84 = shl nuw nsw i64 %83, 4
  %.scalar = sub nsw i64 %82, %84
  %85 = insertelement <4 x i64> undef, i64 %.scalar, i64 0
  %86 = shufflevector <4 x i64> %85, <4 x i64> poison, <4 x i32> zeroinitializer
  %87 = sext <4 x i32> %81 to <4 x i64>
  %88 = shl nsw <4 x i64> %87, <i64 4, i64 4, i64 4, i64 4>
  %89 = add <4 x i64> %88, %86
  %90 = getelementptr float, ptr addrspace(1) %7, <4 x i64> %89
  %Ptr0 = extractelement <4 x ptr addrspace(1)> %90, i64 0
  %Load0 = load float, ptr addrspace(1) %Ptr0, align 4
  %Res0 = insertelement <4 x float> poison, float %Load0, i64 0
  %Ptr1 = extractelement <4 x ptr addrspace(1)> %90, i64 1
  %Load1 = load float, ptr addrspace(1) %Ptr1, align 4
  %Res1 = insertelement <4 x float> %Res0, float %Load1, i64 1
  %Ptr2 = extractelement <4 x ptr addrspace(1)> %90, i64 2
  %Load2 = load float, ptr addrspace(1) %Ptr2, align 4
  %Res2 = insertelement <4 x float> %Res1, float %Load2, i64 2
  %Ptr3 = extractelement <4 x ptr addrspace(1)> %90, i64 3
  %Load3 = load float, ptr addrspace(1) %Ptr3, align 4
  %Res3 = insertelement <4 x float> %Res2, float %Load3, i64 3
  %91 = select <4 x i1> %.not, <4 x float> <float 0xC7EFFFFFE0000000, float 0xC7EFFFFFE0000000, float 0xC7EFFFFFE0000000, float 0xC7EFFFFFE0000000>, <4 x float> zeroinitializer
  %92 = icmp ult <4 x i32> %81, <i32 32, i32 32, i32 32, i32 32>
  %93 = select <4 x i1> %92, <4 x float> %Res3, <4 x float> <float 0x7FF8000000000000, float 0x7FF8000000000000, float 0x7FF8000000000000, float 0x7FF8000000000000>
  %94 = fadd <4 x float> %91, %93
  %95 = shl nuw nsw i64 %83, 22
  %96 = shl nuw nsw i64 %82, 18
  %97 = and i64 %96, 3932160
  %98 = or i64 %97, %95
  %99 = or i64 %98, %22
  %100 = add nsw i64 %99, %21
  %101 = getelementptr float, ptr addrspace(1) %10, i64 %100
  store <4 x float> %94, ptr addrspace(1) %101, align 16
  ret void
}

; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite)
declare void @llvm.assume(i1 noundef) #1

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #2

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #2

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare <4 x i32> @llvm.smin.v4i32(<4 x i32>, <4 x i32>) #2

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare <4 x float> @llvm.fma.v4f32(<4 x float>, <4 x float>, <4 x float>) #2

; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(read)
declare <4 x float> @llvm.masked.gather.v4f32.v4p0(<4 x ptr>, i32 immarg, <4 x i1>, <4 x float>) #3

; Function Attrs: nocallback nofree nosync nounwind willreturn memory(read)
declare <4 x float> @llvm.masked.gather.v4f32.v4p1(<4 x ptr addrspace(1)>, i32 immarg, <4 x i1>, <4 x float>) #4

attributes #0 = { mustprogress nofree nosync nounwind willreturn memory(read, argmem: readwrite, inaccessiblemem: readwrite) }
attributes #1 = { mustprogress nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) }
attributes #2 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #3 = { mustprogress nocallback nofree nosync nounwind willreturn memory(read) }
attributes #4 = { nocallback nofree nosync nounwind willreturn memory(read) }

!nvvm.annotations = !{!0, !1, !2, !3}

!0 = !{ptr @main_dispatch_394_generic_16x16x512x512_i32xi32xi32xf32, !"kernel", i32 1}
!1 = !{ptr @main_dispatch_394_generic_16x16x512x512_i32xi32xi32xf32, !"maxntidx", i32 64}
!2 = !{ptr @main_dispatch_394_generic_16x16x512x512_i32xi32xi32xf32, !"maxntidy", i32 1}
!3 = !{ptr @main_dispatch_394_generic_16x16x512x512_i32xi32xi32xf32, !"maxntidz", i32 1}
!4 = !{i32 0, i32 2147483647}
!5 = !{i32 0, i32 64}

Oops, sorry, I just saw the revert :)