diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td --- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td @@ -47,10 +47,11 @@ DefaultValuedAttr:$boundsCheck, OptionalAttr:$indexOffset, Optional:$sgprOffset)>, - Results<(outs AnyTypeOf<[BF16, F16, F32, I32, I8, + Results<(outs AnyTypeOf<[BF16, F16, F32, I32, I8, F8E5M2FNUZ, F8E4M3FNUZ, VectorOfLengthAndType<[2, 4], [F32, I32]>, VectorOfLengthAndType<[2, 4, 8], [F16, BF16]>, - VectorOfLengthAndType<[2, 4, 8, 16], [I8]>]>:$value)> { + VectorOfLengthAndType<[2, 4, 8, 16], + [I8, F8E5M2FNUZ, F8E4M3FNUZ]>]>:$value)> { let summary = "Raw Buffer load, exposing GCN features"; let description = [{ @@ -96,10 +97,11 @@ def AMDGPU_RawBufferStoreOp : AMDGPU_Op<"raw_buffer_store", [AllElementTypesMatch<["value", "memref"]>, AttrSizedOperandSegments]>, - Arguments<(ins AnyTypeOf<[BF16, F16, F32, I32, I8, + Arguments<(ins AnyTypeOf<[BF16, F16, F32, I32, I8, F8E5M2FNUZ, F8E4M3FNUZ, VectorOfLengthAndType<[2, 4], [F32, I32]>, VectorOfLengthAndType<[2, 4, 8], [F16, BF16]>, - VectorOfLengthAndType<[2, 4, 8, 16], [I8]>]>:$value, + VectorOfLengthAndType<[2, 4, 8, 16], + [I8, F8E5M2FNUZ, F8E4M3FNUZ]>]>:$value, Arg:$memref, Variadic:$indices, DefaultValuedAttr:$boundsCheck, @@ -215,15 +217,15 @@ VectorOfLengthAndType<[2], [F32]>, VectorOfLengthAndType<[4], [F16]>, VectorOfLengthAndType<[2, 4], [BF16]>, - VectorOfLengthAndType<[4, 8], [I8]>]>; + VectorOfLengthAndType<[4, 8], [I8]>, + VectorOfLengthAndType<[8], [F8E5M2FNUZ, F8E4M3FNUZ]>]>; def MFMAOutTypes : AnyTypeOf<[F64, VectorOfLengthAndType<[4, 16, 32], [F32]>, VectorOfLengthAndType<[4, 16, 32], [I32]>, VectorOfLengthAndType<[4], [F64]>]>; def AMDGPU_MFMAOp : - AMDGPU_Op<"mfma", [AllTypesMatch<["sourceA", "sourceB"]>, - AllTypesMatch<["destC", "destD"]>, + AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>, Pure]>, Arguments<(ins I32Attr:$m, @@ -274,7 +276,7 @@ $sourceA `*` $sourceB `+` $destC attr-dict `blgp` `=` $blgp - `:` type($sourceA) `,` type($destC) + `:` type($sourceA) `,` type($sourceB) `,` type($destC) }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -172,6 +172,15 @@ def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">; def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">; def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">; +// fp8, only on gfx940 +def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8">; +def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8">; +def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8">; +def ROCDL_mfma_f32_16x16x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.fp8">; +def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.bf8">; +def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8">; +def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">; +def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">; //===---------------------------------------------------------------------===// // Vector buffer load/store intrinsics diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -404,6 +404,45 @@ if (m == 4 && n == 4 && k == 4 && b == 4) return ROCDL::mfma_f64_4x4x4f64::getOperationName(); } + + if (sourceElem.isFloat8E5M2FNUZ() && destElem.isF32() && + chipset.minorVersion >= 0x40) { + // Known to be correct because there are no scalar f8 instructions and + // because a length mismatch will have been caught by the verifier. + Type sourceBElem = + mfma.getSourceB().getType().cast().getElementType(); + if (m == 16 && n == 16 && k == 32 && b == 1) { + if (sourceBElem.isFloat8E5M2FNUZ()) + return ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(); + if (sourceBElem.isFloat8E4M3FNUZ()) + return ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName(); + } + if (m == 32 && n == 32 && k == 16 && b == 1) { + if (sourceBElem.isFloat8E5M2FNUZ()) + return ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName(); + if (sourceBElem.isFloat8E4M3FNUZ()) + return ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName(); + } + } + + if (sourceElem.isFloat8E4M3FNUZ() && destElem.isF32() && + chipset.minorVersion >= 0x40) { + Type sourceBElem = + mfma.getSourceB().getType().cast().getElementType(); + if (m == 16 && n == 16 && k == 32 && b == 1) { + if (sourceBElem.isFloat8E5M2FNUZ()) + return ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName(); + if (sourceBElem.isFloat8E4M3FNUZ()) + return ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName(); + } + if (m == 32 && n == 32 && k == 16 && b == 1) { + if (sourceBElem.isFloat8E5M2FNUZ()) + return ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName(); + if (sourceBElem.isFloat8E4M3FNUZ()) + return ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(); + } + } + return std::nullopt; } @@ -475,6 +514,14 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, Chipset chipset) { + // ROCDL supports fp8 types in some contexts, but there is no LLVM-level f8 + // type. Therefore, for this target, declare f8 to be equal to i8. + converter.addConversion([](FloatType type) -> std::optional { + if (type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ()) + return IntegerType::get(type.getContext(), 8); + return std::nullopt; + }); + patterns.add(converter); patterns.add< RawBufferOpLowering, diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -189,6 +189,24 @@ destElem = destVector.getElementType(); } + Type sourceBType = getSourceB().getType(); + if (sourceElem.isFloat8E5M2FNUZ() || sourceElem.isFloat8E4M3FNUZ()) { + int64_t sourceBLen = 1; + Type sourceBElem = sourceBType; + if (auto sourceBVector = sourceBType.dyn_cast()) { + sourceBLen = sourceBVector.getNumElements(); + sourceBElem = sourceBVector.getElementType(); + } + if (!sourceBElem.isFloat8E5M2FNUZ() && !sourceBElem.isFloat8E4M3FNUZ()) + return emitOpError("expected both source operands to have f8 elements"); + if (sourceLen != sourceBLen) + return emitOpError( + "expected both f8 source vectors to have the same length"); + } else { + if (sourceType != sourceBType) + return emitOpError( + "expected both non-f8 source operand types to match exactly"); + } // Normalize the wider integer types the compiler expects to i8 if (sourceElem.isInteger(32)) { sourceLen *= 4; diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -49,6 +49,7 @@ %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi8>, i32 -> i8 func.return %0 : i8 } + // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi8 func.func @gpu_gcn_raw_buffer_load_2xi8(%buf: memref<64xi8>, %idx: i32) -> vector<2xi8> { // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32) @@ -69,6 +70,29 @@ func.return %0 : vector<16xi8> } +// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ +func.func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ(%buf: memref<64xf8E5M2FNUZ>, %idx: i32) -> f8E5M2FNUZ { + // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32) + // CHECK: llvm.insertelement{{.*}}%[[numRecords]] + // CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i8 + // CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[loaded]] : i8 to f8E5M2FNUZ + // CHECK: return %[[ret]] + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E5M2FNUZ>, i32 -> f8E5M2FNUZ + func.return %0 : f8E5M2FNUZ +} + +// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ +func.func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ(%buf: memref<64xf8E4M3FNUZ>, %idx: i32) -> vector<4xf8E4M3FNUZ> { + // CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32) + // CHECK: llvm.insertelement{{.*}}%[[numRecords]] + // CHECK: %[[loaded:.*]] = rocdl.raw.buffer.load %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 + // CHECK: %[[cast:.*]] = llvm.bitcast %[[loaded]] : i32 to vector<4xi8> + // CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[cast]] : vector<4xi8> to vector<4xf8E4M3FNUZ> + // CHECK: return %[[ret]] + %0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xf8E4M3FNUZ>, i32 -> vector<4xf8E4M3FNUZ> + func.return %0 : vector<4xf8E4M3FNUZ> +} + // Since the lowering logic is shared with loads, only bitcasts need to be rechecked // CHECK-LABEL: func @gpu_gcn_raw_buffer_store_i32 func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) { diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir --- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir @@ -6,68 +6,86 @@ %arg8 : vector<4xi32>, %arg9 : vector<2xbf16>, %arg10 : vector<4xbf16>, %arg11 : f64, %arg12 : vector<4xf64>, %arg13 : vector<8xi8>, - %arg14 : vector<2xf32>) { + %arg14 : vector<2xf32>, %arg15 : vector<8xf8E5M2FNUZ>, + %arg16 : vector<8xf8E4M3FNUZ>) { // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32> - amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : f32, vector<32xf32> + amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : f32, f32, vector<32xf32> // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : f32, vector<16xf32> + amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : f32, f32, vector<16xf32> // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : f32, vector<4xf32> + amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : f32, f32, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : f32, vector<16xf32> + amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : f32, f32, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f32, vector<4xf32> + amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f32, f32, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> - amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<32xf32> + amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<32xf32> // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xf16>, vector<16xf32> + amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xf16>, vector<4xf32> + amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<16xf32> + amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf32> + amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32> - amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<32xi32> + amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<32xi32> // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32> - amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<16xi32> + amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32> // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32> - amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xi8>, vector<4xi32> + amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32> // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32> - amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<16xi32> + amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32> // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32> - amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi32> + amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32> // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> - amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<32xf32> + amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32> // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<16xf32> + amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<4xf32> + amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<16xf32> + amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<4xf32> + amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> - amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<32xf32> + amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32> // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<16xf32> + amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xf32> + amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<16xf32> + amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xf32> + amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32> // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64> - amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, vector<4xf64> + amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, f64, vector<4xf64> // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64 - amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 } blgp = none : f64, f64 + amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 } blgp = none : f64, f64, f64 // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32> - amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<4xi32> + amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32> // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32> - amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<16xi32> + amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<16xi32> // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<4xf32> + amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<16xf32> + amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg15 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32> + // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg15 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32> + // CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg16 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32> + // CHECK: rocdl.mfma.f32.16x16x32.fp8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma %arg16 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x16.bf8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg15 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32> + // CHECK: rocdl.mfma.f32.32x32x16.bf8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg15 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32> + // CHECK: rocdl.mfma.f32.32x32x16.fp8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg16 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32> + // CHECK: rocdl.mfma.f32.32x32x16.fp8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma %arg16 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32> + func.return } diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -2,12 +2,34 @@ // ----- +func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>, + %c: vector<32xf32>) -> vector<32xf32> { + // expected-error@+1 {{'amdgpu.mfma' op expected both non-f8 source operand types to match exactly}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, + abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<4xf16>, vector<32xf32> + func.return %d : vector<32xf32> +} + +// ----- + +func.func @bad_source_types_f8(%a: vector<8xf8E5M2FNUZ>, %b: vector<8xi8>, + %c: vector<32xf32>) -> vector<32xf32> { + // expected-error@+1 {{'amdgpu.mfma' op expected both source operands to have f8 elements}} + %d = amdgpu.mfma %a * %b + %c { + m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, + abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xf8E5M2FNUZ>, vector<8xi8>, vector<32xf32> + func.return %d : vector<32xf32> +} + +// ----- + func.func @bad_source_arguments(%a: vector<2xf32>, %b: vector<2xf32>, %c: vector<32xf32>) -> vector<32xf32> { // expected-error@+1 {{'amdgpu.mfma' op expected 1 source values for this operation but got 2}} %d = amdgpu.mfma %a * %b + %c { m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, - abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<32xf32> + abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<2xf32>, vector<32xf32> func.return %d : vector<32xf32> } @@ -18,7 +40,7 @@ // expected-error@+1 {{'amdgpu.mfma' op expected 4 source values for this operation but got 8}} %d = amdgpu.mfma %a * %b + %c { m = 32 : i32, n = 32 : i32, k = 4 : i32, blocks = 2 : i32, - abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<4xi32> + abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32> func.return %d : vector<4xi32> } @@ -28,7 +50,7 @@ // expected-error@+1 {{'amdgpu.mfma' op expected 32 result values for this operation but got 16}} %d = amdgpu.mfma %a * %b + %c { m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, - abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, vector<16xf32> + abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<16xf32> return %d : vector<16xf32> } @@ -38,7 +60,7 @@ // expected-error@+1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of B}} %d = amdgpu.mfma %a * %b + %c { m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32, - abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, vector<4xf64> + abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, f64, vector<4xf64> return %d : vector<4xf64> } @@ -48,7 +70,7 @@ // expected-error@+1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of A}} %d = amdgpu.mfma %a * %b + %c { m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32, - abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, vector<4xf64> + abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, f64, vector<4xf64> return %d : vector<4xf64> } @@ -58,7 +80,7 @@ // expected-error@+1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}} %d = amdgpu.mfma %a * %b + %c { m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, - abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, vector<32xf32> + abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<32xf32> func.return %d : vector<32xf32> } @@ -68,7 +90,7 @@ // expected-error@+1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}} %d = amdgpu.mfma %a * %b + %c { m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, - abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, vector<32xf32> + abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, f32, vector<32xf32> func.return %d : vector<32xf32> } @@ -78,6 +100,6 @@ // expected-error@+1 {{'amdgpu.mfma' op negation flags only available for double-precision operations}} %d = amdgpu.mfma %a * %b + %c { m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32, - abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, vector<32xf32> + abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32> func.return %d : vector<32xf32> } diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -70,6 +70,6 @@ // CHECK-LABEL: func @mfma func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> { // CHECK: amdgpu.mfma - %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, vector<32xf32> + %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32> func.return %0 : vector<32xf32> } diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -69,7 +69,7 @@ %arg4 : vector<16 x f32>, %arg5 : vector<4xf32>, %arg6 : vector<4xf16>, %arg7 : vector<32 x i32>, %arg8 : vector<16 x i32>, %arg9 : vector<4xi32>, - %arg10 : vector<2xi16>) -> vector<32 x f32> { + %arg10 : vector<2xi16>, %arg11 : i64) -> vector<32 x f32> { %csti32 = llvm.mlir.constant(42 : i32) : i32 // CHECK-LABEL: rocdl.xdlops @@ -173,6 +173,45 @@ (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r20 = rocdl.mfma.f32.16x16x32.bf8.bf8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 : + (i64, i64, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf8.fp8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r21 = rocdl.mfma.f32.16x16x32.bf8.fp8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 : + (i64, i64, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.fp8.bf8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r22 = rocdl.mfma.f32.16x16x32.fp8.bf8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 : + (i64, i64, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.fp8.fp8(i64 %{{.*}}, i64 %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r23 = rocdl.mfma.f32.16x16x32.fp8.fp8 %arg11, %arg11, %arg5, %csti32, %csti32, %csti32 : + (i64, i64, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r24 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 : + (i64, i64, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.fp8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r25 = rocdl.mfma.f32.32x32x16.bf8.fp8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 : + (i64, i64, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r26 = rocdl.mfma.f32.32x32x16.fp8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 : + (i64, i64, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf8.bf8(i64 %{{.*}}, i64 %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}) + %r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 : + (i64, i64, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> llvm.return %r0 : vector<32 x f32> }