diff --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir --- a/mlir/test/Target/nvvmir.mlir +++ b/mlir/test/Target/nvvmir.mlir @@ -73,6 +73,37 @@ llvm.return %0 : !llvm.struct<(float, float, float, float, float, float, float, float)> } +llvm.func @gpu_wmma_load_op(%arg0: !llvm.ptr, %arg1: !llvm.i32) { + // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, i32 %{{.*}}) + %0 = nvvm.wmma.m16n16k16.load %arg0, %arg1 {ldm = 32 : i64, operand = "AOp"} : !llvm.ptr, !llvm.i32 -> !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + + llvm.return +} + +llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr, %arg1: !llvm.vec<2 x half>, + %arg2: !llvm.vec<2 x half>, %arg3: !llvm.vec<2 x half>, + %arg4: !llvm.vec<2 xhalf>, %arg5: !llvm.i32) { + // CHECK: call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f16.p3i32(i32 addrspace(3)* %{{.*}}, <2 x half> {{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, <2 x half> %{{.*}}, i32 %{{.*}}) + nvvm.wmma.m16n16k16.store %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 : !llvm.ptr, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.i32 + llvm.return +} + +llvm.func @gpu_wmma_mma_op(%arg0: !llvm.vec<2 x half>, %arg1: !llvm.vec<2 x half>, + %arg2: !llvm.vec<2 x half>, %arg3: !llvm.vec<2 x half>, + %arg4: !llvm.vec<2 x half>, %arg5: !llvm.vec<2 x half>, + %arg6: !llvm.vec<2 x half>, %arg7: !llvm.vec<2 x half>, + %arg8: !llvm.vec<2 x half>, %arg9: !llvm.vec<2 x half>, + %arg10: !llvm.vec<2 x half>, %arg11: !llvm.vec<2 x half>, + %arg12: !llvm.vec<2 x half>, %arg13: !llvm.vec<2 x half>, + %arg14: !llvm.vec<2 x half>, %arg15: !llvm.vec<2 x half>, + %arg16: !llvm.vec<2 x half>, %arg17: !llvm.vec<2 x half>, + %arg18: !llvm.vec<2 x half>, %arg19: !llvm.vec<2 x half>) { + // CHECK: call { <2 x half>, <2 x half>, <2 x half>, <2 x half> } @llvm.nvvm.wmma.m16n16k16.mma.row.row.f16.f16(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}}) + %0 = nvvm.wmma.m16n16k16.mma %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19 : !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half> -> !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + + llvm.return +} + // This function has the "kernel" attribute attached and should appear in the // NVVM annotations after conversion. llvm.func @kernel_func() attributes {gpu.kernel} {