diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td --- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPU.td @@ -10,6 +10,7 @@ #define AMDGPU include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def AMDGPU_Dialect : Dialect { @@ -163,4 +164,98 @@ let hasVerifier = 1; } +// Available MFMA intrinsics. +// Keep up to date with lvm/include/llvm/IR/IntrinsicsAMDGPU.td +// Generated by: perl -ne 'BEGIN { $i = 0; } if (/amdgcn_mfma_(\w+)\s*:\s*AMDGPUMfmaIntrinsic/) { print "I32EnumAttrCase<\"$1\", $i>,\n"; $i += 1; }' l +def AMDGPU_MFMAInstr : I32EnumAttr<"MFMAInstr", + "Any of the possible MFMA instructions available on AMD GPUs.", + [ + I32EnumAttrCase<"f32_32x32x1f32", 0>, + I32EnumAttrCase<"f32_16x16x1f32", 1>, + I32EnumAttrCase<"f32_4x4x1f32", 2>, + I32EnumAttrCase<"f32_32x32x2f32", 3>, + I32EnumAttrCase<"f32_16x16x4f32", 4>, + I32EnumAttrCase<"f32_32x32x4f16", 5>, + I32EnumAttrCase<"f32_16x16x4f16", 6>, + I32EnumAttrCase<"f32_4x4x4f16", 7>, + I32EnumAttrCase<"f32_32x32x8f16", 8>, + I32EnumAttrCase<"f32_16x16x16f16", 9>, + I32EnumAttrCase<"i32_32x32x4i8", 10>, + I32EnumAttrCase<"i32_16x16x4i8", 11>, + I32EnumAttrCase<"i32_4x4x4i8", 12>, + I32EnumAttrCase<"i32_32x32x8i8", 13>, + I32EnumAttrCase<"i32_16x16x16i8", 14>, + I32EnumAttrCase<"f32_32x32x2bf16", 15>, + I32EnumAttrCase<"f32_16x16x2bf16", 16>, + I32EnumAttrCase<"f32_4x4x2bf16", 17>, + I32EnumAttrCase<"f32_32x32x4bf16", 18>, + I32EnumAttrCase<"f32_16x16x8bf16", 19>, + I32EnumAttrCase<"f32_32x32x4bf16_1k", 20>, + I32EnumAttrCase<"f32_16x16x4bf16_1k", 21>, + I32EnumAttrCase<"f32_4x4x4bf16_1k", 22>, + I32EnumAttrCase<"f32_32x32x8bf16_1k", 23>, + I32EnumAttrCase<"f32_16x16x16bf16_1k", 24>, + I32EnumAttrCase<"f64_16x16x4f64", 25>, + I32EnumAttrCase<"f64_4x4x4f64", 26>, + I32EnumAttrCase<"i32_16x16x32_i8", 27>, + I32EnumAttrCase<"i32_32x32x16_i8", 28>, + I32EnumAttrCase<"f32_16x16x8_xf32", 29>, + I32EnumAttrCase<"f32_32x32x4_xf32", 30> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::amdgpu"; +} + +def AMDGPU_MFMAInstrAttr : EnumAttr; + +// mfma +def MFMAInTypes : AnyTypeOf<[F32, F64, I32, I64, + VectorOfLengthAndType<[2], [F32]>, + VectorOfLengthAndType<[4], [F16]>, + VectorOfLengthAndType<[2, 4], [BF16]>, + VectorOfLengthAndType<[4, 8], [I8]>]>; +def MFMAOutTypes : AnyTypeOf<[F64, + VectorOfLengthAndType<[4, 16, 32], [F32]>, + VectorOfLengthAndType<[4, 16, 32], [I32]>, + VectorOfLengthAndType<[4], [F64]>]>; + +def AMDGPU_MFMAOp : + AMDGPU_Op<"mfma", [AllTypesMatch<["sourceA", "sourceB"]>, + AllTypesMatch<["destC", "destD"]>]>, + Arguments<(ins AMDGPU_MFMAInstrAttr:$instr, + MFMAInTypes:$sourceA, + MFMAInTypes:$sourceB, + MFMAOutTypes:$destC, + I32Attr:$cbsz, + I32Attr:$abid, + I32Attr:$blgp)>, + Results<(outs MFMAOutTypes: $destD)> { + let summary = "MLIR wrapper for CDNA mfma instructions"; + let description = [{ + The `amdgpu.mfma` op is an MLIR wrapper around intrinsics + for various `mfma` instructions in the CDNA architecture, which perform + multiple outer products in order to allow fast matrix multiplication. + + The `instr` enum specifies the mfma instruction to be used, while `immArgs` + specifies the immediate arguments to said operation. + + Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA + intrinsics that take an integer type of width `4K`. For example, + one can provide a vector<4xi8> as an argument to an MFMA instruction that + logically takes 4 i8s but whose intrinsics are specified to take an i32. + In these cases, the bytes in the vector will be concatenated in little-endian + order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on). + + The `cbsz`, `abid`, and `blgp` attributes control broadcast and swizzling + during the computation. + }]; + let assemblyFormat = [{ + $instr attr-dict $sourceA `*` $sourceB `+` $destC + `cbsz` `=` $cbsz `abid` `=` $abid `blgp` `=` $blgp + `:` type($sourceA) `,` type($destC) + }]; + let hasVerifier = 1; +} + #endif // AMDGPU diff --git a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h --- a/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h +++ b/mlir/include/mlir/Dialect/AMDGPU/AMDGPUDialect.h @@ -14,13 +14,17 @@ #ifndef MLIR_DIALECT_AMDGPU_AMDGPUDIALECT_H_ #define MLIR_DIALECT_AMDGPU_AMDGPUDIALECT_H_ -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h.inc" +#include "mlir/Dialect/AMDGPU/AMDGPUEnums.h.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.h.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/AMDGPU.h.inc" diff --git a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt --- a/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/AMDGPU/CMakeLists.txt @@ -2,3 +2,11 @@ add_mlir_doc(AMDGPU AMDGPU Dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS AMDGPU.td) +mlir_tablegen(AMDGPUEnums.h.inc -gen-enum-decls) +mlir_tablegen(AMDGPUEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRAMDGPUEnumsGen) + +set(LLVM_TARGET_DEFINITIONS AMDGPU.td) +mlir_tablegen(AMDGPUAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=amdgpu) +mlir_tablegen(AMDGPUAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=amdgpu) +add_public_tablegen_target(MLIRAMDGPUAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -122,26 +122,42 @@ "$args attr-dict `:` functional-type($args, $res)"; } +// Available on all CDNA. def ROCDL_mfma_f32_32x32x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x1f32">; +def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">; +def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">; def ROCDL_mfma_f32_32x32x2f32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2f32">; def ROCDL_mfma_f32_16x16x4f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f32">; -def ROCDL_mfma_f32_16x16x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x1f32">; def ROCDL_mfma_f32_32x32x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4f16">; -def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">; def ROCDL_mfma_f32_16x16x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4f16">; -def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">; -def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">; -def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">; -def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">; -def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">; -def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">; -def ROCDL_mfma_f32_4x4x1f32 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x1f32">; def ROCDL_mfma_f32_4x4x4f16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4f16">; +def ROCDL_mfma_f32_32x32x8f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8f16">; +def ROCDL_mfma_f32_16x16x16f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16f16">; def ROCDL_mfma_i32_32x32x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x4i8">; def ROCDL_mfma_i32_16x16x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x4i8">; def ROCDL_mfma_i32_4x4x4i8 : ROCDL_Mfma_IntrOp<"mfma.i32.4x4x4i8">; def ROCDL_mfma_i32_32x32x8i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x8i8">; def ROCDL_mfma_i32_16x16x16i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x16i8">; +def ROCDL_mfma_f32_32x32x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x2bf16">; +def ROCDL_mfma_f32_16x16x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x2bf16">; +def ROCDL_mfma_f32_4x4x2bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x2bf16">; +def ROCDL_mfma_f32_32x32x4bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16">; +def ROCDL_mfma_f32_16x16x8bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8bf16">; +// New in gfx90a. +def ROCDL_mfma_f32_32x32x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4bf16.1k">; +def ROCDL_mfma_f32_16x16x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x4bf16.1k">; +def ROCDL_mfma_f32_4x4x4bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.4x4x4bf16.1k">; +def ROCDL_mfma_f32_32x32x8bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x8bf16.1k">; +def ROCDL_mfma_f32_16x16x16bf16_1k : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x16bf16.1k">; +// Note: in gfx940, unlike in gfx90a, the f64 xdlops use the "blgp" argument as a +// NEG bitfield. See IntrinsicsAMDGPU.td for more info. +def ROCDL_mfma_f64_16x16x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.16x16x4f64">; +def ROCDL_mfma_f64_4x4x4f64 : ROCDL_Mfma_IntrOp<"mfma.f64.4x4x4f64">; +// New in gfx940. +def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8">; +def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">; +def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">; +def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">; //===---------------------------------------------------------------------===// // Vector buffer load/store intrinsics diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -11,15 +11,18 @@ #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "llvm/ADT/STLExtras.h" using namespace mlir; +using namespace mlir::amdgpu; static Value createI32Constant(ConversionPatternRewriter &rewriter, Location loc, int32_t value) { IntegerAttr valAttr = rewriter.getI32IntegerAttr(value); Type llvmI32 = rewriter.getI32Type(); - return rewriter.create(loc, llvmI32, valAttr); + return rewriter.createOrFold(loc, llvmI32, valAttr); } namespace { @@ -233,6 +236,105 @@ return success(); } }; +} // end anonymous namespace + +/// If `input` is a vector of bytes, concatentate those bytes in little-endian +/// order to form a single integer of size 8 * [vector length]. This works +/// around a wart in the AMDGPU intrinsics where operations that logically take +/// vectors of bytes instead integers. Since we do not want to expose this +/// implementation detail to MLIR, we correct for it here. +static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter, + Location loc, Value input) { + Type inputType = input.getType(); + if (auto vectorType = inputType.dyn_cast()) { + if (vectorType.getElementType() != rewriter.getI8Type()) + return input; + int64_t numBytes = vectorType.getNumElements(); + Type destType = rewriter.getIntegerType(numBytes * 8); + Value result = rewriter.createOrFold( + loc, destType, rewriter.getIntegerAttr(destType, 0)); + for (int64_t i = 0; i < numBytes; ++i) { + Value idxConst = createI32Constant(rewriter, loc, i); + Value element = + rewriter.create(loc, input, idxConst); + Value extended = rewriter.create(loc, destType, element); + Value shiftConst = rewriter.createOrFold( + loc, destType, rewriter.getIntegerAttr(destType, i * 8)); + Value shifted = rewriter.create(loc, extended, shiftConst); + result = rewriter.create(loc, result, shifted); + } + return result; + } + return input; +} + +/// Return the `rocdl` intrinsic corresponding to a `MFMAInstr` value. +/// This conversion happens here to allow code up the stack to handle the choice +/// of mfma by picking between enum variants, which is much more ergonomic than +/// picking between ops, at the cost of some long switch statements in this +/// pass. +static StringRef mfmaInstrToIntrinsicName(MFMAInstr instr) { +#define LOWERING_CASE(type) \ + case MFMAInstr::type: \ + return ROCDL::mfma_##type::getOperationName(); + switch (instr) { + LOWERING_CASE(f32_32x32x1f32) + LOWERING_CASE(f32_16x16x1f32) + LOWERING_CASE(f32_4x4x1f32) + LOWERING_CASE(f32_32x32x2f32) + LOWERING_CASE(f32_16x16x4f32) + LOWERING_CASE(f32_32x32x4f16) + LOWERING_CASE(f32_16x16x4f16) + LOWERING_CASE(f32_4x4x4f16) + LOWERING_CASE(f32_32x32x8f16) + LOWERING_CASE(f32_16x16x16f16) + LOWERING_CASE(i32_32x32x4i8) + LOWERING_CASE(i32_16x16x4i8) + LOWERING_CASE(i32_4x4x4i8) + LOWERING_CASE(i32_32x32x8i8) + LOWERING_CASE(i32_16x16x16i8) + LOWERING_CASE(f32_32x32x2bf16) + LOWERING_CASE(f32_16x16x2bf16) + LOWERING_CASE(f32_4x4x2bf16) + LOWERING_CASE(f32_32x32x4bf16) + LOWERING_CASE(f32_16x16x8bf16) + LOWERING_CASE(f32_32x32x4bf16_1k) + LOWERING_CASE(f32_16x16x4bf16_1k) + LOWERING_CASE(f32_4x4x4bf16_1k) + LOWERING_CASE(f32_32x32x8bf16_1k) + LOWERING_CASE(f32_16x16x16bf16_1k) + LOWERING_CASE(f64_16x16x4f64) + LOWERING_CASE(f64_4x4x4f64) + LOWERING_CASE(i32_16x16x32_i8) + LOWERING_CASE(i32_32x32x16_i8) + LOWERING_CASE(f32_16x16x8_xf32) + LOWERING_CASE(f32_32x32x4_xf32) + } +#undef LOWERING_CASE +} + +namespace { +struct MFMAOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(MFMAOp op, MFMAOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Type outType = typeConverter->convertType(op.destD().getType()); + + OperationState loweredOp(loc, mfmaInstrToIntrinsicName(op.instr())); + loweredOp.addTypes(outType); + loweredOp.addOperands({mfmaConcatIfNeeded(rewriter, loc, adaptor.sourceA()), + mfmaConcatIfNeeded(rewriter, loc, adaptor.sourceB()), + adaptor.destC(), + createI32Constant(rewriter, loc, op.cbsz()), + createI32Constant(rewriter, loc, op.abid()), + createI32Constant(rewriter, loc, op.blgp())}); + Operation *lowered = rewriter.create(loweredOp); + rewriter.replaceOp(op, lowered->getResults()); + return success(); + } +}; struct ConvertAMDGPUToROCDLPass : public ConvertAMDGPUToROCDLBase { @@ -243,6 +345,7 @@ LLVMTypeConverter converter(&getContext()); populateAMDGPUToROCDLConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); + target.addIllegalDialect<::mlir::amdgpu::AMDGPUDialect>(); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::ROCDL::ROCDLDialect>(); if (failed(applyPartialConversion(getOperation(), target, @@ -250,15 +353,15 @@ signalPassFailure(); } }; -} // namespace +} // end anonymous namespace void mlir::populateAMDGPUToROCDLConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add< - RawBufferOpLowering, - RawBufferOpLowering, - RawBufferOpLowering>(converter); + RawBufferOpLowering, + RawBufferOpLowering, + RawBufferOpLowering, + MFMAOpLowering>(converter); } std::unique_ptr mlir::createConvertAMDGPUToROCDLPass() { diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -11,19 +11,27 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/AMDGPU/AMDGPUDialect.h" + #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; +using namespace mlir::amdgpu; #include "mlir/Dialect/AMDGPU/AMDGPUDialect.cpp.inc" -void amdgpu::AMDGPUDialect::initialize() { +void AMDGPUDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc" + >(); } //===----------------------------------------------------------------------===// @@ -44,17 +52,145 @@ return success(); } -LogicalResult amdgpu::RawBufferLoadOp::verify() { - return verifyRawBufferOp(*this); -} +LogicalResult RawBufferLoadOp::verify() { return verifyRawBufferOp(*this); } -LogicalResult amdgpu::RawBufferStoreOp::verify() { +LogicalResult RawBufferStoreOp::verify() { return verifyRawBufferOp(*this); } + +LogicalResult RawBufferAtomicFaddOp::verify() { return verifyRawBufferOp(*this); } -LogicalResult amdgpu::RawBufferAtomicFaddOp::verify() { - return verifyRawBufferOp(*this); +//===----------------------------------------------------------------------===// +// MFMAOp +//===----------------------------------------------------------------------===// +LogicalResult MFMAOp::verify() { + Builder b(getOperation()); + StringRef instrName = stringifyMFMAInstr(instr()); + + Type inType = sourceA().getType(); + switch (instr()) { + case MFMAInstr::f32_32x32x1f32: + case MFMAInstr::f32_16x16x1f32: + case MFMAInstr::f32_4x4x1f32: + case MFMAInstr::f32_32x32x2f32: + case MFMAInstr::f32_16x16x4f32: + if (inType != b.getF32Type()) + return emitOpError(instrName + " requires f32 inputs"); + break; + case MFMAInstr::f32_32x32x4f16: + case MFMAInstr::f32_16x16x4f16: + case MFMAInstr::f32_4x4x4f16: + case MFMAInstr::f32_32x32x8f16: + case MFMAInstr::f32_16x16x16f16: + if (inType != VectorType::get(4, b.getF16Type())) + return emitOpError(instrName + " requires vector<4xf16> inputs"); + break; + case MFMAInstr::i32_32x32x4i8: + case MFMAInstr::i32_16x16x4i8: + case MFMAInstr::i32_4x4x4i8: + case MFMAInstr::i32_32x32x8i8: + case MFMAInstr::i32_16x16x16i8: + if (inType != b.getI32Type() && inType != VectorType::get(4, b.getI8Type())) + return emitOpError(instrName + " requires i32 or vector<4xi8> inputs"); + break; + case MFMAInstr::f32_32x32x2bf16: + case MFMAInstr::f32_16x16x2bf16: + case MFMAInstr::f32_4x4x2bf16: + case MFMAInstr::f32_32x32x4bf16: + case MFMAInstr::f32_16x16x8bf16: + if (inType != VectorType::get(2, b.getBF16Type())) + return emitOpError(instrName + " requires vector<2xbf16> inputs"); + break; + case MFMAInstr::f32_32x32x4bf16_1k: + case MFMAInstr::f32_16x16x4bf16_1k: + case MFMAInstr::f32_4x4x4bf16_1k: + case MFMAInstr::f32_32x32x8bf16_1k: + case MFMAInstr::f32_16x16x16bf16_1k: + if (inType != VectorType::get(4, b.getBF16Type())) + return emitOpError(instrName + " requires vector<4xbf16> inputs"); + break; + case MFMAInstr::f64_16x16x4f64: + case MFMAInstr::f64_4x4x4f64: + if (inType != b.getF64Type()) + return emitOpError(instrName + " requires f64 inputs"); + break; + case MFMAInstr::i32_16x16x32_i8: + case MFMAInstr::i32_32x32x16_i8: + if (inType != b.getI64Type() && inType != VectorType::get(8, b.getI8Type())) + return emitOpError(instrName + " requires i64 or vector<8xi8> inputs"); + break; + case MFMAInstr::f32_16x16x8_xf32: + case MFMAInstr::f32_32x32x4_xf32: + if (inType != VectorType::get(2, b.getF32Type())) + return emitOpError(instrName + " requires vector<2xf32> inputs"); + break; + } + + Type outType = destC().getType(); + switch (instr()) { + case MFMAInstr::f32_32x32x1f32: + case MFMAInstr::f32_32x32x4f16: + case MFMAInstr::f32_32x32x2bf16: + case MFMAInstr::f32_32x32x4bf16_1k: + if (outType != VectorType::get(32, b.getF32Type())) + return emitOpError(instrName + " must have vector<32xf32> outputs"); + break; + case MFMAInstr::f32_16x16x1f32: + case MFMAInstr::f32_32x32x2f32: + case MFMAInstr::f32_16x16x4f16: + case MFMAInstr::f32_32x32x8f16: + case MFMAInstr::f32_16x16x2bf16: + case MFMAInstr::f32_32x32x4bf16: + case MFMAInstr::f32_16x16x4bf16_1k: + case MFMAInstr::f32_32x32x8bf16_1k: + case MFMAInstr::f32_32x32x4_xf32: + if (outType != VectorType::get(16, b.getF32Type())) + return emitOpError(instrName + " must have vector<16xf32> outputs"); + break; + case MFMAInstr::f32_4x4x1f32: + case MFMAInstr::f32_16x16x4f32: + case MFMAInstr::f32_4x4x4f16: + case MFMAInstr::f32_16x16x16f16: + case MFMAInstr::f32_4x4x2bf16: + case MFMAInstr::f32_16x16x8bf16: + case MFMAInstr::f32_4x4x4bf16_1k: + case MFMAInstr::f32_16x16x16bf16_1k: + case MFMAInstr::f32_16x16x8_xf32: + if (outType != VectorType::get(4, b.getF32Type())) + return emitOpError(instrName + " must have vector<4xf32> outputs"); + break; + case MFMAInstr::i32_32x32x4i8: + + if (outType != VectorType::get(32, b.getI32Type())) + return emitOpError(instrName + " must have vector<32xi32> outputs"); + break; + case MFMAInstr::i32_16x16x4i8: + case MFMAInstr::i32_32x32x8i8: + case MFMAInstr::i32_32x32x16_i8: + if (outType != VectorType::get(16, b.getI32Type())) + return emitOpError(instrName + " must have vector<16xi32> outputs"); + break; + case MFMAInstr::i32_4x4x4i8: + case MFMAInstr::i32_16x16x16i8: + case MFMAInstr::i32_16x16x32_i8: + if (outType != VectorType::get(4, b.getI32Type())) + return emitOpError(instrName + " must have vector<4xi32> outputs"); + break; + case MFMAInstr::f64_16x16x4f64: + if (outType != VectorType::get(4, b.getF64Type())) + return emitOpError(instrName + " must have vector<4xf64> outputs"); + break; + case MFMAInstr::f64_4x4x4f64: + if (outType != b.getF64Type()) + return emitOpError(instrName + " must have f64 outputs"); + } + return success(); } +#include "mlir/Dialect/AMDGPU/AMDGPUEnums.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/AMDGPU/AMDGPUAttributes.cpp.inc" + #define GET_OP_CLASSES #include "mlir/Dialect/AMDGPU/AMDGPU.cpp.inc" diff --git a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt --- a/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/IR/CMakeLists.txt @@ -5,6 +5,8 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU DEPENDS + MLIRAMDGPUEnumsGen + MLIRAMDGPUAttributesIncGen MLIRAMDGPUIncGen LINK_LIBS PUBLIC diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir --- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir @@ -108,3 +108,76 @@ amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, targetIsRDNA = false} %value -> %buf[%idx] : f32 -> memref<64xf32>, i32 func.return } + +func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>, + %arg2 : vector<16xf32>, %arg3 : vector<4xf32>, + %arg4 : vector<4xf16>, %arg5 : vector<4xi8>, + %arg6 : vector<32xi32>, %arg7 : vector<16xi32>, + %arg8 : vector<4xi32>, %arg9 : vector<2xbf16>, + %arg10 : vector<4xbf16>, %arg11 : f64, + %arg12 : vector<4xf64>, %arg13 : vector<8xi8>, + %arg14 : vector<2xf32>) { + // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma f32_32x32x1f32 %arg0 * %arg0 + %arg1 cbsz = 0 abid = 0 blgp = 0 : f32, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_16x16x1f32 %arg0 * %arg0 + %arg2 cbsz = 0 abid = 0 blgp = 0 : f32, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_4x4x1f32 %arg0 * %arg0 + %arg3 cbsz = 0 abid = 0 blgp = 0 : f32, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_32x32x2f32 %arg0 * %arg0 + %arg2 cbsz = 0 abid = 0 blgp = 0 : f32, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_16x16x4f32 %arg0 * %arg0 + %arg3 cbsz = 0 abid = 0 blgp = 0 : f32, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma f32_32x32x4f16 %arg4 * %arg4 + %arg1 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_16x16x4f16 %arg4 * %arg4 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_4x4x4f16 %arg4 * %arg4 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_32x32x8f16 %arg4 * %arg4 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_16x16x16f16 %arg4 * %arg4 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<4xf32> + // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32> + amdgpu.mfma i32_32x32x4i8 %arg5 * %arg5 + %arg6 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<32xi32> + // CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.mfma i32_16x16x4i8 %arg5 * %arg5 + %arg7 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<16xi32> + // CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.mfma i32_4x4x4i8 %arg5 * %arg5 + %arg8 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<4xi32> + // CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.mfma i32_32x32x8i8 %arg5 * %arg5 + %arg7 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<16xi32> + // CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.mfma i32_16x16x16i8 %arg5 * %arg5 + %arg8 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<4xi32> + // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma f32_32x32x2bf16 %arg9 * %arg9 + %arg1 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_16x16x2bf16 %arg9 * %arg9 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_4x4x2bf16 %arg9 * %arg9 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_32x32x4bf16 %arg9 * %arg9 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xbf16>, vector<2xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_16x16x8bf16 %arg9 * %arg9 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + amdgpu.mfma f32_32x32x4bf16_1k %arg10 * %arg10 + %arg1 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<32xf32> + // CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_16x16x4bf16_1k %arg10 * %arg10 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_4x4x4bf16_1k %arg10 * %arg10 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_32x32x8bf16_1k %arg10 * %arg10 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xbf16>, vector<4xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_16x16x16bf16_1k %arg10 * %arg10 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<4xf32> + // CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64> + amdgpu.mfma f64_16x16x4f64 %arg11 * %arg11 + %arg12 cbsz = 0 abid = 0 blgp = 0 : f64, vector<4xf64> + // CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64 + amdgpu.mfma f64_4x4x4f64 %arg11 * %arg11 + %arg11 cbsz = 0 abid = 0 blgp = 0 : f64, f64 + // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + amdgpu.mfma i32_16x16x32_i8 %arg13 * %arg13 + %arg8 cbsz = 0 abid = 0 blgp = 0 : vector<8xi8>, vector<4xi32> + // CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + amdgpu.mfma i32_32x32x16_i8 %arg13 * %arg13 + %arg7 cbsz = 0 abid = 0 blgp = 0 : vector<8xi8>, vector<16xi32> + // CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + amdgpu.mfma f32_16x16x8_xf32 %arg14 * %arg14 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<2xf32>, vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + amdgpu.mfma f32_32x32x4_xf32 %arg14 * %arg14 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<2xf32>, vector<16xf32> + func.return +} diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir --- a/mlir/test/Dialect/AMDGPU/ops.mlir +++ b/mlir/test/Dialect/AMDGPU/ops.mlir @@ -59,3 +59,76 @@ amdgpu.raw_buffer_atomic_fadd {boundsCheck = true, indexOffset = 1 : i32, targetIsRDNA = false} %value -> %dst[%idx0, %idx1, %idx2, %idx3] sgprOffset %offset : f32 -> memref<128x64x32x16xf32>, i32, i32, i32, i32 func.return } + +// CHECK-LABEL: func @mfma +func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>, %arg2 : vector<16xf32>, + %arg3 : vector<4xf32>, %arg4 : vector<4xf16>, + %arg5 : vector<4xi8>, %arg6 : vector<32xi32>, + %arg7 : vector<16xi32>, %arg8 : vector<4xi32>, + %arg9 : vector<2xbf16>, %arg10 : vector<4xbf16>, %arg11 : f64, + %arg12 : vector<4xf64>, %arg13 : vector<8xi8>, + %arg14 : vector<2xf32>) { + // CHECK: amdgpu.mfma f32_32x32x1f32 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : f32, vector<32xf32> + amdgpu.mfma f32_32x32x1f32 %arg0 * %arg0 + %arg1 cbsz = 0 abid = 0 blgp = 0 : f32, vector<32xf32> + // CHECK: amdgpu.mfma f32_16x16x1f32 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : f32, vector<16xf32> + amdgpu.mfma f32_16x16x1f32 %arg0 * %arg0 + %arg2 cbsz = 0 abid = 0 blgp = 0 : f32, vector<16xf32> + // CHECK: amdgpu.mfma f32_4x4x1f32 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : f32, vector<4xf32> + amdgpu.mfma f32_4x4x1f32 %arg0 * %arg0 + %arg3 cbsz = 0 abid = 0 blgp = 0 : f32, vector<4xf32> + // CHECK: amdgpu.mfma f32_32x32x2f32 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : f32, vector<16xf32> + amdgpu.mfma f32_32x32x2f32 %arg0 * %arg0 + %arg2 cbsz = 0 abid = 0 blgp = 0 : f32, vector<16xf32> + // CHECK: amdgpu.mfma f32_16x16x4f32 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : f32, vector<4xf32> + amdgpu.mfma f32_16x16x4f32 %arg0 * %arg0 + %arg3 cbsz = 0 abid = 0 blgp = 0 : f32, vector<4xf32> + // CHECK: amdgpu.mfma f32_32x32x4f16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<32xf32> + amdgpu.mfma f32_32x32x4f16 %arg4 * %arg4 + %arg1 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<32xf32> + // CHECK: amdgpu.mfma f32_16x16x4f16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<16xf32> + amdgpu.mfma f32_16x16x4f16 %arg4 * %arg4 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<16xf32> + // CHECK: amdgpu.mfma f32_4x4x4f16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<4xf32> + amdgpu.mfma f32_4x4x4f16 %arg4 * %arg4 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<4xf32> + // CHECK: amdgpu.mfma f32_32x32x8f16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<16xf32> + amdgpu.mfma f32_32x32x8f16 %arg4 * %arg4 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<16xf32> + // CHECK: amdgpu.mfma f32_16x16x16f16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<4xf32> + amdgpu.mfma f32_16x16x16f16 %arg4 * %arg4 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xf16>, vector<4xf32> + // CHECK: amdgpu.mfma i32_32x32x4i8 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<32xi32> + amdgpu.mfma i32_32x32x4i8 %arg5 * %arg5 + %arg6 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<32xi32> + // CHECK: amdgpu.mfma i32_16x16x4i8 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<16xi32> + amdgpu.mfma i32_16x16x4i8 %arg5 * %arg5 + %arg7 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<16xi32> + // CHECK: amdgpu.mfma i32_4x4x4i8 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<4xi32> + amdgpu.mfma i32_4x4x4i8 %arg5 * %arg5 + %arg8 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<4xi32> + // CHECK: amdgpu.mfma i32_32x32x8i8 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<16xi32> + amdgpu.mfma i32_32x32x8i8 %arg5 * %arg5 + %arg7 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<16xi32> + // CHECK: amdgpu.mfma i32_16x16x16i8 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<4xi32> + amdgpu.mfma i32_16x16x16i8 %arg5 * %arg5 + %arg8 cbsz = 0 abid = 0 blgp = 0 : vector<4xi8>, vector<4xi32> + // CHECK: amdgpu.mfma f32_32x32x2bf16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<32xf32> + amdgpu.mfma f32_32x32x2bf16 %arg9 * %arg9 + %arg1 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<32xf32> + // CHECK: amdgpu.mfma f32_16x16x2bf16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<16xf32> + amdgpu.mfma f32_16x16x2bf16 %arg9 * %arg9 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<16xf32> + // CHECK: amdgpu.mfma f32_4x4x2bf16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<4xf32> + amdgpu.mfma f32_4x4x2bf16 %arg9 * %arg9 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<4xf32> + // CHECK: amdgpu.mfma f32_32x32x4bf16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<16xf32> + amdgpu.mfma f32_32x32x4bf16 %arg9 * %arg9 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<16xf32> + // CHECK: amdgpu.mfma f32_16x16x8bf16 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<4xf32> + amdgpu.mfma f32_16x16x8bf16 %arg9 * %arg9 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<2xbf16>, vector<4xf32> + // CHECK: amdgpu.mfma f32_32x32x4bf16_1k %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<32xf32> + amdgpu.mfma f32_32x32x4bf16_1k %arg10 * %arg10 + %arg1 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<32xf32> + // CHECK: amdgpu.mfma f32_16x16x4bf16_1k %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<16xf32> + amdgpu.mfma f32_16x16x4bf16_1k %arg10 * %arg10 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<16xf32> + // CHECK: amdgpu.mfma f32_4x4x4bf16_1k %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<4xf32> + amdgpu.mfma f32_4x4x4bf16_1k %arg10 * %arg10 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<4xf32> + // CHECK: amdgpu.mfma f32_32x32x8bf16_1k %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<16xf32> + amdgpu.mfma f32_32x32x8bf16_1k %arg10 * %arg10 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<16xf32> + // CHECK: amdgpu.mfma f32_16x16x16bf16_1k %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<4xf32> + amdgpu.mfma f32_16x16x16bf16_1k %arg10 * %arg10 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<4xbf16>, vector<4xf32> + // CHECK: amdgpu.mfma f64_16x16x4f64 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : f64, vector<4xf64> + amdgpu.mfma f64_16x16x4f64 %arg11 * %arg11 + %arg12 cbsz = 0 abid = 0 blgp = 0 : f64, vector<4xf64> + // CHECK: amdgpu.mfma f64_4x4x4f64 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : f64, f64 + amdgpu.mfma f64_4x4x4f64 %arg11 * %arg11 + %arg11 cbsz = 0 abid = 0 blgp = 0 : f64, f64 + // CHECK: amdgpu.mfma i32_16x16x32_i8 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<8xi8>, vector<4xi32> + amdgpu.mfma i32_16x16x32_i8 %arg13 * %arg13 + %arg8 cbsz = 0 abid = 0 blgp = 0 : vector<8xi8>, vector<4xi32> + // CHECK: amdgpu.mfma i32_32x32x16_i8 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<8xi8>, vector<16xi32> + amdgpu.mfma i32_32x32x16_i8 %arg13 * %arg13 + %arg7 cbsz = 0 abid = 0 blgp = 0 : vector<8xi8>, vector<16xi32> + // CHECK: amdgpu.mfma f32_16x16x8_xf32 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<2xf32>, vector<4xf32> + amdgpu.mfma f32_16x16x8_xf32 %arg14 * %arg14 + %arg3 cbsz = 0 abid = 0 blgp = 0 : vector<2xf32>, vector<4xf32> + // CHECK: amdgpu.mfma f32_32x32x4_xf32 %{{.*}} * %{{.*}} + %{{.*}} cbsz = 0 abid = 0 blgp = 0 : vector<2xf32>, vector<16xf32> + amdgpu.mfma f32_32x32x4_xf32 %arg14 * %arg14 + %arg2 cbsz = 0 abid = 0 blgp = 0 : vector<2xf32>, vector<16xf32> + func.return +} diff --git a/mlir/test/Dialect/LLVMIR/rocdl.mlir b/mlir/test/Dialect/LLVMIR/rocdl.mlir --- a/mlir/test/Dialect/LLVMIR/rocdl.mlir +++ b/mlir/test/Dialect/LLVMIR/rocdl.mlir @@ -40,7 +40,9 @@ %arg4 : vector<16xf32>, %arg5 : vector<4xf32>, %arg6 : vector<4xf16>, %arg7 : vector<32xi32>, %arg8 : vector<16xi32>, %arg9 : vector<4xi32>, - %arg10 : vector<2xi16>) -> vector<32xf32> { + %arg10 : vector<2xi16>, %arg11 : vector<4xi16>, + %arg12 : vector<4xf64>, %arg13 : f64, + %arg14 : i64, %arg15 : vector<2xf32>) { // CHECK-LABEL: rocdl.xdlops // CHECK: rocdl.mfma.f32.32x32x1f32 {{.*}} : (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32> %r0 = rocdl.mfma.f32.32x32x1f32 %arg0, %arg1, %arg2, %arg3, %arg3, %arg3 : @@ -52,21 +54,21 @@ (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - %r2 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : - (f32, f32, vector<4xf32>, - i32, i32, i32) -> vector<4xf32> - // CHECK: rocdl.mfma.f32.4x4x1f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - %r3 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + %r2 = rocdl.mfma.f32.4x4x1f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x2f32 {{.*}} : (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - %r4= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : + %r3= rocdl.mfma.f32.32x32x2f32 %arg0, %arg1, %arg4, %arg3, %arg3, %arg3 : (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + // CHECK: rocdl.mfma.f32.16x16x4f32 {{.*}} : (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r4 = rocdl.mfma.f32.16x16x4f32 %arg0, %arg1, %arg5, %arg3, %arg3, %arg3 : + (f32, f32, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + // CHECK: rocdl.mfma.f32.32x32x4f16 {{.*}} : (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> %r5 = rocdl.mfma.f32.32x32x4f16 %arg6, %arg6, %arg2, %arg3, %arg3, %arg3 : (vector<4xf16>, vector<4xf16>, vector<32xf32>, @@ -142,7 +144,63 @@ (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - llvm.return %r0 : vector<32xf32> + + // CHECK: rocdl.mfma.f32.32x32x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> + %r20 = rocdl.mfma.f32.32x32x4bf16.1k %arg11, %arg11, %arg2, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<32xf32>, + i32, i32, i32) -> vector<32xf32> + + // CHECK: rocdl.mfma.f32.16x16x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r21 = rocdl.mfma.f32.16x16x4bf16.1k %arg11, %arg11, %arg4, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.mfma.f32.4x4x4bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r22 = rocdl.mfma.f32.4x4x4bf16.1k %arg11, %arg11, %arg5, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.mfma.f32.32x32x8bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r23 = rocdl.mfma.f32.32x32x8bf16.1k %arg11, %arg11, %arg4, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + // CHECK: rocdl.mfma.f32.16x16x16bf16.1k {{.*}} : (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r24 = rocdl.mfma.f32.16x16x16bf16.1k %arg11, %arg11, %arg5, %arg3, %arg3, %arg3 : + (vector<4xi16>, vector<4xi16>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.mfma.f64.16x16x4f64 {{.*}} : (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64> + %r25 = rocdl.mfma.f64.16x16x4f64 %arg13, %arg13, %arg12, %arg3, %arg3, %arg3 : + (f64, f64, vector<4xf64>, + i32, i32, i32) -> vector<4xf64> + + // CHECK: rocdl.mfma.f64.4x4x4f64 {{.*}} : (f64, f64, f64, i32, i32, i32) -> f64 + %r26 = rocdl.mfma.f64.4x4x4f64 %arg13, %arg13, %arg13, %arg3, %arg3, %arg3 : + (f64, f64, f64, + i32, i32, i32) -> f64 + + // CHECK: rocdl.mfma.i32.16x16x32.i8 {{.*}} : (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32> + %r27 = rocdl.mfma.i32.16x16x32.i8 %arg14, %arg14, %arg9, %arg3, %arg3, %arg3 : + (i64, i64, vector<4xi32>, + i32, i32, i32) -> vector<4xi32> + + // CHECK: rocdl.mfma.i32.32x32x16.i8 {{.*}} : (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32> + %r28 = rocdl.mfma.i32.32x32x16.i8 %arg14, %arg14, %arg8, %arg3, %arg3, %arg3 : + (i64, i64, vector<16xi32>, + i32, i32, i32) -> vector<16xi32> + + // CHECK: rocdl.mfma.f32.16x16x8.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> + %r29 = rocdl.mfma.f32.16x16x8.xf32 %arg15, %arg15, %arg5, %arg3, %arg3, %arg3 : + (vector<2xf32>, vector<2xf32>, vector<4xf32>, + i32, i32, i32) -> vector<4xf32> + + // CHECK: rocdl.mfma.f32.32x32x4.xf32 {{.*}} : (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> + %r30 = rocdl.mfma.f32.32x32x4.xf32 %arg15, %arg15, %arg4, %arg3, %arg3, %arg3 : + (vector<2xf32>, vector<2xf32>, vector<16xf32>, + i32, i32, i32) -> vector<16xf32> + + llvm.return } llvm.func @rocdl.mubuf(%rsrc : vector<4xi32>, %vindex : i32,