diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -102,4 +102,38 @@ let assemblyFormat = "attr-dict"; } +//===---------------------------------------------------------------------===// +// Xdlops intrinsics + +class ROCDL_Mfma_IntrOp traits = []> : + LLVM_IntrOpBase, + Arguments<(ins Variadic:$args)> { + let assemblyFormat = + "$args attr-dict `:` functional-type($args, $res)"; +} + +def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">; +def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32">; +def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32">; +def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">; +def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16">; +def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">; +def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16">; +def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">; +def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">; +def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">; +def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">; +def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">; +def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">; +def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">; +def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16">; +def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8">; +def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8">; +def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8">; +def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">; +def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">; + + #endif // ROCDLIR_OPS diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -34,3 +34,113 @@ rocdl.barrier llvm.return } + +func @rocdl.xdlops(%arg0 : !llvm.float, %arg1 : !llvm.float, + %arg2 : !llvm<"<32 x float>">, %arg3 : !llvm.i32, + %arg4 : !llvm<"<16 x float>">, %arg5 : !llvm<"<4 x float>">, + %arg6 : !llvm<"<4 x half>">, %arg7 : !llvm<"<32 x i32>">, + %arg8 : !llvm<"<16 x i32>">, %arg9 : !llvm<"<4 x i32>">, + %arg10 : !llvm<"<2 x i16>">) -> !llvm<"<32 x float>"> { + // CHECK-LABEL: rocdl.xdlops + // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<32 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + + // CHECK: rocdl.mfma.f32.16x16x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: rocdl.mfma.f32.4x4x1f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: rocdl.mfma.f32.32x32x2f32 {{.*}} : (!llvm.float, !llvm.float, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: rocdl.mfma.f32.32x32x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + + // CHECK: rocdl.mfma.f32.16x16x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: rocdl.mfma.f32.4x4x4f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: rocdl.mfma.f32.32x32x8f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: rocdl.mfma.f32.16x16x16f16 {{.*}} : (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: rocdl.mfma.i32.32x32x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>"> + %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>"> + + // CHECK: rocdl.mfma.i32.16x16x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>"> + %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>"> + + // CHECK: rocdl.mfma.i32.4x4x4i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>"> + %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>"> + + // CHECK: rocdl.mfma.i32.32x32x8i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>"> + %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>"> + + // CHECK: rocdl.mfma.i32.16x16x16i8 {{.*}} : (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>"> + %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>"> + + // CHECK: rocdl.mfma.f32.32x32x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + + // CHECK: rocdl.mfma.f32.16x16x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: rocdl.mfma.f32.4x4x2bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: rocdl.mfma.f32.32x32x4bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: rocdl.mfma.f32.16x16x8bf16 {{.*}} : (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + llvm.return %r0 : !llvm<"<32 x float>"> +} diff --git a/mlir/test/Target/rocdl.mlir b/mlir/test/Target/rocdl.mlir --- a/mlir/test/Target/rocdl.mlir +++ b/mlir/test/Target/rocdl.mlir @@ -41,3 +41,113 @@ rocdl.barrier llvm.return } + +llvm.func @rocdl.xdlops(%arg0 : !llvm.float, %arg1 : !llvm.float, + %arg2 : !llvm<"<32 x float>">, %arg3 : !llvm.i32, + %arg4 : !llvm<"<16 x float>">, %arg5 : !llvm<"<4 x float>">, + %arg6 : !llvm<"<4 x half>">, %arg7 : !llvm<"<32 x i32>">, + %arg8 : !llvm<"<16 x i32>">, %arg9 : !llvm<"<4 x i32>">, + %arg10 : !llvm<"<2 x i16>">) -> !llvm<"<32 x float>"> { + // CHECK-LABEL: rocdl.xdlops + // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float %{{.*}}, float %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<32 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x1f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r1 = rocdl.mfma.f32.16x16x1f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x4f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x1f32(float %{{.*}}, float %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x2f32(float %{{.*}}, float %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : + (!llvm.float, !llvm.float, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<32 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r6 = rocdl.mfma.f32.16x16x4f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x4f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r7 = rocdl.mfma.f32.4x4x4f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r8 = rocdl.mfma.f32.32x32x8f16 %arg6, %arg6, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x16f16(<4 x half> %{{.*}}, <4 x half> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r9 = rocdl.mfma.f32.16x16x16f16 %arg6, %arg6, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<4 x half>">, !llvm<"<4 x half>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: call <32 x i32> @llvm.amdgcn.mfma.i32.32x32x4i8(i32 %{{.*}}, i32 %{{.*}}, <32 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r10 = rocdl.mfma.i32.32x32x4i8 %arg3, %arg3, %arg7, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<32 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x i32>"> + + // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.16x16x4i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r11 = rocdl.mfma.i32.16x16x4i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>"> + + // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.4x4x4i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r12 = rocdl.mfma.i32.4x4x4i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>"> + + // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x8i8(i32 %{{.*}}, i32 %{{.*}}, <16 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r13 = rocdl.mfma.i32.32x32x8i8 %arg3, %arg3, %arg8, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<16 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x i32>"> + + // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x16i8(i32 %{{.*}}, i32 %{{.*}}, <4 x i32> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r14 = rocdl.mfma.i32.16x16x16i8 %arg3, %arg3, %arg9, %arg3, %arg3, %arg3 : + (!llvm.i32, !llvm.i32, !llvm<"<4 x i32>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x i32>"> + + // CHECK: call <32 x float> @llvm.amdgcn.mfma.f32.32x32x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <32 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r15 = rocdl.mfma.f32.32x32x2bf16 %arg10, %arg10, %arg2, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<32 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<32 x float>"> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.16x16x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r16 = rocdl.mfma.f32.16x16x2bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.4x4x2bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r17 = rocdl.mfma.f32.4x4x2bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x4bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r18 = rocdl.mfma.f32.32x32x4bf16 %arg10, %arg10, %arg4, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<16 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<16 x float>"> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x8bf16(<2 x i16> %{{.*}}, <2 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) + %r19 = rocdl.mfma.f32.16x16x8bf16 %arg10, %arg10, %arg5, %arg3, %arg3, %arg3 : + (!llvm<"<2 x i16>">, !llvm<"<2 x i16>">, !llvm<"<4 x float>">, + !llvm.i32, !llvm.i32, !llvm.i32) -> !llvm<"<4 x float>"> + + llvm.return %r0 : !llvm<"<32 x float>"> +}