diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -380,6 +380,10 @@ VectorOfLengthAndType<[4, 16, 32], [F32]>, VectorOfLengthAndType<[4, 16, 32], [I32]>, VectorOfLengthAndType<[4], [F64]>]>; +// wmma +def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>; +def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>, + VectorOfLengthAndType<[8, 16], [F16, BF16]>]>; def AMDGPU_MFMAOp : AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>, @@ -438,4 +442,41 @@ let hasVerifier = 1; } +def AMDGPU_WMMAOp : + AMDGPU_Op<"wmma", [AllTypesMatch<["destC", "destD"]>, + AllTypesMatch<["sourceA", "sourceB"]>, + Pure]>, + Arguments<(ins + WMMAInTypes:$sourceA, + WMMAInTypes:$sourceB, + WMMAOutTypes:$destC, + DefaultValuedAttr, IntMaxValue<1>]>, "0">:$subwordOffset, + UnitAttr:$unsignedA, + UnitAttr:$unsignedB, + UnitAttr:$clamp)>, + Results<(outs WMMAOutTypes: $destD)> { + let summary = "MLIR wrapper for RDNA3 wmma instructions"; + let description = [{ + The `amdgpu.wmma` op is an MLIR wrapper around intrinsics + for various `wmma` instructions in the RDNA3 architecture, which perform + a 16x16 matrix multiplication for different data types. + + When emitting f16->f16 (or bf16->bf16) wmma the output is a 16xf16 (or 16xbf16) vector + containing only 8 valid values: + - If `subwordOffset` is 0, then the output is stored at indices 0, 2, 4, ..., 14. + - If `subwordOffset` is 1, then the output is stored at indices 1, 3, 5, ..., 15. + + `unsignedA` and `unsignedB` flag that the `int8` LLVM inputs are unsigned. + + The `clamp` flag is used to saturate the output of type T to numeric_limits::max() + in case of overflow. + }]; + let assemblyFormat = [{ + $sourceA `*` $sourceB `+` $destC + attr-dict + `:` type($sourceA) `,` type($sourceB) `,` type($destC) + }]; + let hasVerifier = 1; +} + #endif // AMDGPU diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -124,6 +124,7 @@ let assemblyFormat = "attr-dict"; } + //===---------------------------------------------------------------------===// // Xdlops intrinsics @@ -182,6 +183,26 @@ def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">; def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">; +//===---------------------------------------------------------------------===// +// WMMA intrinsics +class ROCDL_Wmma_IntrOp traits = []> : + LLVM_IntrOpBase, + Arguments<(ins Variadic:$args)> { + let assemblyFormat = + "$args attr-dict `:` functional-type($args, $res)"; +} + +// Available on RDNA3 +def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16">; +def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16">; +def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16">; +def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16">; +def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8">; +def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4">; + + //===---------------------------------------------------------------------===// // Vector buffer load/store intrinsics diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -32,6 +32,12 @@ return rewriter.create(loc, llvmI32, value); } +static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc, + bool value) { + Type llvmI1 = rewriter.getI1Type(); + return rewriter.createOrFold(loc, llvmI1, value); +} + namespace { /// Define lowering patterns for raw buffer ops template @@ -334,6 +340,64 @@ return input; } +/// Push an input operand. If it is a float type, nothing to do. If it is +/// an integer type, then we need to also push its signdness (1 for signed, 0 +/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32 +/// vector. +static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter, + Location loc, TypeConverter *typeConverter, + bool isUnsigned, Value llvmInput, + SmallVector &operands) { + Type inputType = llvmInput.getType(); + auto vectorType = inputType.dyn_cast(); + Type elemType = vectorType.getElementType(); + + if (!elemType.isInteger(8)) { + operands.push_back(llvmInput); + return; + } + + int64_t numBytes = vectorType.getNumElements(); + Type i32 = rewriter.getI32Type(); + VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32); + auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits); + + Value result = rewriter.createOrFold( + loc, llvmVectorType32bits, llvmInput); + + // if element type is 8-bit signed or unsigned, ignore the isUnsigned flag + bool localIsUnsigned = isUnsigned; + if (elemType.isUnsignedInteger(8)) { + localIsUnsigned = true; + } else if (elemType.isSignedInteger(8)) { + localIsUnsigned = false; + } + Value sign = createI1Constant(rewriter, loc, !localIsUnsigned); + operands.push_back(sign); + operands.push_back(result); +} + +/// Push the output operand. For many cases this is only pushing the output in +/// the operand list. But when we have f16 -> f16 or bf16 -> bf16 intrinsics, +/// since the same numbers of VGPRs is used, we need to decide if to store the +/// result in the upper 16 bits of the VGPRs or in the lower part. To store the +/// result in the lower 16 bits, set subwordOffset to 1, otherwise result will +/// be stored it in the upper part +static void wmmaPushOutputOperand(ConversionPatternRewriter &rewriter, + Location loc, TypeConverter *typeConverter, + Value output, int32_t subwordOffset, + bool clamp, SmallVector &operands) { + Type inputType = output.getType(); + auto vectorType = inputType.dyn_cast(); + Type elemType = vectorType.getElementType(); + operands.push_back(output); + if (elemType.isF16() || elemType.isBF16()) { + operands.push_back(createI1Constant(rewriter, loc, subwordOffset)); + } else if (elemType.isInteger(32)) { + operands.push_back(createI1Constant(rewriter, loc, clamp)); + } +} + /// Return the `rocdl` intrinsic corresponding to a MFMA operation `mfma` /// if one exists. This includes checking to ensure the intrinsic is supported /// on the architecture you are compiling for. @@ -471,6 +535,31 @@ return std::nullopt; } +/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma` +/// if one exists. This includes checking to ensure the intrinsic is supported +/// on the architecture you are compiling for. +static std::optional wmmaOpToIntrinsic(WMMAOp wmma, + Chipset chipset) { + + auto sourceVectorType = wmma.getSourceA().getType().dyn_cast(); + auto destVectorType = wmma.getDestC().getType().dyn_cast(); + auto elemSourceType = sourceVectorType.getElementType(); + auto elemDestType = destVectorType.getElementType(); + + if (elemSourceType.isF16() && elemDestType.isF32()) { + return ROCDL::wmma_f32_16x16x16_f16::getOperationName(); + } else if (elemSourceType.isBF16() && elemDestType.isF32()) { + return ROCDL::wmma_f32_16x16x16_bf16::getOperationName(); + } else if (elemSourceType.isF16() && elemDestType.isF16()) { + return ROCDL::wmma_f16_16x16x16_f16::getOperationName(); + } else if (elemSourceType.isBF16() && elemDestType.isBF16()) { + return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName(); + } else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) { + return ROCDL::wmma_i32_16x16x16_iu8::getOperationName(); + } + return std::nullopt; +} + namespace { struct MFMAOpLowering : public ConvertOpToLLVMPattern { MFMAOpLowering(LLVMTypeConverter &converter, Chipset chipset) @@ -510,6 +599,45 @@ } }; +struct WMMAOpLowering : public ConvertOpToLLVMPattern { + WMMAOpLowering(LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Type outType = typeConverter->convertType(op.getDestD().getType()); + + if (chipset.majorVersion != 11) + return op->emitOpError("WMMA only supported on gfx11"); + + std::optional maybeIntrinsic = wmmaOpToIntrinsic(op, chipset); + + if (!maybeIntrinsic.has_value()) + return op.emitOpError("no intrinsic matching WMMA on the given chipset"); + + OperationState loweredOp(loc, *maybeIntrinsic); + loweredOp.addTypes(outType); + + SmallVector operands; + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(), + adaptor.getSourceA(), operands); + wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(), + adaptor.getSourceB(), operands); + wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(), + op.getSubwordOffset(), op.getClamp(), operands); + + loweredOp.addOperands(operands); + Operation *lowered = rewriter.create(loweredOp); + rewriter.replaceOp(op, lowered->getResults()); + + return success(); + } +}; + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLBase { ConvertAMDGPUToROCDLPass() = default; @@ -549,7 +677,7 @@ RawBufferOpLowering, RawBufferOpLowering, - MFMAOpLowering>(converter, chipset); + MFMAOpLowering, WMMAOpLowering>(converter, chipset); } std::unique_ptr mlir::createConvertAMDGPUToROCDLPass() { diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -205,6 +205,34 @@ context); } +//===----------------------------------------------------------------------===// +// WMMAOp +//===----------------------------------------------------------------------===// +LogicalResult WMMAOp::verify() { + Type sourceAType = getSourceA().getType(); + Type destType = getDestC().getType(); + + VectorType sourceVectorAType = sourceAType.dyn_cast(); + VectorType destVectorType = destType.dyn_cast(); + + Type sourceAElemType = sourceVectorAType.getElementType(); + Type destElemType = destVectorType.getElementType(); + + bool isDestFloat = + (destElemType.isF32() || destElemType.isF16() || destElemType.isBF16()); + bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16()); + + if (isDestFloat && !isSrcFloat) { + return emitOpError("Expected float sources with float destination"); + } + + if (!isDestFloat && isSrcFloat) { + return emitOpError("Expected int sources with int destination"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // MFMAOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s +func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, + %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>, + %arg6 : vector<16xi8>, %arg7 : vector<4xi32>, %arg8 : vector<8xi32>, + %arg9 : vector<16xui8>){ + // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> + // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xf32>) -> vector<8xf32> + amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> + // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<4xf32>) -> vector<4xf32> + amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> + // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<16xbf16>, i1) -> vector<16xbf16> + amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> + // CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xbf16>, vector<16xbf16>, vector<8xbf16>, i1) -> vector<8xbf16> + amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32> + // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<8xi32> + func.return +} diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir --- a/mlir/test/Dialect/AMDGPU/invalid.mlir +++ b/mlir/test/Dialect/AMDGPU/invalid.mlir @@ -103,3 +103,11 @@ abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32> func.return %d : vector<32xf32> } + +// ----- + +func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xi32>) -> vector<8xi32> { + // expected-error@+1 {{'amdgpu.wmma' op Expected int sources with int destination}} + %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xi32> + func.return %0 : vector<8xi32> +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -94,3 +94,10 @@ %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32> func.return %0 : vector<32xf32> } + +// CHECK-LABEL: func @wmma +func.func @wmma(%arg0 : vector<16xf16>, %arg1 : vector<8xf16>) -> vector<8xf16> { + // CHECK: amdgpu.wmma + %0 = amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf16> + func.return %0 : vector<8xf16> +} diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -215,6 +215,66 @@ llvm.return %r0 : vector<32 x f32> } +llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : vector<16 x i16>, %arg3 : vector<8 x i32>, + %arg4 : vector<2xi32>, %arg5 : vector<4xi32>, %arg6 : vector<4xf32>, %arg7 : vector<8xf16>, %arg8 : vector<8xi16>) -> vector<8xf32> { + %zero = llvm.mlir.constant(false) : i1 + + // ---- Wave32 ----- + + // f16 -> f32 + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v8f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x float> %{{.*}}) + %r0 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg0 : (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> + + // bf16 -> f32 + // CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v8f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x float> %{{.*}}) + %r1 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg0 : (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> + + // f16 -> f16 (OPSEL = {0,1}) + // CHECK: call <16 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v16f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <16 x half> %{{.*}}, i1 {{.*}}) + %r2 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg1, %zero : (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> + + // bf16 -> bf16 (OPSEL = {0,1}) + // CHECK: call <16 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v16i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <16 x i16> %{{.*}}, i1 {{.*}}) + %r4 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg2, %zero : (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> + + // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v8i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) + %r5 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg3, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> + + // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) + // CHECK: call <8 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v8i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <8 x i32> %{{.*}}, i1 {{.*}}) + %r6 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg3, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> + + // ---- Wave64 ----- + + // f16 -> f32 + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.f16.v4f32(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <4 x float> %{{.*}}) + %r7 = rocdl.wmma.f32.16x16x16.f16 %arg1, %arg1, %arg6 : (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> + + // bf16 -> f32 + // CHECK: call <4 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf16.v4f32(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <4 x float> %{{.*}}) + %r8 = rocdl.wmma.f32.16x16x16.bf16 %arg2, %arg2, %arg6 : (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> + + // f16 -> f16 (OPSEL = {0,1}) + // CHECK: call <8 x half> @llvm.amdgcn.wmma.f16.16x16x16.f16.v8f16(<16 x half> %{{.*}}, <16 x half> %{{.*}}, <8 x half> %{{.*}}, i1 {{.*}}) + %r9 = rocdl.wmma.f16.16x16x16.f16 %arg1, %arg1, %arg7, %zero : (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> + + // bf16 -> bf16 (OPSEL = {0,1}) + // CHECK: call <8 x i16> @llvm.amdgcn.wmma.bf16.16x16x16.bf16.v8i16(<16 x i16> %{{.*}}, <16 x i16> %{{.*}}, <8 x i16> %{{.*}}, i1 {{.*}}) + %r11 = rocdl.wmma.bf16.16x16x16.bf16 %arg2, %arg2, %arg8, %zero : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> + + // int8 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu8.v4i32(i1 {{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) + %r12 = rocdl.wmma.i32.16x16x16.iu8 %zero, %arg5, %zero, %arg5, %arg5, %zero : (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> + + // int4 -> int32 (signA = {0,1}, signB = {0,1}, clamp = {0,1}) + // CHECK: call <4 x i32> @llvm.amdgcn.wmma.i32.16x16x16.iu4.v4i32(i1 {{.*}}, <2 x i32> %{{.*}}, i1 {{.*}}, <2 x i32> %{{.*}}, <4 x i32> %{{.*}}, i1 {{.*}}) + %r13 = rocdl.wmma.i32.16x16x16.iu4 %zero, %arg4, %zero, %arg4, %arg5, %zero : (i1, vector<2xi32>, i1, vector<2xi32>, vector<4xi32>, i1) -> vector<4xi32> + + llvm.return %r0 : vector<8xf32> +} + + llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32, %offset : i32, %vdata1 : vector<1xf32>, %vdata2 : vector<2xf32>, %vdata4 : vector<4xf32>) {