diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -124,26 +124,42 @@ "$args attr-dict `:` functional-type($args, $res)"; } +// Available on all CDNA. def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">; +def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">; +def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">; def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32">; def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32">; -def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">; def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16">; -def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">; def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16">; -def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">; -def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">; -def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">; -def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">; -def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">; -def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">; -def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">; def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16">; +def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">; +def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">; def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8">; def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8">; def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8">; def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">; def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">; +def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">; +def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">; +def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">; +def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">; +def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">; +// New in gfx90a. +def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16.1k">; +def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4bf16.1k">; +def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4bf16.1k">; +def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8bf16.1k">; +def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16bf16.1k">; +// Note: in gfx940, unlike in gfx90a, the f64 xdlops use the "blgp" argument as a +// NEG bitfield. See IntrinsicsAMDGPU.td for more info. +def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64">; +def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.4x4x4f64">; +// New in gfx940. +def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8">; +def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">; +def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">; +def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">; //===---------------------------------------------------------------------===// // Vector buffer load/store intrinsics diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -40,7 +40,9 @@ %arg4 : vector<16xf32>, %arg5 : vector<4xf32>, %arg6 : vector<4xf16>, %arg7 : vector<32xi32>, %arg8 : vector<16xi32>, %arg9 : vector<4xi32>, - %arg10 : vector<2xi16>) -> vector<32xf32> { + %arg10 : vector<2xi16>, %arg11 : vector<4xi16>, + %arg12 : vector<4xf64>, %arg13 : f64, + %arg14 : i64, %arg15 : vector<2xf32>) { // CHECK-LABEL: rocdl.xdlops // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32> %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 : @@ -52,21 +54,21 @@ (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : - (f32, f32, vector<4xf32>, - i32, i32, i32) -> vector<4xf32> - // CHECK: rocdl.mfma.f32.4x4x1f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + %r2 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x2f32 {{.*}} : (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : + %r3= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r4 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + (f32, f32, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 : (vector<4xf16>, vector<4xf16>, vector<32xf32>, @@ -142,7 +144,63 @@ (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - llvm.return %r0 : vector<32xf32> + + // CHECK: rocdl.mfma.f32.32x32x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + %r20 = rocdl.mfma.f32.32x32x4bf16.1k %arg11, %arg11, %arg2, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<32xf32>, + i32, i32, i32) -> vector<32xf32> + + // CHECK: rocdl.mfma.f32.16x16x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r21 = rocdl.mfma.f32.16x16x4bf16.1k %arg11, %arg11, %arg4, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.mfma.f32.4x4x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r22 = rocdl.mfma.f32.4x4x4bf16.1k %arg11, %arg11, %arg5, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.mfma.f32.32x32x8bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r23 = rocdl.mfma.f32.32x32x8bf16.1k %arg11, %arg11, %arg4, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.mfma.f32.16x16x16bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r24 = rocdl.mfma.f32.16x16x16bf16.1k %arg11, %arg11, %arg5, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.mfma.f64.16x16x4f64 {{.*}} : (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64> + %r25 = rocdl.mfma.f64.16x16x4f64 %arg13, %arg13, %arg12, %arg3, %arg3, %arg3 : + (f64, f64, vector<4xf64>, + i32, i32, i32) -> vector<4xf64> + + // CHECK: rocdl.mfma.f64.4x4x4f64 {{.*}} : (f64, f64, f64, i32, i32, i32) -> f64 + %r26 = rocdl.mfma.f64.4x4x4f64 %arg13, %arg13, %arg13, %arg3, %arg3, %arg3 : + (f64, f64, f64, + i32, i32, i32) -> f64 + + // CHECK: rocdl.mfma.i32.16x16x32.i8 {{.*}} : (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + %r27 = rocdl.mfma.i32.16x16x32.i8 %arg14, %arg14, %arg9, %arg3, %arg3, %arg3 : + (i64, i64, vector<4xi32>, + i32, i32, i32) -> vector<4xi32> + + // CHECK: rocdl.mfma.i32.32x32x16.i8 {{.*}} : (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + %r28 = rocdl.mfma.i32.32x32x16.i8 %arg14, %arg14, %arg8, %arg3, %arg3, %arg3 : + (i64, i64, vector<16xi32>, + i32, i32, i32) -> vector<16xi32> + + // CHECK: rocdl.mfma.f32.16x16x8.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r29 = rocdl.mfma.f32.16x16x8.xf32 %arg15, %arg15, %arg5, %arg3, %arg3, %arg3 : + (vector<2xf32>, vector<2xf32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.mfma.f32.32x32x4.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r30 = rocdl.mfma.f32.32x32x4.xf32 %arg15, %arg15, %arg4, %arg3, %arg3, %arg3 : + (vector<2xf32>, vector<2xf32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + llvm.return } llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32,