diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -31,6 +31,10 @@ void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); +/// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM. +void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + /// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The /// index bitwidth used for the lowering of the device side index computations /// is configurable. diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_conversion_library(MLIRGPUToNVVMTransforms LowerGpuOpsToNVVMOps.cpp + WmmaOpsToNvvm.cpp DEPENDS MLIRConversionPassIncGen diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -126,6 +126,38 @@ return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); }); + // Lowering for MMAMatrixType. + converter.addConversion([&](gpu::MMAMatrixType type) -> Type { + // The number of items in structToReturn are dependent on the the dataType + // and the MMA operand that this operation is associated with. + llvm::DenseMap numElemsPerThreadF16, + numElemsPerThreadF32; + numElemsPerThreadF16["AOp"] = 8; + numElemsPerThreadF16["BOp"] = 8; + numElemsPerThreadF16["COp"] = 4; + numElemsPerThreadF16["DOp"] = 4; + numElemsPerThreadF32["AOp"] = 8; + numElemsPerThreadF32["BOp"] = 8; + numElemsPerThreadF32["COp"] = 8; + numElemsPerThreadF32["DOp"] = 8; + Type structToReturn; + if (type.getElementType().isF16()) { + // Number of f16's in 32-bit. + unsigned vecSize = 2; + Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext())); + unsigned size = numElemsPerThreadF16[type.getOperand()]; + SmallVector elements(size, vec); + structToReturn = + LLVM::LLVMStructType::getLiteral(&getContext(), elements); + } else if (type.getElementType().isF32()) { + unsigned size = numElemsPerThreadF32[type.getOperand()]; + SmallVector elements(size, FloatType::getF32(&getContext())); + structToReturn = + LLVM::LLVMStructType::getLiteral(&getContext(), elements); + } + return structToReturn; + }); + RewritePatternSet patterns(m.getContext()); RewritePatternSet llvmPatterns(m.getContext()); @@ -137,6 +169,7 @@ populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); + populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); configureGpuToNVVMConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -0,0 +1,451 @@ +//===------ WmmaOpsToNVVM.cpp - WMMA LD/ST/Compute to NVVM lowering -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definitions of patterns to lower GPU Subgroup MMA ops to +// NVVM Dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +using namespace mlir; + +namespace { + +/// Contains all the common LLVM types which are used across the lowerings of +/// GPU subgroup ops to NVVM dialect. +struct CommonLLVMAndBuiltInMLIRTypes { +public: + CommonLLVMAndBuiltInMLIRTypes(MLIRContext *context) { + numHalfsInOpFrags.resize(4); + numHalfsInOpFrags[A] = 8; + numHalfsInOpFrags[B] = 8; + numHalfsInOpFrags[C] = 4; + numHalfsInOpFrags[D] = 4; + i32Ty = IntegerType::get(context, 32); + f16Ty = FloatType::getF16(context); + f32Ty = FloatType::getF32(context); + f16x2Ty = VectorType::get(2, f16Ty); + fragArrayABTy = LLVM::LLVMStructType::getLiteral( + context, SmallVector(8, f16x2Ty)); + fragArrayCDTy = LLVM::LLVMStructType::getLiteral( + context, SmallVector(4, f16x2Ty)); + fragArrayCDF32Ty = + LLVM::LLVMStructType::getLiteral(context, SmallVector(8, f32Ty)); + }; + + Type i32Ty; + Type f16Ty; + Type f32Ty; + Type f16x2Ty; + /// Type for the fragment of A and B operands that a single thread holds for + /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) + + /// (beta*C). + Type fragArrayABTy; + /// Type for the fragment of C and D operands that a single thread holds for + /// fp16 data type in a WMMA operation of the form D = (alpha*(A*B)) + + /// (beta*C). + Type fragArrayCDTy; + /// Type for the fragment of C and D operands that a single thread holds for + /// fp32 data type in a WMMA operation of the form D = (alpha*(A*B)) + + /// (beta*C). + Type fragArrayCDF32Ty; + /// Represents the number of f16 elements a single thread holds in a WMMA + /// operation of the form D = (alpha*(A*B)) + (beta*C) . + SmallVector numHalfsInOpFrags; + /// Represents the operands of a MMA operation of the form D = (alpha*(A*B)) + + /// (beta*C). + enum OperandMap { A, B, C, D }; +}; + +/// Checks if all the operands of the op being lowered are of LLVM Types. The +/// types are expected to be converted by the `LLVMTypeConverter` before the op +/// is actually lowered. If the type of an operands is not already converted it +/// hints a missing typeConversion and failure is returned in that case. +static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) { + return rewriter.notifyMatchFailure( + op, "cannot convert if operands aren't of LLVM type."); + } + + return success(); +} + +/// Error string to emit when unimplemented WMMA variant is encountered. +static constexpr StringRef kInvalidCaseStr = + "Unimplemented WMMA variant, Only M16N16K16 version implemented."; + +/// This class implements the conversion of GPU MMA loadOp to wmma.load op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to store the data in the destination memref +/// after it has been loaded. +struct WmmaLoadOpToNVVMLowering + : public ConvertOpToLLVMPattern, + private CommonLLVMAndBuiltInMLIRTypes { +public: + explicit WmmaLoadOpToNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) { + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaLoadMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + unsigned indexTypeBitwidth = + this->getTypeConverter()->getIndexTypeBitwidth(); + + // The corresponding intrinsics expects leadDimension to be a 32-bit + // integer, so all the calculations of linearizing the load address + // must also follow this restriction. + if (indexTypeBitwidth != 32) + return rewriter.notifyMatchFailure( + op, "Expected indices to the memref to be 32-bit wide."); + + // Source memref of the original op. + MemRefType srcMemrefType = + subgroupMmaLoadMatrixOp.srcMemref().getType().cast(); + Location loc = op->getLoc(); + + auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); + + // MemRefDescriptor to extract alignedPtr and offset. + MemRefDescriptor promotedSrcOp( + gpu::SubgroupMmaLoadMatrixOpAdaptor(operands).srcMemref()); + + // Emit ops which compute the load offset using `srcOffsetI`, + // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + + // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are + // assumed to be normalized and hence the simple conversion works. + SmallVector indices(subgroupMmaLoadMatrixOp.indices()); + Value srcOffsetIVal = indices[0]; + Value srcOffsetJVal = indices[1]; + Value leadingDim32 = + rewriter.create(loc, i32Ty, leadDimension); + Value numElemsLeadDim = + rewriter.create(loc, i32Ty, leadingDim32, srcOffsetIVal); + Value loadOffset = rewriter.create(loc, i32Ty, numElemsLeadDim, + srcOffsetJVal); + + Value promotedSrcOpToUse; + promotedSrcOpToUse = promotedSrcOp.offset(rewriter, loc); + Value actualOffset = rewriter.create(loc, i32Ty, loadOffset, + promotedSrcOpToUse); + Value loadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(f16Ty, srcMemrefType.getMemorySpaceAsInt()), + promotedSrcOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits and semantics match with the + // intrinsic exposed by NVPTX backend. + Value loadAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(i32Ty, srcMemrefType.getMemorySpaceAsInt()), + loadAddress); + + // Get the shape of the MMAMatrix type being returned. The shape will + // choose which intrinsic this op will be lowered to. + gpu::MMAMatrixType retType = + subgroupMmaLoadMatrixOp.res().getType().cast(); + ArrayRef retTypeShape = retType.getShape(); + + Type resType; + StringRef operandStr = retType.getOperand(); + if (operandStr.equals("AOp") || operandStr.equals("BOp")) { + resType = fragArrayABTy; + } else { + if (srcMemrefType.getElementType().isF16()) + resType = fragArrayCDTy; + else if (srcMemrefType.getElementType().isF32()) + resType = fragArrayCDF32Ty; + else + return failure(); + } + + // Create nvvm.mma_load op according to the operand types. + SmallVector loadOpOperands({loadAddressCasted, leadingDim32}); + if (operandStr.equals("AOp")) { + if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadAOp.getResult()); + } else { + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + } else if (operandStr.equals("BOp")) { + if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadBOp.getResult()); + } else { + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + } else { + if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + if (srcMemrefType.getElementType().isF16()) { + NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadCOp.getResult()); + } else if (srcMemrefType.getElementType().isF32()) { + NVVM::WMMALoadCF32M16N16K16Op wmmaLoadCOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadCOp.getResult()); + } + } else { + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + } + return success(); + } +}; + +/// This class implements the conversion of GPU MMA storeOp to wmma.store op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to unpack the data in the source and +/// convert the data in the format that is needed by the NVVM op. +struct WmmaStoreOpToNVVMLowering + : public ConvertOpToLLVMPattern, + private CommonLLVMAndBuiltInMLIRTypes { +public: + explicit WmmaStoreOpToNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) { + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaStoreMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + unsigned indexTypeBitwidth = + this->getTypeConverter()->getIndexTypeBitwidth(); + // The corresponding intrinsics expects leadDimension to be a 32-bit + // integer, so all the calculations of linearizing the store address + // must also follow this restriction. + if (indexTypeBitwidth != 32) + return rewriter.notifyMatchFailure( + op, "expected indices to the memref to be 32-bit wide."); + + Location loc = op->getLoc(); + + // Destination memref of the original op. + MemRefType dstMemrefType = + subgroupMmaStoreMatrixOp.dstMemref().getType().cast(); + + // MemRefDescriptor to extract alignedPtr and offset. + MemRefDescriptor promotedDstOp( + gpu::SubgroupMmaStoreMatrixOpAdaptor(operands).dstMemref()); + + auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr(); + + // Emit ops which compute the store offset using `dstOffsetI`, + // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + + // ((leadDimension * dstOffsetI) + dstOffsetJ)). + SmallVector indices(subgroupMmaStoreMatrixOp.indices()); + Value dstOffsetIVal = indices[0]; + Value dstOffsetJVal = indices[1]; + Value leadingDim32 = + rewriter.create(loc, i32Ty, leadDimension); + Value numElemsLeadDim = + rewriter.create(loc, i32Ty, leadingDim32, dstOffsetIVal); + Value loadOffset = rewriter.create(loc, i32Ty, numElemsLeadDim, + dstOffsetJVal); + + Value promotedDstOpToUse; + promotedDstOpToUse = promotedDstOp.offset(rewriter, loc); + Value actualOffset = rewriter.create(loc, i32Ty, loadOffset, + promotedDstOpToUse); + Value storeAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(f16Ty, dstMemrefType.getMemorySpaceAsInt()), + promotedDstOp.alignedPtr(rewriter, loc), ArrayRef{actualOffset}); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits and semantics match with the + // intrinsic exposed by NVPTX backend. + Value storeAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(i32Ty, dstMemrefType.getMemorySpaceAsInt()), + storeAddress); + + SmallVector storeOpOperands; + storeOpOperands.push_back(storeAddressCasted); + + // Get the shape of the MMAMatrix type being stored. The shape will + // choose which intrinsic this op will be lowered to. + gpu::MMAMatrixType srcType = + subgroupMmaStoreMatrixOp.src().getType().cast(); + ArrayRef srcTypeShape = srcType.getShape(); + + // Unpack the results from the source. + if (subgroupMmaStoreMatrixOp.src() + .getType() + .cast() + .getElementType() == f16Ty) { + for (unsigned i = 0, e = numHalfsInOpFrags[D]; i < e; ++i) { + Value toUse = rewriter.create( + loc, f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i)); + storeOpOperands.push_back(toUse); + } + storeOpOperands.push_back(leadingDim32); + + // Create nvvm.mma_store op. + if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) { + rewriter.create(loc, storeOpOperands); + } else { + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + rewriter.eraseOp(op); + return success(); + } else if (subgroupMmaStoreMatrixOp.src() + .getType() + .cast() + .getElementType() == f32Ty) { + for (unsigned i = 0, e = 8; i < e; ++i) { + Value toUse = rewriter.create( + loc, f32Ty, operands[0], rewriter.getI32ArrayAttr(i)); + storeOpOperands.push_back(toUse); + } + storeOpOperands.push_back(leadingDim32); + + // Create nvvm.mma_store op. + if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) + rewriter.create(loc, storeOpOperands); + else { + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + rewriter.eraseOp(op); + return success(); + } + + return failure(); + } +}; + +/// This class implements the conversion of GPU MMA computeOp to wmma.mma op +/// in the NVVM dialect. +struct WmmaMmaOpToNVVMLowering + : public ConvertOpToLLVMPattern, + private CommonLLVMAndBuiltInMLIRTypes { + explicit WmmaMmaOpToNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + CommonLLVMAndBuiltInMLIRTypes(&this->getTypeConverter()->getContext()) { + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaComputeOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + Location loc = op->getLoc(); + + // The wmma.mma intrinsic in llvm requires the operands as individual + // values. So individual elements from the memrefs need to be extracted and + // then passed on to the intrinsic call. Emit llvm ops to extract individual + // values form lowered memrefs. + SmallVector unpackedOps; + + auto unpackOp = [&](CommonLLVMAndBuiltInMLIRTypes::OperandMap op, + Value operand, unsigned numElems, Type elemType) { + for (unsigned i = 0; i < numElems; ++i) { + Value toUse = rewriter.create( + loc, elemType, operand, rewriter.getI32ArrayAttr(i)); + unpackedOps.push_back(toUse); + } + }; + + // Get the shapes of the MMAMatrix type being used. The shapes will + // choose which intrinsic this op will be lowered to. + gpu::MMAMatrixType aType = + subgroupMmaComputeOp.opA().getType().cast(); + ArrayRef aTypeShape = aType.getShape(); + gpu::MMAMatrixType bType = + subgroupMmaComputeOp.opA().getType().cast(); + ArrayRef bTypeShape = bType.getShape(); + gpu::MMAMatrixType cType = + subgroupMmaComputeOp.opA().getType().cast(); + ArrayRef cTypeShape = cType.getShape(); + + gpu::SubgroupMmaComputeOpAdaptor transformedOperands(operands); + if (subgroupMmaComputeOp.opC() + .getType() + .cast() + .getElementType() == f16Ty) { + unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty); + unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty); + unpackOp(C, transformedOperands.opC(), numHalfsInOpFrags[C], f16x2Ty); + + if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && + bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { + // Create nvvm.wmma.mma op. + NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp = + rewriter.create(loc, fragArrayCDTy, + unpackedOps); + + rewriter.replaceOp(op, wmmaMmaOp.getResult()); + return success(); + } else { + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + } else if (subgroupMmaComputeOp.opC() + .getType() + .cast() + .getElementType() == f32Ty) { + unpackOp(A, transformedOperands.opA(), numHalfsInOpFrags[A], f16x2Ty); + unpackOp(B, transformedOperands.opB(), numHalfsInOpFrags[B], f16x2Ty); + unpackOp(C, transformedOperands.opC(), 8, f32Ty); + + if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && + bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { + // Create nvvm.wmma.mma op. + NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp = + rewriter.create( + loc, fragArrayCDF32Ty, unpackedOps); + + rewriter.replaceOp(op, wmmaMmaOp.getResult()); + return success(); + } else { + return rewriter.notifyMatchFailure(op, kInvalidCaseStr); + } + } + + return failure(); + } +}; + +} // anonymous namespace + +namespace mlir { +void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.insert(converter); + patterns.insert(converter); + patterns.insert(converter); +} +} // namespace mlir diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -0,0 +1,91 @@ +// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_load_op() -> + // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> { + func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + %j = constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_store_op + // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { + func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () { + %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + %j = constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3> + // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + // CHECK: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + // CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + // CHECK: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 + // CHECK: llvm.return + return + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_mma_op + // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { + func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () { + %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp"> + // CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[A4:.*]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[A5:.*]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[A6:.*]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[A7:.*]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[A8:.*]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B1:.*]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B2:.*]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B3:.*]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B4:.*]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B5:.*]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B6:.*]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B7:.*]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[B8:.*]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[C1:.*]] = llvm.extractvalue %[[C]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.return + return + } +}