diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -17,18 +17,24 @@ #include "mlir/IR/Value.h" #include +namespace llvm { +class StringRef; +} // namespace llvm + namespace mlir { class LLVMTypeConverter; class ConversionTarget; class OpBuilder; class Location; class RewritePatternSet; +class Type; template class OperationPass; namespace gpu { class GPUModuleOp; +class MMAMatrixType; } // namespace gpu #define GEN_PASS_DECL_CONVERTGPUOPSTOROCDLOPS @@ -47,27 +53,50 @@ /// Generate ops to get the laneId of the current lane and return it. Value getLaneId(Location loc, OpBuilder b, unsigned indexBitwidth); + +/// Return the LLVM Type corresponding to the MMAMatrixType. +Type convertWMMAToROCDLLLVMType(gpu::MMAMatrixType matrixType); } // namespace amd /// Collect a set of patterns to convert from the GPU dialect to ROCDL. -/// If `runtime` is Unknown, gpu.printf will not be lowered -/// The resulting pattern set should be run over a gpu.module op -void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns, - gpu::amd::Runtime runtime); +/// If `runtime` is Unknown, gpu.printf will not be lowered. The resulting +/// pattern set should be run over a gpu.module op. `chipset` is the chip we are +/// targeting. `indexBitwidth` is the bitwidth to be used while convertind index +/// types. `warpSize` is the warp size to use when generating WMMA intrinsics. +/// `opSelect` is used in the lowering of f16 versions of WMMA ops involving `C` +/// operand. If `opSelect` is true upper half of the general purpose 32-bit +/// registers is used for storing the values; If false the lower half is used. +void populateGpuToROCDLConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, + gpu::amd::Runtime runtime, llvm::StringRef chipset = "gfx900", + unsigned indexBitwidth = 32, bool opSelec = false, unsigned warpSize = 32); /// Configure target to convert from the GPU dialect to ROCDL. void configureGpuToROCDLConversionLegality(ConversionTarget &target); /// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The /// index bitwidth used for the lowering of the device side index computations -/// is configurable. +/// is configurable. AMD gpus have a configurable warp size; valid choices are +/// 32 and 64. We choose 32 as the default size. `opSelect` is used in the +/// lowering of f16 versions of WMMA ops involving `C` operand. If `opSelect` is +/// true upper half of the general purpose 32-bit registers is used for storing +/// the values; If false the lower half is used. std::unique_ptr> createLowerGpuOpsToROCDLOpsPass( const std::string &chipset = "gfx900", unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout, bool useBarePtrCallConv = false, - gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown); + gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown, + bool opSelect = false, unsigned warpSize = 32); + +/// Collect a set of patterns to convert WMMA ops from GPU dialect to ROCDL. +/// `chipset` is the target chip for which the IR is being generated. +/// `warpSize` is the warp size to use when generating WMMA intrinsics. +void populateGpuWMMAToROCDLConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns, + llvm::StringRef chipset, + unsigned indexBitwidth, + bool opSelect, unsigned warpSize); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -488,23 +488,37 @@ /*default=*/"\"gfx000\"", "Chipset that these operations will run on">, Option<"indexBitwidth", "index-bitwidth", "unsigned", - /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + /*default=kDeriveIndexBitwidthFromDataLayout*/ "0", "Bitwidth of the index type, 0 to use size of machine word">, Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool", /*default=*/"false", "Replace memref arguments in GPU functions with bare pointers." "All memrefs must have static shape">, Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime", - "::mlir::gpu::amd::Runtime::Unknown", - "Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)", - [{::llvm::cl::values( - clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", "Unknown (default)"), - clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"), - clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", "OpenCL") - )}]>, + "::mlir::gpu::amd::Runtime::Unknown", + "Runtime code will be run on (default is Unknown, can also use HIP " + "or OpenCl)", + [{::llvm::cl::values( + clEnumValN(::mlir::gpu::amd::Runtime::Unknown, "unknown", + "Unknown (default)"), + clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"), + clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL", + "OpenCL"))}]>, Option<"useOpaquePointers", "use-opaque-pointers", "bool", - /*default=*/"true", "Generate LLVM IR using opaque pointers " - "instead of typed pointers">, + /*default=*/"true", + "Generate LLVM IR using opaque pointers " + "instead of typed pointers">, + Option< + "opSelect", "opselect", "bool", + /*default=*/"false", + "`opSelect` is used in the lowering of f16 versions of WMMA ops " + "involving `C` operand. If `opSelect` is true, the upper half of the " + "general purpose 32-bit registers is used for storing the values; " + "If false the lower half is used.">, + Option<"warpSize", "warp-size", "unsigned", + /*default=*/"32", + "AMD GPUs have a configurable warp size; valid choices are 32 and " + "64. 32 is used as the default size.">, ]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -61,5 +61,9 @@ add_mlir_dialect(ROCDLOps rocdl) add_mlir_doc(ROCDLOps ROCDLDialect Dialects/ -gen-dialect-doc -dialect=rocdl) set(LLVM_TARGET_DEFINITIONS ROCDLOps.td) +mlir_tablegen(ROCDLOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(ROCDLOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(ROCDLOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=rocdl) +mlir_tablegen(ROCDLOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=rocdl) mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRROCDLConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h @@ -28,6 +28,8 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Dialect/LLVMIR/ROCDLOpsEnums.h.inc" + ///// Ops ///// #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/ROCDLOps.h.inc" diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -14,6 +14,7 @@ #define ROCDLIR_OPS include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -212,6 +213,18 @@ "$args attr-dict `:` functional-type($args, $res)"; } +def ROCDLWMMAFragA : I32EnumAttrCase<"a", 0>; +def ROCDLWMMAFragB : I32EnumAttrCase<"b", 1>; +def ROCDLWMMAFragC : I32EnumAttrCase<"c", 2>; + +/// Enum attribute of the different frag types. +def ROCDLWMMAFrag + : I32EnumAttr<"ROCDLWMMAFrag", "ROCDL WMMA frag type", + [ROCDLWMMAFragA, ROCDLWMMAFragB, ROCDLWMMAFragC]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::ROCDL"; +} + // Available on RDNA3 def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16">; def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16">; diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms LowerGpuOpsToROCDLOps.cpp + WmmaOpsToROCDL.cpp DEPENDS MLIRConversionPassIncGen diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -24,13 +25,13 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/Pass/Pass.h" @@ -108,8 +109,8 @@ : public impl::ConvertGpuOpsToROCDLOpsBase { LowerGpuOpsToROCDLOpsPass() = default; LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, - bool useBarePtrCallConv, - gpu::amd::Runtime runtime) { + bool useBarePtrCallConv, gpu::amd::Runtime runtime, + bool opSelect, unsigned warpSize) { if (this->chipset.getNumOccurrences() == 0) this->chipset = chipset; if (this->indexBitwidth.getNumOccurrences() == 0) @@ -118,6 +119,8 @@ this->useBarePtrCallConv = useBarePtrCallConv; if (this->runtime.getNumOccurrences() == 0) this->runtime = runtime; + if (this->warpSize.getNumOccurrences() == 0) + this->warpSize = warpSize; } void runOnOperation() override { @@ -192,7 +195,9 @@ cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); - populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime); + populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime, + this->chipset, this->indexBitwidth, + this->opSelect, this->warpSize); LLVMConversionTarget target(getContext()); configureGpuToROCDLConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) @@ -245,9 +250,15 @@ void mlir::populateGpuToROCDLConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns, - mlir::gpu::amd::Runtime runtime) { + mlir::gpu::amd::Runtime runtime, StringRef chipset, unsigned indexBitwidth, + bool opSelect, unsigned warpSize) { using mlir::gpu::amd::Runtime; + // Lowering for MMAMatrixType. + converter.addConversion([&](gpu::MMAMatrixType type) -> Type { + return amd::convertWMMAToROCDLLLVMType(type); + }); + populateWithGenerated(patterns); patterns .add(converter); + /// Collect a set of patterns to convert WMMA ops from GPU dialect to ROCDL. + populateGpuWMMAToROCDLConversionPatterns(converter, patterns, chipset, + indexBitwidth, opSelect, warpSize); + populateOpPatterns(converter, patterns, "__ocml_fabs_f32", "__ocml_fabs_f64"); populateOpPatterns(converter, patterns, "__ocml_atan_f32", @@ -325,7 +340,8 @@ mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, bool useBarePtrCallConv, - gpu::amd::Runtime runtime) { + gpu::amd::Runtime runtime, bool opSelect, + unsigned warpSize) { return std::make_unique( - chipset, indexBitwidth, useBarePtrCallConv, runtime); + chipset, indexBitwidth, useBarePtrCallConv, runtime, opSelect, warpSize); } diff --git a/mlir/lib/Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp b/mlir/lib/Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToROCDL/WmmaOpsToROCDL.cpp @@ -0,0 +1,515 @@ +//===--------- WmmaOpsToROCDL.cpp - GPU WMMA ops to ROCDL lowering --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of patterns to lower GPU Subgroup MMA ops to +// ROCDL Dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; + +namespace { + +/// Checks if all the operands of the op being lowered are of LLVM Types. The +/// types are expected to be converted by the `LLVMTypeConverter` before the op +/// is actually lowered. If the type of an operands is not already converted it +/// hints a missing typeConversion and failure is returned in that case. +static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) { + return rewriter.notifyMatchFailure( + op, "cannot convert if operands aren't of LLVM type."); + } + + return success(); +} + +/// Return the WMMA operand corresponding to `operandName`. +static ROCDL::ROCDLWMMAFrag convertOperand(StringRef operandName) { + if (operandName.equals("AOp")) + return ROCDL::ROCDLWMMAFrag::a; + if (operandName.equals("BOp")) + return ROCDL::ROCDLWMMAFrag::b; + if (operandName.equals("COp")) + return ROCDL::ROCDLWMMAFrag::c; + llvm_unreachable("Unknown operand name"); +} + +/// Generate load ops for `AOp` or `BOp`. `dataPtr` is the base address starting +/// from which values will be loaded. `laneId` lane ID of the thread loading the +/// values. `vecType` is the vector type of the values that will be loaded. The +/// loaded values are returned in `loadedValues`. The address for loading the +/// values is generated in the following manner: +/// +/// wrappedLaneId = laneId % 16 +/// for i in vectorSize { +/// loadedValues[i] = dataPtr + ((wrappedLaneId * leadingDim) + i); +/// } +static void generateAbLoadOpsVecFirst(Location loc, Value dataPtr, Value laneId, + Value leadingDim, VectorType vecType, + PatternRewriter &rewriter, + Value &loadedValues) { + // We wrap the laneId to 16 because of matrix replication in RDNA 3. + Value wrapSize = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/16); + mlir::TypedAttr x; + Value wrappedLaneId = rewriter.create(loc, laneId, wrapSize); + loadedValues = rewriter.create(loc, vecType); + Value laneIdLdm = + rewriter.create(loc, wrappedLaneId, leadingDim); + for (unsigned i = 0; i < vecType.getNumElements(); ++i) { + Value iter = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/i); + Value curInx = rewriter.create(loc, laneIdLdm, iter); + Value curAddress = rewriter.create( + loc, dataPtr.getType(), vecType.getElementType(), dataPtr, curInx); + // Load the value from the current index. + Value loaded = rewriter.create(loc, vecType.getElementType(), + curAddress); + loadedValues = rewriter.create( + loc, vecType, loadedValues, loaded, iter); + } +} + +/// Generate load ops for `AOp` or `BOp`. `dataPtr` is the base address starting +/// from which values will be loaded. `laneId` is the lane ID of the thread +/// loading the values. `vecType` is the vector type of the values that will be +/// loaded. The loaded values are returned in `loadedValues`. The address for +/// loading the values is generated in the following manner: +/// +/// wrappedLaneId = laneId % 16 +/// for i in vectorSize { +/// loadedValues[i] = dataPtr + ((i * leadingDim) + wrappedLaneId); +/// } +static void generateAbLoadOpsLaneFirst(Location loc, Value dataPtr, + Value laneId, Value leadingDim, + VectorType vecType, + PatternRewriter &rewriter, + Value &loadedValues) { + // We wrap the laneId to 16 because of matrix replication in RDNA 3. + Value wrapSize = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/16); + Value wrappedLaneId = rewriter.create(loc, laneId, wrapSize); + loadedValues = rewriter.create(loc, vecType); + for (unsigned i = 0; i < vecType.getNumElements(); ++i) { + Value iter = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/i); + Value iterLdm = rewriter.create(loc, iter, leadingDim); + Value curInx = rewriter.create(loc, iterLdm, wrappedLaneId); + Value curAddress = rewriter.create( + loc, dataPtr.getType(), vecType.getElementType(), dataPtr, curInx); + // Load the value from the current index. + Value loaded = rewriter.create(loc, vecType.getElementType(), + curAddress); + loadedValues = rewriter.create( + loc, vecType, loadedValues, loaded, iter); + } +} + +/// Generate load ops for `COp`. `dataPtr` is the base address starting +/// from which values will be loaded. `laneId` is the lane ID of the +/// thread loading the values. `vecType` is the vector type of the values that +/// will be loaded. The loaded values are returned in `loadedValues`. The +/// address for loading the values is generated in the following manner: +/// +/// wrappedLaneId = laneId % 16 +/// for i in vectorSize { +/// row = i * 2 + (laneId / 16) +/// if opSelect +/// loadedValues[i * 2 + 1] = dataPtr + ((row * leadingDim) + +/// wrappedLaneId); +/// else +/// loadedValues[i * 2] = dataPtr + ((row * leadingDim) + wrappedLaneId); +/// } +static void generateCLoadOpsLaneFirst(bool opSelect, Location loc, + Value dataPtr, Value laneId, + Value leadingDim, VectorType vecType, + PatternRewriter &rewriter, + Value &loadedValues) { + // We wrap the laneId to 16 because of matrix replication in RDNA 3. + Value wrapSize = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/16); + Value wrappedLaneId = rewriter.create(loc, laneId, wrapSize); + loadedValues = rewriter.create(loc, vecType); + Value constTwo = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/2); + Value sixteen = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/16); + Value laneIdHalf = rewriter.create(loc, laneId, sixteen); + for (unsigned i = 0; i < vecType.getNumElements(); ++i) { + Value iter = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/i); + Value iterTwo = rewriter.create(loc, iter, constTwo); + Value row = rewriter.create(loc, iterTwo, laneIdHalf); + Value rowLdm = rewriter.create(loc, row, leadingDim); + Value curInx = rewriter.create(loc, rowLdm, wrappedLaneId); + Value curAddress = rewriter.create( + loc, dataPtr.getType(), vecType.getElementType(), dataPtr, curInx); + // Load the value from the current index. + Value loaded = rewriter.create(loc, vecType.getElementType(), + curAddress); + // We have to skip every second element if opselect is true. + Value inx = iter; + if (vecType.getElementType().isF16()) { + if (opSelect) { + Value constOne = + rewriter.create(loc, rewriter.getI32Type(), + /*value=*/1); + inx = rewriter.create(loc, iterTwo, constOne); + } else { + inx = iterTwo; + } + } + loadedValues = rewriter.create( + loc, vecType, loadedValues, loaded, inx); + } +} + +/// Generate load ops for `AOp`, `BOp`, or `COp`. `opSelect` is the opSelect bit +/// governing how to store/load half precision `COp` values. `transpose` tells +/// if the matrix has to be loaded in a transposed manner. `frag` is the type of +/// the WMMA operand being loaded. `dataPtr` is the base address starting from +/// which values will be loaded. `vecType` is the vector type of the values that +/// will be loaded. The loaded values are returned in `loadedValues`. +static LogicalResult generateLoadOps(bool opSelect, bool transpose, + Location loc, ROCDL::ROCDLWMMAFrag frag, + unsigned indexBitwidth, Value dataPtr, + Value leadingDim, VectorType vecType, + PatternRewriter &rewriter, + Value &loadedValues) { + Value laneId = amd::getLaneId(loc, rewriter, indexBitwidth); + Type eltType = vecType.getElementType(); + if (frag == ROCDL::ROCDLWMMAFrag::a && !transpose && eltType.isF16()) { + generateAbLoadOpsVecFirst(loc, dataPtr, laneId, leadingDim, vecType, + rewriter, loadedValues); + return success(); + } + if (frag == ROCDL::ROCDLWMMAFrag::a && transpose && eltType.isF16()) { + generateAbLoadOpsLaneFirst(loc, dataPtr, laneId, leadingDim, vecType, + rewriter, loadedValues); + return success(); + } + if (frag == ROCDL::ROCDLWMMAFrag::b && transpose && eltType.isF16()) { + generateAbLoadOpsVecFirst(loc, dataPtr, laneId, leadingDim, vecType, + rewriter, loadedValues); + return success(); + } + if (frag == ROCDL::ROCDLWMMAFrag::b && !transpose && eltType.isF16()) { + generateAbLoadOpsLaneFirst(loc, dataPtr, laneId, leadingDim, vecType, + rewriter, loadedValues); + return success(); + } + if (frag == ROCDL::ROCDLWMMAFrag::c && !transpose && + (eltType.isF32() || eltType.isF16())) { + generateCLoadOpsLaneFirst(opSelect, loc, dataPtr, laneId, leadingDim, + vecType, rewriter, loadedValues); + return success(); + } + + return failure(); +} + +/// This class implements the conversion of GPU MMA loadOp to wmma.load op +/// in the ROCDL dialect. The conversion not only emits the ROCDL op but also +/// emits code that is necessary to store the data in the destination memref +/// after it has been loaded. +struct WmmaLoadOpToROCDLLowering + : public ConvertOpToLLVMPattern { + WmmaLoadOpToROCDLLowering(LLVMTypeConverter &typeConverter, StringRef chip, + unsigned indexBitwidth, bool opSelect, + unsigned warpSize) + : ConvertOpToLLVMPattern(typeConverter), + indexBitwidth(indexBitwidth), warpSize(warpSize), opSelect(opSelect), + chip(chip){}; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(areAllLLVMTypes(subgroupMmaLoadMatrixOp.getOperation(), + adaptor.getOperands(), rewriter))) + return failure(); + + if (chip != "gfx1100") + return subgroupMmaLoadMatrixOp->emitError( + "wmma lowering is supported for gfx1100 only"); + + if (warpSize != amd::kWaveFrontSize32) + return subgroupMmaLoadMatrixOp->emitError( + "wavefront of size 32 only supported"); + + auto transpose = subgroupMmaLoadMatrixOp.getTranspose(); + gpu::MMAMatrixType retType = + subgroupMmaLoadMatrixOp.getRes().getType().cast(); + SmallVector retTypeShape(retType.getShape()); + + if (!llvm::all_of(retTypeShape, [](int dim) { return dim == 16; })) + return subgroupMmaLoadMatrixOp->emitError( + "wmma ops of shape 16x16x16 are only supported."); + + auto srcMemrefType = + subgroupMmaLoadMatrixOp.getSrcMemref().getType().cast(); + + if (srcMemrefType.getElementType() != retType.getElementType()) + return subgroupMmaLoadMatrixOp->emitError( + "src memref type and mma matrix element type must be same"); + + // Get the LLVM type of corresponding to the result MMAMatrixType. + Type llvmRetType = amd::convertWMMAToROCDLLLVMType(retType); + + // We need to declare a vector type and then emit instructions to load the + // elements into the vector type. + Location loc = subgroupMmaLoadMatrixOp.getLoc(); + Value dataPtr = + getStridedElementPtr(loc, srcMemrefType, adaptor.getSrcMemref(), + adaptor.getIndices(), rewriter); + + Value leadingDim = rewriter.create( + loc, rewriter.getI32Type(), + subgroupMmaLoadMatrixOp.getLeadDimensionAttr()); + + Value loadedValues; + ROCDL::ROCDLWMMAFrag operand = convertOperand(retType.getOperand()); + if (auto vecType = dyn_cast(llvmRetType)) { + if (failed(generateLoadOps(opSelect, + transpose.has_value() && transpose.value(), + loc, operand, indexBitwidth, dataPtr, + leadingDim, vecType, rewriter, loadedValues))) + return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp, + "unsupported load op variant."); + rewriter.replaceOp(subgroupMmaLoadMatrixOp, loadedValues); + return success(); + } + return rewriter.notifyMatchFailure(subgroupMmaLoadMatrixOp, + "unsupported load op variant."); + } + + /// Index bitwidth to use in any index calculation. + unsigned indexBitwidth; + + /// `warpSize` is the warp size to use when generating WMMA intrinsics. + unsigned warpSize; + + /// `opSelect` is used to decide whether to use lower half or upper half of + /// the 32-bit registers to use for storing half precision C operand. + bool opSelect; + + /// The target chip for which to generate the lowering. + std::string chip; +}; + +/// Generate store ops for `COp`. `dataPtr` is the base address starting +/// to which the values will be stored. `laneId` is the lane ID of the +/// thread loading the values. `vecType` is the vector type of the values that +/// are being stored. The values to be stored are supplied in `toStore`. The +/// address for storing the values is generated in the following manner: +/// +/// wrappedLaneId = laneId % 16 +/// for i in vectorSize { +/// row = i * 2 + (laneId / 16) +/// if opSelect +/// store toStore[i * 2 + 1], dataPtr + ((row * leadingDim) + wrappedLaneId) +/// else +/// store toStore[i * 2], dataPtr + ((row * leadingDim) + wrappedLaneId) +/// } +static void generateCStoreOpsLaneFirst(bool opSelect, Location loc, + Value dataPtr, Value laneId, + Value leadingDim, VectorType vecType, + Value toStore, + PatternRewriter &rewriter) { + // We wrap the laneId to 16 because of matrix replication in RDNA 3. + Value wrapSize = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/16); + Value wrappedLaneId = rewriter.create(loc, laneId, wrapSize); + Value constSixteen = + rewriter.create(loc, rewriter.getI32Type(), + /*value=*/16); + Value constTwo = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/2); + Value laneIdHalf = rewriter.create(loc, laneId, constSixteen); + for (int i = 0; i < vecType.getNumElements(); ++i) { + Value inx = rewriter.create(loc, rewriter.getI32Type(), + /*value=*/i); + Value inxTimesTwo = rewriter.create(loc, inx, constTwo); + Value row = rewriter.create(loc, laneIdHalf, inxTimesTwo); + Value rowLdm = rewriter.create(loc, row, leadingDim); + Value offset = rewriter.create(loc, rowLdm, wrappedLaneId); + Value storeAddress = rewriter.create( + loc, dataPtr.getType(), vecType.getElementType(), dataPtr, offset); + Value toStoreAtInx; + if (vecType.getElementType().isF16()) { + if (!opSelect) { + toStoreAtInx = rewriter.create( + loc, vecType.getElementType(), toStore, inxTimesTwo); + + } else { + Value constOne = + rewriter.create(loc, rewriter.getI32Type(), + /*value=*/1); + Value inxTimesTwoAddOne = + rewriter.create(loc, inxTimesTwo, constOne); + toStoreAtInx = rewriter.create( + loc, vecType.getElementType(), toStore, inxTimesTwoAddOne); + } + } else if (vecType.getElementType().isF32()) { + toStoreAtInx = rewriter.create( + loc, vecType.getElementType(), toStore, inx); + } + rewriter.create(loc, toStoreAtInx, storeAddress); + } +} + +/// Generate store ops for `COp`. `opSelect` is the opSelect bit governing how +/// to store half precision `COp` values. `frag` is the type of the WMMA +/// operand being stored. `dataPtr` is the base address starting from which +/// starting from which the values will be stored. `vecType` is the vector type +/// of the values being stored. `toStore` contains the values to be stored. +static LogicalResult generateStoreOps(bool opSelect, Location loc, + ROCDL::ROCDLWMMAFrag frag, Value dataPtr, + unsigned indexBitwidth, Value leadingDim, + VectorType vecType, Value toStore, + PatternRewriter &rewriter) { + // Store ops can only be generated for C operands. + if (frag != ROCDL::ROCDLWMMAFrag::c) + return emitError(toStore.getLoc(), "only COp can be stored"); + + // Get the laneID. + Value laneId = amd::getLaneId(loc, rewriter, indexBitwidth); + Type eltType = vecType.getElementType(); + if (eltType.isF16() || eltType.isF32()) { + generateCStoreOpsLaneFirst(opSelect, loc, dataPtr, laneId, leadingDim, + vecType, toStore, rewriter); + return success(); + } + + return failure(); +} + +/// This class implements the conversion of GPU MMA storeOp to wmma.store op +/// in the ROCDL dialect. +struct WmmaStoreOpToROCDLowering + : public ConvertOpToLLVMPattern { + WmmaStoreOpToROCDLowering(LLVMTypeConverter &typeConverter, StringRef chip, + unsigned indexBitwidth, bool opSelect, + unsigned warpSize) + : ConvertOpToLLVMPattern(typeConverter), + indexBitwidth(indexBitwidth), warpSize(warpSize), opSelect(opSelect), + chip(chip){}; + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(areAllLLVMTypes(subgroupMmaStoreMatrixOp.getOperation(), + adaptor.getOperands(), rewriter))) + return failure(); + + if (chip != "gfx1100") + return subgroupMmaStoreMatrixOp->emitError( + "wmma lowering is supported for gfx1100 only"); + + if (warpSize != amd::kWaveFrontSize32) + return subgroupMmaStoreMatrixOp->emitError( + "wavefront of size 32 only supported"); + + Location loc = subgroupMmaStoreMatrixOp->getLoc(); + + auto transpose = subgroupMmaStoreMatrixOp.getTranspose(); + if (transpose.has_value() && transpose.value()) + return subgroupMmaStoreMatrixOp->emitError( + "lowering with transpose is not supported."); + + gpu::MMAMatrixType retType = + subgroupMmaStoreMatrixOp.getSrc().getType().cast(); + SmallVector retTypeShape(retType.getShape()); + + if (!llvm::all_of(retTypeShape, [](int dim) { return dim == 16; })) + return subgroupMmaStoreMatrixOp->emitError( + "wmma ops of shape 16x16x16 are only supported."); + + auto dstMemrefType = + subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(); + + if (dstMemrefType.getElementType() != retType.getElementType()) + return subgroupMmaStoreMatrixOp->emitError( + "dst memref type and mma matrix element type must be same"); + + Value dataPtr = getStridedElementPtr( + loc, + subgroupMmaStoreMatrixOp.getDstMemref().getType().cast(), + adaptor.getDstMemref(), adaptor.getIndices(), rewriter); + Value leadingDim = rewriter.create( + loc, rewriter.getI32Type(), + subgroupMmaStoreMatrixOp.getLeadDimensionAttr()); + + // Get the LLVM type of corresponding to the result MMAMatrixType. + Type llvmRetType = amd::convertWMMAToROCDLLLVMType(retType); + + Value toStore = adaptor.getSrc(); + + if (auto vecType = dyn_cast(llvmRetType)) { + if (failed(generateStoreOps( + opSelect, loc, convertOperand(retType.getOperand()), dataPtr, + indexBitwidth, leadingDim, vecType, toStore, rewriter))) + return rewriter.notifyMatchFailure(subgroupMmaStoreMatrixOp, + "unsupported store op variant."); + } + rewriter.eraseOp(subgroupMmaStoreMatrixOp); + return success(); + } + + /// Index bitwidth to use in any index calculation. + unsigned indexBitwidth; + + /// `warpSize` is the warp size to use when generating WMMA intrinsics. + unsigned warpSize; + + /// `opSelect` is used to decide whether to use lower half or upper half of + /// the 32-bit registers to use for storing half precision C operand. + bool opSelect; + + /// The target chip for which to generate the lowering. + std::string chip; +}; +} // namespace + +// Convert the MMAMatrix type to LLVM types based of the elemental type of +// MMAMatrixType. +Type mlir::amd::convertWMMAToROCDLLLVMType( + mlir::gpu::MMAMatrixType matrixType) { + Type eltType = matrixType.getElementType(); + ROCDL::ROCDLWMMAFrag frag = convertOperand(matrixType.getOperand()); + if (eltType.isF16() && + (frag == ROCDL::ROCDLWMMAFrag::a || frag == ROCDL::ROCDLWMMAFrag::b || + frag == ROCDL::ROCDLWMMAFrag::c)) + return VectorType::get({16}, eltType); + if (eltType.isF32() && frag == ROCDL::ROCDLWMMAFrag::c) + return VectorType::get({8}, eltType); + + llvm_unreachable("Unsupported data type"); +} + +void mlir::populateGpuWMMAToROCDLConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns, StringRef chip, + unsigned indexBitwidth, bool opSelect, unsigned warpSize) { + patterns.add( + converter, chip, indexBitwidth, opSelect, warpSize); +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -35,6 +35,15 @@ option(MLIR_RUN_ARM_SVE_TESTS "Run Arm SVE tests.") option(MLIR_RUN_ARM_SME_TESTS "Run Arm SME tests.") + set(GFX_WMMA_TARGET "gfx1100") + + if (ROCM_TEST_CHIPSET STREQUAL ${GFX_WMMA_TARGET}) + message(STATUS "Enabling MLIR_RUN_ROCM_WMMA_TESTS") + set(MLIR_RUN_ROCM_WMMA_TESTS TRUE) + else () + message(STATUS "Disabling MLIR_RUN_ROCM_WMMA_TESTS") + set(MLIR_RUN_ROCM_WMMA_TESTS FALSE) + endif () # The native target may not be enabled when cross compiling, raise an error. if(NOT MLIR_ENABLE_EXECUTION_ENGINE) @@ -67,6 +76,7 @@ MLIR_INCLUDE_INTEGRATION_TESTS MLIR_RUN_AMX_TESTS MLIR_RUN_CUDA_TENSOR_CORE_TESTS + MLIR_RUN_ROCM_WMMA_TESTS MLIR_RUN_X86VECTOR_TESTS MLIR_RUN_ARM_SVE_TESTS MLIR_RUN_ARM_SME_TESTS diff --git a/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported.mlir b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl-unsupported.mlir @@ -0,0 +1,181 @@ +// This file tests the we error out properly when unsupported ops are +// encountered for GPU wmma ops to ROCDL conversion. +// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx1100 index-bitwidth=32' -split-input-file -verify-diagnostics + +gpu.module @main { + // CHECK-LABEL: load_a_op_16_16_16_no_transpose_invalid_shape + func.func @load_a_op_16_16_16_no_transpose()->(!gpu.mma_matrix<32x8xf16, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "AOp"> + // expected-error@-1 {{wmma ops of shape 16x16x16 are only supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<32x8xf16, "AOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_a_op_16_16_16_transpose_invalid_shape + func.func @load_a_op_16_16_16_transpose()->(!gpu.mma_matrix<32x8xf16, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "AOp"> + // expected-error@-1 {{wmma ops of shape 16x16x16 are only supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<32x8xf16, "AOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_a_op_16_16_16_no_transpose_invalid_types + func.func @load_a_op_16_16_16_no_transpose_invalid_types()->(!gpu.mma_matrix<16x16xf16, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf32, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + // expected-error@-1 {{src memref type and mma matrix element type must be same}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_b_op_16_16_16_no_transpose_invalid_shape + func.func @load_b_op_16_16_16_no_transpose()->(!gpu.mma_matrix<32x8xf16, "BOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "BOp"> + // expected-error@-1 {{wmma ops of shape 16x16x16 are only supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<32x8xf16, "BOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_b_op_16_16_16_transpose_invalid_shape + func.func @load_b_op_16_16_16_transpose()->(!gpu.mma_matrix<32x8xf16, "BOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "BOp"> + // expected-error@-1 {{wmma ops of shape 16x16x16 are only supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<32x8xf16, "BOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_b_op_16_16_16_no_transpose_invalid_types + func.func @load_b_op_16_16_16_no_transpose_invalid_types()->(!gpu.mma_matrix<16x16xf16, "BOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf32, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf16, "BOp"> + // expected-error@-1 {{src memref type and mma matrix element type must be same}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<16x16xf16, "BOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_c_op_16_16_16_no_transpose_invalid_shape + func.func @load_c_op_16_16_16_no_transpose()->(!gpu.mma_matrix<32x8xf16, "COp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "COp"> + // expected-error@-1 {{wmma ops of shape 16x16x16 are only supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<32x8xf16, "COp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_c_op_16_16_16_transpose_invalid_shape + func.func @load_c_op_16_16_16_transpose()->(!gpu.mma_matrix<32x8xf16, "COp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<32x8xf16, "COp"> + // expected-error@-1 {{wmma ops of shape 16x16x16 are only supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<32x8xf16, "COp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_c_op_16_16_16_no_transpose_invalid_types + func.func @load_c_op_16_16_16_no_transpose_invalid_types()->(!gpu.mma_matrix<16x16xf16, "COp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf32, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf16, "COp"> + // expected-error@-1 {{src memref type and mma matrix element type must be same}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_load_matrix' that was explicitly marked illegal}} + return %0 : !gpu.mma_matrix<16x16xf16, "COp"> + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: store_cop_f32 + func.func @store_cop_f32(%arg0: !gpu.mma_matrix<32x8xf32, "COp">) -> () { + %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<32x8xf32, "COp">, memref<32x32xf32, 3> + // expected-error@-1 {{wmma ops of shape 16x16x16 are only supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_store_matrix' that was explicitly marked illegal}} + return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: store_cop_f32 + func.func @store_cop_f32(%arg0: !gpu.mma_matrix<16x16xf32, "COp">) -> () { + %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index, transpose} : !gpu.mma_matrix<16x16xf32, "COp">, memref<32x32xf32, 3> + // expected-error@-1 {{lowering with transpose is not supported.}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_store_matrix' that was explicitly marked illegal}} + return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: store_cop_f32 + func.func @store_cop_f32(%arg0: !gpu.mma_matrix<16x16xf32, "COp">) -> () { + %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<32x32xf16, 3> + // expected-error@-1 {{dst memref type and mma matrix element type must be same}} + // expected-error@-2 {{failed to legalize operation 'gpu.subgroup_mma_store_matrix' that was explicitly marked illegal}} + return + } +} diff --git a/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToROCDL/wmma-ops-to-rocdl.mlir @@ -0,0 +1,443 @@ +// This file tests the conversion of GPU WMMA ops to ROCDL dialect. +// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx1100 index-bitwidth=32' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-rocdl='chipset=gfx1100 index-bitwidth=32 opselect' -split-input-file | FileCheck %s --check-prefix=OPSEL + +gpu.module @main { + // CHECK-LABEL: load_a_op_16_16_16_no_transpose + func.func @load_a_op_16_16_16_no_transpose()->(!gpu.mma_matrix<16x16xf16, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK-NEXT: %[[C0_I32:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[C_1_I32:.*]] = llvm.mlir.constant(-1 : i32) : i32 + // CHECK-NEXT: %[[MBCNT_LO:.*]] = rocdl.mbcnt.lo %[[C_1_I32]], %[[C0_I32]] : (i32, i32) -> i32 + // CHECK-NEXT: %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %[[C_1_I32]], %[[MBCNT_LO]] : (i32, i32) -> i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]] : i32 + // The part checked up to this point will be common in most of the WMMA op + // lowerings. Checking all of these lines will be skipped in the subsequent + // tests as the same utility emits the IR up to this point. Only some + // values which are used later will be matched. + // CHECK-NEXT: %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16> + // CHECK-NEXT: %[[WRAPPEDTID32:.*]] = llvm.mul %[[WRAPPEDTID]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C0]] : i32 + // CHECK-NEXT: %[[LOADADDR0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED0:.*]] = llvm.load %[[LOADADDR0]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C1]] : i32 + // CHECK-NEXT: %[[LOADADDR1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED1:.*]] = llvm.load %[[LOADADDR1]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16> + // We just check the loading and insertion of two values only, rest of the + // values need not be checked as they are emitted in a loop just with + // different parameters. + // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: %[[OFFSET15:.*]] = llvm.add %[[WRAPPEDTID32]], %{{.*}} : i32 + // CHECK-NEXT: %[[LOADADDR15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADEDVALS15:.*]] = llvm.load %[[LOADADDR15]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[LOADEDVALS15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16> + // CHECK-NEXT: llvm.return %[[RES]] : vector<16xf16> + return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_a_op_16_16_16_transpose + func.func @load_a_op_16_16_16_transpose()->(!gpu.mma_matrix<16x16xf16, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[WRAPPEDTID:.*]] = llvm.srem %{{.*}}, {{.*}} : i32 + // CHECK-NEXT: %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16> + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[ROW0:.*]] = llvm.mul %[[C0]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROW0]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ROW1:.*]] = llvm.mul %[[C1]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROW1]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16> + // We just check the loading and insertion of two values only, rest of the + // values need not be checked as they are emitted in a loop just with + // different parameters. + // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK-NEXT: %[[ROW15:.*]] = llvm.mul %[[C15]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET15:.*]] = llvm.add %[[ROW15]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED15:.*]] = llvm.load %[[ADDRESS15]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[LOADED15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16> + // CHECK-NEXT: llvm.return %[[RES]] : vector<16xf16> + return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_b_op_16_16_16_no_transpose + func.func @load_b_op_16_16_16_no_transpose()->(!gpu.mma_matrix<16x16xf16, "BOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp"> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[WRAPPEDTID:.*]] = llvm.srem %{{.*}}, {{.*}} : i32 + // CHECK: %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16> + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[ROW0:.*]] = llvm.mul %[[C0]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROW0]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ROW1:.*]] = llvm.mul %[[C1]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROW1]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16> + // We just check the loading and insertion of two values only, rest of the + // values need not be checked as they are emitted in a loop just with + // different parameters. + // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK-NEXT: %[[ROW15:.*]] = llvm.mul %[[C15]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET15:.*]] = llvm.add %[[ROW15]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED15:.*]] = llvm.load %[[ADDRESS15]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[LOADED15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16> + // CHECK-NEXT: llvm.return %[[RES]] : vector<16xf16> + return %0 : !gpu.mma_matrix<16x16xf16, "BOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_b_op_16_16_16_transpose + func.func @load_b_op_16_16_16_transpose()->(!gpu.mma_matrix<16x16xf16, "BOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "BOp"> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[WRAPPEDTID:.*]] = llvm.srem %{{.*}}, %{{.*}} : i32 + // CHECK-NEXT: %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16> + // CHECK-NEXT: %[[WRAPPEDTID32:.*]] = llvm.mul %[[WRAPPEDTID]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C0]] : i32 + // CHECK-NEXT: %[[LOADADDR0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED0:.*]] = llvm.load %[[LOADADDR0]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<16xf16> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[WRAPPEDTID32]], %[[C1]] : i32 + // CHECK-NEXT: %[[LOADADDR1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED1:.*]] = llvm.load %[[LOADADDR1]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[C1]] : i32] : vector<16xf16> + // We just check the loading and insertion of two values only, rest of the + // values need not be checked as they are emitted in a loop just with + // different parameters. + // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: %[[OFFSET15:.*]] = llvm.add %[[WRAPPEDTID32]], %{{.*}} : i32 + // CHECK-NEXT: %[[LOADADDR15:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET15]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADEDVALS15:.*]] = llvm.load %[[LOADADDR15]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[LOADEDVALS15]], %{{.*}}[%[[C15]] : i32] : vector<16xf16> + // CHECK-NEXT: llvm.return %[[RES]] : vector<16xf16> + return %0 : !gpu.mma_matrix<16x16xf16, "BOp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_c_op_16_16_16_no_opselect + func.func @load_c_op_16_16_16_no_opselect()->(!gpu.mma_matrix<16x16xf32, "COp">) { + %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg_1[%i, %j] {leadDimension = 32 : index} : memref<32x32xf32, 3> -> !gpu.mma_matrix<16x16xf32, "COp"> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<8xf32> + // CHECK-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW0:.*]] = llvm.add %[[ITER0]], %[[WTIDDIV16]] : i32 + // CHECK-NEXT: %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 + // CHECK-NEXT: %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f32 + // CHECK-NEXT: %[[LOADEDVAL0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[C0]] : i32] : vector<8xf32> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW1:.*]] = llvm.add %[[ITER1]], %[[WTIDDIV16]] : i32 + // CHECK-NEXT: %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 + // CHECK-NEXT: %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f32 + // CHECK-NEXT: %[[LOADEDVAL1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVAL0]][%[[C1]] : i32] : vector<8xf32> + // We just check the loading and insertion of two values only, rest of the + // values need not be checked as they are emitted in a loop just with + // different parameters. + // CHECK: %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: %[[RES:.*]] = llvm.insertelement %{{.*}}, %{{.*}}[%[[C7]] : i32] : vector<8xf32> + // CHECK-NEXT: llvm.return %[[RES]] : vector<8xf32> + return %0 : !gpu.mma_matrix<16x16xf32, "COp"> + } +} + +// ----- + +gpu.module @main { + // CHECK-LABEL: load_c_op_16_16_16_no_opselect + func.func @load_c_op_16_16_16_no_opselect()->(!gpu.mma_matrix<16x16xf16, "COp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16> + // CHECK-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW0:.*]] = llvm.add %[[ITER0]], %[[WTIDDIV16]] : i32 + // CHECK-NEXT: %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[ITER0]] : i32] : vector<16xf16> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW1:.*]] = llvm.add %[[ITER1]], %[[WTIDDIV16]] : i32 + // CHECK-NEXT: %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16 + // CHECK-NEXT: %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[ITER1]] : i32] : vector<16xf16> + // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: %[[RES:.*]] = llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<16xf16> + // CHECK-NEXT: llvm.return %[[RES]] : vector<16xf16> + return %0 : !gpu.mma_matrix<16x16xf16, "COp"> + } +} + +// ----- + +gpu.module @main { + // OPSEL-LABEL: load_c_op_16_16_16_opselect + func.func @load_c_op_16_16_16_opselect()->(!gpu.mma_matrix<16x16xf16, "COp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "COp"> + // OPSEL: llvm.mlir.constant(32 : index) : i32 + // OPSEL: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // OPSEL: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // OPSEL: %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // OPSEL-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // OPSEL-NEXT: %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]] : i32 + // OPSEL-NEXT: %[[LOADEDVALS:.*]] = llvm.mlir.undef : vector<16xf16> + // OPSEL-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // OPSEL-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // OPSEL-NEXT: %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]] : i32 + // OPSEL-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // OPSEL-NEXT: %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32 + // OPSEL-NEXT: %[[ROW0:.*]] = llvm.add %[[ITER0]], %[[WTIDDIV16]] : i32 + // OPSEL-NEXT: %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]] : i32 + // OPSEL-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]] : i32 + // OPSEL-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // OPSEL-NEXT: %[[LOADED0:.*]] = llvm.load %[[ADDRESS0]] : !llvm.ptr<3> -> f16 + // OPSEL-NEXT: %[[C1C:.*]] = llvm.mlir.constant(1 : i32) : i32 + // OPSEL-NEXT: %[[VECOFFSET0:.*]] = llvm.add %[[ITER0]], %[[C1C]] : i32 + // OPSEL-NEXT: %[[LOADEDVALS0:.*]] = llvm.insertelement %[[LOADED0]], %[[LOADEDVALS]][%[[VECOFFSET0]] : i32] : vector<16xf16> + // OPSEL-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // OPSEL-NEXT: %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32 + // OPSEL-NEXT: %[[ROW1:.*]] = llvm.add %[[ITER1]], %[[WTIDDIV16]] : i32 + // OPSEL-NEXT: %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]] : i32 + // OPSEL-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]] : i32 + // OPSEL-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // OPSEL-NEXT: %[[LOADED1:.*]] = llvm.load %[[ADDRESS1]] : !llvm.ptr<3> -> f16 + // OPSEL-NEXT: %[[C1C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // OPSEL-NEXT: %[[VECOFFSET1:.*]] = llvm.add %[[ITER1]], %[[C1C1]] : i32 + // OPSEL-NEXT: %[[LOADEDVALS1:.*]] = llvm.insertelement %[[LOADED1]], %[[LOADEDVALS0]][%[[VECOFFSET1]] : i32] : vector<16xf16> + // OPSEL: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // OPSEL: %[[RES:.*]] = llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : i32] : vector<16xf16> + // OPSEL-NEXT: llvm.return %[[RES]] : vector<16xf16> + return %0 : !gpu.mma_matrix<16x16xf16, "COp"> + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: store_cop_f32 + // CHECK-SAME: (%[[SRC:.*]]: vector<8xf32>) + func.func @store_cop_f32(%arg0: !gpu.mma_matrix<16x16xf32, "COp">) -> () { + %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf32, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<32x32xf32, 3> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-NEXT: %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW0:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER0]] : i32 + // CHECK-NEXT: %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 + // CHECK-NEXT: %[[ELE0:.*]] = llvm.extractelement %[[SRC]][%[[C0]] : i32] : vector<8xf32> + // CHECK-NEXT: llvm.store %[[ELE0]], %[[ADDRESS0]] : f32, !llvm.ptr<3> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW1:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER1]] : i32 + // CHECK-NEXT: %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32 + // CHECK-NEXT: %[[ELE1:.*]] = llvm.extractelement %[[SRC]][%[[C1]] : i32] : vector<8xf32> + // CHECK-NEXT: llvm.store %[[ELE1]], %[[ADDRESS1]] : f32, !llvm.ptr<3> + // CHECK: %[[C7:.*]] = llvm.mlir.constant(7 : i32) : i32 + // CHECK: %[[ELE7:.*]] = llvm.extractelement %[[SRC]][%[[C7]] : i32] : vector<8xf32> + // CHECK-NEXT: llvm.store %[[ELE7]], %{{.*}} : f32, !llvm.ptr<3> + return + } +} + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: store_cop_f16_no_opsel + // CHECK-SAME: (%[[SRC:.*]]: vector<16xf16>) + func.func @store_cop_f16_no_opsel(%arg0: !gpu.mma_matrix<16x16xf16, "COp">) -> () { + %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> + // CHECK: llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // CHECK-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK-NEXT: %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]] : i32 + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW0:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER0]] : i32 + // CHECK-NEXT: %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[ELE0:.*]] = llvm.extractelement %[[SRC]][%[[ITER0]] : i32] : vector<16xf16> + // CHECK-NEXT: llvm.store %[[ELE0]], %[[ADDRESS0]] : f16, !llvm.ptr<3> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32 + // CHECK-NEXT: %[[ROW1:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER1]] : i32 + // CHECK-NEXT: %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]] : i32 + // CHECK-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]] : i32 + // CHECK-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[ELE1:.*]] = llvm.extractelement %[[SRC]][%[[ITER1]] : i32] : vector<16xf16> + // CHECK-NEXT: llvm.store %[[ELE1]], %[[ADDRESS1]] : f16, !llvm.ptr<3> + // CHECK: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // CHECK: %[[ITER15:.*]] = llvm.mul %[[C15]], %[[C2]] : i32 + // CHECK: %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // CHECK-NEXT: %[[ELE15:.*]] = llvm.extractelement %[[SRC]][%[[ITER15]] : i32] : vector<16xf16> + // CHECK-NEXT: llvm.store %[[ELE15]], %[[ADDRESS15]] : f16, !llvm.ptr<3> + return + } +} + +// ----- + +gpu.module @test_module { + // OPSEL-LABEL: store_cop_f16_opsel + // OPSEL-SAME: (%[[SRC:.*]]: vector<16xf16>) + func.func @store_cop_f16_opsel(%arg0: !gpu.mma_matrix<16x16xf16, "COp">) -> () { + %wg_1 = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %wg_1[%i, %j] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> + // OPSEL: llvm.mlir.constant(32 : index) : i32 + // OPSEL: %[[BASE:.*]] = llvm.getelementptr %{{.*}} : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // OPSEL: %[[C32_0:.*]] = llvm.mlir.constant(32 : index) : i32 + // OPSEL: %[[WARPLOCALTID:.*]] = rocdl.mbcnt.hi %{{.*}}, %{{.*}} : (i32, i32) -> i32 + // OPSEL-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // OPSEL-NEXT: %[[WRAPPEDTID:.*]] = llvm.srem %[[WARPLOCALTID]], %[[C16]] : i32 + // OPSEL-NEXT: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32 + // OPSEL-NEXT: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // OPSEL-NEXT: %[[WTIDDIV16:.*]] = llvm.sdiv %[[WARPLOCALTID]], %[[C16]] : i32 + // OPSEL-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // OPSEL-NEXT: %[[ITER0:.*]] = llvm.mul %[[C0]], %[[C2]] : i32 + // OPSEL-NEXT: %[[ROW0:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER0]] : i32 + // OPSEL-NEXT: %[[ROWLDM0:.*]] = llvm.mul %[[ROW0]], %[[C32_0]] : i32 + // OPSEL-NEXT: %[[OFFSET0:.*]] = llvm.add %[[ROWLDM0]], %[[WRAPPEDTID]] : i32 + // OPSEL-NEXT: %[[ADDRESS0:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET0]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // OPSEL-NEXT: %[[C01:.*]] = llvm.mlir.constant(1 : i32) : i32 + // OPSEL-NEXT: %[[INX0:.*]] = llvm.add %[[ITER0]], %[[C01]] : i32 + // OPSEL-NEXT: %[[ELE0:.*]] = llvm.extractelement %[[SRC]][%[[INX0]] : i32] : vector<16xf16> + // OPSEL-NEXT: llvm.store %[[ELE0]], %[[ADDRESS0]] : f16, !llvm.ptr<3> + // OPSEL-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // OPSEL-NEXT: %[[ITER1:.*]] = llvm.mul %[[C1]], %[[C2]] : i32 + // OPSEL-NEXT: %[[ROW1:.*]] = llvm.add %[[WTIDDIV16]], %[[ITER1]] : i32 + // OPSEL-NEXT: %[[ROWLDM1:.*]] = llvm.mul %[[ROW1]], %[[C32_0]] : i32 + // OPSEL-NEXT: %[[OFFSET1:.*]] = llvm.add %[[ROWLDM1]], %[[WRAPPEDTID]] : i32 + // OPSEL-NEXT: %[[ADDRESS1:.*]] = llvm.getelementptr %[[BASE]][%[[OFFSET1]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // OPSEL-NEXT: %[[C11:.*]] = llvm.mlir.constant(1 : i32) : i32 + // OPSEL-NEXT: %[[INX1:.*]] = llvm.add %[[ITER1]], %[[C11]] : i32 + // OPSEL-NEXT: %[[ELE1:.*]] = llvm.extractelement %[[SRC]][%[[INX1]] : i32] : vector<16xf16> + // OPSEL-NEXT: llvm.store %[[ELE1]], %[[ADDRESS1]] : f16, !llvm.ptr<3> + // OPSEL: %[[C15:.*]] = llvm.mlir.constant(15 : i32) : i32 + // OPSEL-NEXT: %[[ITER15:.*]] = llvm.mul %[[C15]], %[[C2]] : i32 + // OPSEL: %[[ADDRESS15:.*]] = llvm.getelementptr %[[BASE]][%{{.*}}] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 + // OPSEL-NEXT: %[[C151:.*]] = llvm.mlir.constant(1 : i32) : i32 + // OPSEL-NEXT: %[[INX15:.*]] = llvm.add %[[ITER15]], %[[C151]] : i32 + // OPSEL-NEXT: %[[ELE15:.*]] = llvm.extractelement %[[SRC]][%[[INX15]] : i32] : vector<16xf16> + // OPSEL-NEXT: llvm.store %[[ELE15]], %[[ADDRESS15]] : f16, !llvm.ptr<3> + return + } +} diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/lit.local.cfg b/mlir/test/Integration/GPU/ROCM/WMMA/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/ROCM/WMMA/lit.local.cfg @@ -0,0 +1,5 @@ +import sys + +# WMMA tests must be enabled via build flag. +if not config.mlir_run_rocm_wmma_tests: + config.unsupported = True diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -convert-scf-to-cf \ +// RUN: | mlir-opt -gpu-kernel-outlining \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts,gpu-to-hsaco{chip=%chip}))' \ +// RUN: | mlir-opt -gpu-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_rocm_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +func.func @main() { + %0 = memref.alloc() : memref<16x16xf16> + %22 = memref.alloc() : memref<16x16xf16> + + %f1 = arith.constant 1.0e+00 : f16 + %f0 = arith.constant 0.0e+00 : f16 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + + // Intialize the Input matrix with ones. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %cast_c = arith.index_cast %arg1 : index to i16 + %cast_r = arith.index_cast %arg0 : index to i16 + %add = arith.addi %cast_r, %cast_c : i16 + %float = arith.sitofp %add : i16 to f16 + memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16> + } + } + // Intialize the accumulator matrix with zeros. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf16> + } + } + + %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16> + + %stream = gpu.wait async + %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16> + %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf16> + + %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16> + %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf16>, memref<16x16xf16> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> + %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> + + %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + + gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> + gpu.terminator + } + + %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf16>, memref<16x16xf16> + gpu.wait [%stream] + + %res_f32 = memref.alloc() : memref<16x16xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %load = memref.load %gpu_out[%arg0, %arg1] : memref<16x16xf16> + %ext = arith.extf %load : f16 to f32 + memref.store %ext, %res_f32[%arg0, %arg1] : memref<16x16xf32> + } + } + %res_f32_cast = memref.cast %res_f32 : memref<16x16xf32> to memref<*xf32> + + // Print the memref after computation. + call @printMemrefF32(%res_f32_cast) : (memref<*xf32>) -> () + // CHECK: [1240, 1360, 1480, 1600, 1720, 1840, 1960, 2080, 2200, 2320, 2440, 2560, 2680, 2800, 2920, 3040], + // CHECK-NEXT: [1360, 1496, 1632, 1768, 1904, 2040, 2176, 2312, 2448, 2584, 2720, 2856, 2992, 3128, 3264, 3400], + // CHECK-NEXT: [1480, 1632, 1784, 1936, 2088, 2240, 2392, 2544, 2696, 2848, 3000, 3152, 3304, 3456, 3608, 3760], + // CHECK-NEXT: [1600, 1768, 1936, 2104, 2272, 2440, 2608, 2776, 2944, 3112, 3280, 3448, 3616, 3784, 3952, 4120], + // CHECK-NEXT: [1720, 1904, 2088, 2272, 2456, 2640, 2824, 3008, 3192, 3376, 3560, 3744, 3928, 4112, 4296, 4480], + // CHECK-NEXT: [1840, 2040, 2240, 2440, 2640, 2840, 3040, 3240, 3440, 3640, 3840, 4040, 4240, 4440, 4640, 4840], + // CHECK-NEXT: [1960, 2176, 2392, 2608, 2824, 3040, 3256, 3472, 3688, 3904, 4120, 4336, 4552, 4768, 4984, 5200], + // CHECK-NEXT: [2080, 2312, 2544, 2776, 3008, 3240, 3472, 3704, 3936, 4168, 4400, 4632, 4864, 5100, 5328, 5556], + // CHECK-NEXT: [2200, 2448, 2696, 2944, 3192, 3440, 3688, 3936, 4184, 4432, 4680, 4928, 5172, 5424, 5676, 5920], + // CHECK-NEXT: [2320, 2584, 2848, 3112, 3376, 3640, 3904, 4168, 4432, 4696, 4960, 5228, 5488, 5748, 6016, 6284], + // CHECK-NEXT: [2440, 2720, 3000, 3280, 3560, 3840, 4120, 4400, 4680, 4960, 5236, 5520, 5804, 6080, 6356, 6640], + // CHECK-NEXT: [2560, 2856, 3152, 3448, 3744, 4040, 4336, 4632, 4928, 5228, 5520, 5812, 6112, 6412, 6704, 6996], + // CHECK-NEXT: [2680, 2992, 3304, 3616, 3928, 4240, 4552, 4864, 5172, 5488, 5804, 6112, 6420, 6736, 7052, 7360], + // CHECK-NEXT: [2800, 3128, 3456, 3784, 4112, 4440, 4768, 5100, 5424, 5748, 6080, 6412, 6736, 7060, 7392, 7724], + // CHECK-NEXT: [2920, 3264, 3608, 3952, 4296, 4640, 4984, 5328, 5676, 6016, 6356, 6704, 7052, 7392, 7732, 8080], + // CHECK-NEXT: [3040, 3400, 3760, 4120, 4480, 4840, 5200, 5556, 5920, 6284, 6640, 6996, 7360, 7724, 8080, 8440] + return +} + +func.func private @printMemrefF32(memref<*xf32>) diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_opselect.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_opselect.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f16_16_16_16_f16_opselect.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -convert-scf-to-cf \ +// RUN: | mlir-opt -gpu-kernel-outlining \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts,gpu-to-hsaco{chip=%chip}))' \ +// RUN: | mlir-opt -gpu-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_rocm_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +func.func @main() { + %0 = memref.alloc() : memref<16x16xf16> + %22 = memref.alloc() : memref<16x16xf16> + + %f1 = arith.constant 1.0e+00 : f16 + %f0 = arith.constant 0.0e+00 : f16 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + + // Intialize the Input matrix with ones. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %cast_c = arith.index_cast %arg1 : index to i16 + %cast_r = arith.index_cast %arg0 : index to i16 + %add = arith.addi %cast_r, %cast_c : i16 + %float = arith.sitofp %add : i16 to f16 + memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16> + } + } + // Intialize the accumulator matrix with zeros. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf16> + } + } + + %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16> + + %stream = gpu.wait async + %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16> + %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf16> + + %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16> + %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf16>, memref<16x16xf16> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> + %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> + + %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> + + gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16> + gpu.terminator + } + + %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf16>, memref<16x16xf16> + gpu.wait [%stream] + + %res_f32 = memref.alloc() : memref<16x16xf32> + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %load = memref.load %gpu_out[%arg0, %arg1] : memref<16x16xf16> + %ext = arith.extf %load : f16 to f32 + memref.store %ext, %res_f32[%arg0, %arg1] : memref<16x16xf32> + } + } + %res_f32_cast = memref.cast %res_f32 : memref<16x16xf32> to memref<*xf32> + + // Print the memref after computation. + call @printMemrefF32(%res_f32_cast) : (memref<*xf32>) -> () + // CHECK: [1240, 1360, 1480, 1600, 1720, 1840, 1960, 2080, 2200, 2320, 2440, 2560, 2680, 2800, 2920, 3040], + // CHECK-NEXT: [1360, 1496, 1632, 1768, 1904, 2040, 2176, 2312, 2448, 2584, 2720, 2856, 2992, 3128, 3264, 3400], + // CHECK-NEXT: [1480, 1632, 1784, 1936, 2088, 2240, 2392, 2544, 2696, 2848, 3000, 3152, 3304, 3456, 3608, 3760], + // CHECK-NEXT: [1600, 1768, 1936, 2104, 2272, 2440, 2608, 2776, 2944, 3112, 3280, 3448, 3616, 3784, 3952, 4120], + // CHECK-NEXT: [1720, 1904, 2088, 2272, 2456, 2640, 2824, 3008, 3192, 3376, 3560, 3744, 3928, 4112, 4296, 4480], + // CHECK-NEXT: [1840, 2040, 2240, 2440, 2640, 2840, 3040, 3240, 3440, 3640, 3840, 4040, 4240, 4440, 4640, 4840], + // CHECK-NEXT: [1960, 2176, 2392, 2608, 2824, 3040, 3256, 3472, 3688, 3904, 4120, 4336, 4552, 4768, 4984, 5200], + // CHECK-NEXT: [2080, 2312, 2544, 2776, 3008, 3240, 3472, 3704, 3936, 4168, 4400, 4632, 4864, 5100, 5328, 5556], + // CHECK-NEXT: [2200, 2448, 2696, 2944, 3192, 3440, 3688, 3936, 4184, 4432, 4680, 4928, 5172, 5424, 5676, 5920], + // CHECK-NEXT: [2320, 2584, 2848, 3112, 3376, 3640, 3904, 4168, 4432, 4696, 4960, 5228, 5488, 5748, 6016, 6284], + // CHECK-NEXT: [2440, 2720, 3000, 3280, 3560, 3840, 4120, 4400, 4680, 4960, 5236, 5520, 5804, 6080, 6356, 6640], + // CHECK-NEXT: [2560, 2856, 3152, 3448, 3744, 4040, 4336, 4632, 4928, 5228, 5520, 5812, 6112, 6412, 6704, 6996], + // CHECK-NEXT: [2680, 2992, 3304, 3616, 3928, 4240, 4552, 4864, 5172, 5488, 5804, 6112, 6420, 6736, 7052, 7360], + // CHECK-NEXT: [2800, 3128, 3456, 3784, 4112, 4440, 4768, 5100, 5424, 5748, 6080, 6412, 6736, 7060, 7392, 7724], + // CHECK-NEXT: [2920, 3264, 3608, 3952, 4296, 4640, 4984, 5328, 5676, 6016, 6356, 6704, 7052, 7392, 7732, 8080], + // CHECK-NEXT: [3040, 3400, 3760, 4120, 4480, 4840, 5200, 5556, 5920, 6284, 6640, 6996, 7360, 7724, 8080, 8440] + return +} + +func.func private @printMemrefF32(memref<*xf32>) diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -convert-scf-to-cf \ +// RUN: | mlir-opt -gpu-kernel-outlining \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts,gpu-to-hsaco{chip=%chip}))' \ +// RUN: | mlir-opt -gpu-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_rocm_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +func.func @main() { + %0 = memref.alloc() : memref<16x16xf16> + %22 = memref.alloc() : memref<16x16xf32> + + %f1 = arith.constant 1.0e+00 : f16 + %f0 = arith.constant 0.0e+00 : f32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + + // Intialize the Input matrix with ones. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %cast_c = arith.index_cast %arg1 : index to i16 + %cast_r = arith.index_cast %arg0 : index to i16 + %add = arith.addi %cast_r, %cast_c : i16 + %float = arith.sitofp %add : i16 to f16 + memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16> + } + } + // Intialize the accumulator matrix with zeros. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf32> + } + } + + %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16> + %33 = memref.cast %22 : memref<16x16xf32> to memref<*xf32> + + %stream = gpu.wait async + %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16> + %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf32> + + %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16> + %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf32>, memref<16x16xf32> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> + %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp"> + + %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp"> + + gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32> + gpu.terminator + } + + %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf32>, memref<16x16xf32> + gpu.wait [%stream] + + // Print the memref after computation. + call @printMemrefF32(%33) : (memref<*xf32>) -> () + // CHECK: [1240, 1360, 1480, 1600, 1720, 1840, 1960, 2080, 2200, 2320, 2440, 2560, 2680, 2800, 2920, 3040], + // CHECK-NEXT: [1360, 1496, 1632, 1768, 1904, 2040, 2176, 2312, 2448, 2584, 2720, 2856, 2992, 3128, 3264, 3400], + // CHECK-NEXT: [1480, 1632, 1784, 1936, 2088, 2240, 2392, 2544, 2696, 2848, 3000, 3152, 3304, 3456, 3608, 3760], + // CHECK-NEXT: [1600, 1768, 1936, 2104, 2272, 2440, 2608, 2776, 2944, 3112, 3280, 3448, 3616, 3784, 3952, 4120], + // CHECK-NEXT: [1720, 1904, 2088, 2272, 2456, 2640, 2824, 3008, 3192, 3376, 3560, 3744, 3928, 4112, 4296, 4480], + // CHECK-NEXT: [1840, 2040, 2240, 2440, 2640, 2840, 3040, 3240, 3440, 3640, 3840, 4040, 4240, 4440, 4640, 4840], + // CHECK-NEXT: [1960, 2176, 2392, 2608, 2824, 3040, 3256, 3472, 3688, 3904, 4120, 4336, 4552, 4768, 4984, 5200], + // CHECK-NEXT: [2080, 2312, 2544, 2776, 3008, 3240, 3472, 3704, 3936, 4168, 4400, 4632, 4864, 5096, 5328, 5560], + // CHECK-NEXT: [2200, 2448, 2696, 2944, 3192, 3440, 3688, 3936, 4184, 4432, 4680, 4928, 5176, 5424, 5672, 5920], + // CHECK-NEXT: [2320, 2584, 2848, 3112, 3376, 3640, 3904, 4168, 4432, 4696, 4960, 5224, 5488, 5752, 6016, 6280], + // CHECK-NEXT: [2440, 2720, 3000, 3280, 3560, 3840, 4120, 4400, 4680, 4960, 5240, 5520, 5800, 6080, 6360, 6640], + // CHECK-NEXT: [2560, 2856, 3152, 3448, 3744, 4040, 4336, 4632, 4928, 5224, 5520, 5816, 6112, 6408, 6704, 7000], + // CHECK-NEXT: [2680, 2992, 3304, 3616, 3928, 4240, 4552, 4864, 5176, 5488, 5800, 6112, 6424, 6736, 7048, 7360], + // CHECK-NEXT: [2800, 3128, 3456, 3784, 4112, 4440, 4768, 5096, 5424, 5752, 6080, 6408, 6736, 7064, 7392, 7720], + // CHECK-NEXT: [2920, 3264, 3608, 3952, 4296, 4640, 4984, 5328, 5672, 6016, 6360, 6704, 7048, 7392, 7736, 8080], + // CHECK-NEXT: [3040, 3400, 3760, 4120, 4480, 4840, 5200, 5560, 5920, 6280, 6640, 7000, 7360, 7720, 8080, 8440] + return +} + +func.func private @printMemrefF32(memref<*xf32>) diff --git a/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16_a_b_transpose.mlir b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16_a_b_transpose.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/GPU/ROCM/WMMA/wmma_f32_16_16_16_f16_a_b_transpose.mlir @@ -0,0 +1,84 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -convert-scf-to-cf \ +// RUN: | mlir-opt -gpu-kernel-outlining \ +// RUN: | mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-amdgpu{chipset=%chip index-bitwidth=32},convert-gpu-to-rocdl{chipset=%chip index-bitwidth=32},reconcile-unrealized-casts,gpu-to-hsaco{chip=%chip}))' \ +// RUN: | mlir-opt -gpu-to-llvm \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_rocm_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +func.func @main() { + %0 = memref.alloc() : memref<16x16xf16> + %22 = memref.alloc() : memref<16x16xf32> + + %f1 = arith.constant 1.0e+00 : f16 + %f0 = arith.constant 0.0e+00 : f32 + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + + // Intialize the Input matrix with ones. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + %cast_c = arith.index_cast %arg1 : index to i16 + %float = arith.sitofp %cast_c : i16 to f16 + memref.store %float, %0[%arg0, %arg1] : memref<16x16xf16> + } + } + // Intialize the accumulator matrix with zeros. + scf.for %arg0 = %c0 to %c16 step %c1 { + scf.for %arg1 = %c0 to %c16 step %c1 { + memref.store %f0, %22[%arg0, %arg1] : memref<16x16xf32> + } + } + + %2 = memref.cast %0 : memref<16x16xf16> to memref<*xf16> + %33 = memref.cast %22 : memref<16x16xf32> to memref<*xf32> + + %stream = gpu.wait async + %gpu_in, %asyncToken_0 = gpu.alloc async [%stream] () : memref<16x16xf16> + %gpu_out, %asyncToken_1 = gpu.alloc async [%stream] () : memref<16x16xf32> + + %asyncToken_2 = gpu.memcpy async [%stream] %gpu_in, %0 : memref<16x16xf16>, memref<16x16xf16> + %asyncToken_3 = gpu.memcpy async [%stream] %gpu_out, %22 : memref<16x16xf32>, memref<16x16xf32> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + %A = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> + %B = gpu.subgroup_mma_load_matrix %gpu_in[%c0, %c0] {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> + %C = gpu.subgroup_mma_load_matrix %gpu_out[%c0, %c0] {leadDimension = 16 : index} : memref<16x16xf32> -> !gpu.mma_matrix<16x16xf32, "COp"> + + %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf32, "COp"> + + gpu.subgroup_mma_store_matrix %R, %gpu_out[%c0, %c0] {leadDimension = 16 : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32> + gpu.terminator + } + + %asyncToken_4 = gpu.memcpy async [%stream] %22, %gpu_out : memref<16x16xf32>, memref<16x16xf32> + gpu.wait [%stream] + + // Print the memref after computation. + call @printMemrefF32(%33) : (memref<*xf32>) -> () + // CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + // CHECK-NEXT: [120, 120, 120, 120, 120, 120, 120, 120, 120, 120, 120, 120, 120, 120, 120, 120], + // CHECK-NEXT: [240, 240, 240, 240, 240, 240, 240, 240, 240, 240, 240, 240, 240, 240, 240, 240], + // CHECK-NEXT: [360, 360, 360, 360, 360, 360, 360, 360, 360, 360, 360, 360, 360, 360, 360, 360], + // CHECK-NEXT: [480, 480, 480, 480, 480, 480, 480, 480, 480, 480, 480, 480, 480, 480, 480, 480], + // CHECK-NEXT: [600, 600, 600, 600, 600, 600, 600, 600, 600, 600, 600, 600, 600, 600, 600, 600], + // CHECK-NEXT: [720, 720, 720, 720, 720, 720, 720, 720, 720, 720, 720, 720, 720, 720, 720, 720], + // CHECK-NEXT: [840, 840, 840, 840, 840, 840, 840, 840, 840, 840, 840, 840, 840, 840, 840, 840], + // CHECK-NEXT: [960, 960, 960, 960, 960, 960, 960, 960, 960, 960, 960, 960, 960, 960, 960, 960], + // CHECK-NEXT: [1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080, 1080], + // CHECK-NEXT: [1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200, 1200], + // CHECK-NEXT: [1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320, 1320], + // CHECK-NEXT: [1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440, 1440], + // CHECK-NEXT: [1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560, 1560], + // CHECK-NEXT: [1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680, 1680], + // CHECK-NEXT: [1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800, 1800] + return +} + +func.func private @printMemrefF32(memref<*xf32>) diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -46,6 +46,7 @@ config.mlir_run_x86vector_tests = @MLIR_RUN_X86VECTOR_TESTS@ config.mlir_run_riscv_vector_tests = "@MLIR_RUN_RISCV_VECTOR_TESTS@" config.mlir_run_cuda_tensor_core_tests = @MLIR_RUN_CUDA_TENSOR_CORE_TESTS@ +config.mlir_run_rocm_wmma_tests = @MLIR_RUN_ROCM_WMMA_TESTS@ config.mlir_run_cuda_sm80_tests = @MLIR_RUN_CUDA_SM80_TESTS@ config.mlir_run_cuda_sm80_lt_tests = @MLIR_RUN_CUDA_SM80_LT_TESTS@ config.mlir_run_cuda_sm90_tests = @MLIR_RUN_CUDA_SM90_TESTS@