diff --git a/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h b/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h @@ -0,0 +1,68 @@ +//===----- CommonTypes.h - Contains LLVM Types common to all Lowerings. ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the common LLVM types that are used by the lowerings of +// GPU MMA Ops to NVVM ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_GPUTONVVM_COMMONTYPES_H +#define MLIR_LIB_CONVERSION_GPUTONVVM_COMMONTYPES_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "llvm/IR/DerivedTypes.h" + +namespace mlir { + +/// Contains all the common LLVM types which are used across the lowerings of +/// GPU subgroup ops to NVVM dialect. +struct CommonLLVMTypes { +public: + CommonLLVMTypes(MLIRContext *context) { + numHalfsInOpFrags.resize(4); + numHalfsInOpFrags[A] = 8; + numHalfsInOpFrags[B] = 8; + numHalfsInOpFrags[C] = 4; + numHalfsInOpFrags[D] = 4; + int8Type = IntegerType::get(context, 8); + int64Type = IntegerType::get(context, 64); + int32Type = IntegerType::get(context, 32); + int32PtrTy = LLVM::LLVMPointerType::get(int32Type); + f16Ty = FloatType::getF16(context); + f16PtrTy = LLVM::LLVMPointerType::get(f16Ty); + f16x8Ty = VectorType::get(8, f16Ty); + f16x16Ty = VectorType::get(16, f16Ty); + f16x2Ty = VectorType::get(2, f16Ty); + fragArrayABTy = LLVM::LLVMStructType::getLiteral( + context, SmallVector(8, f16x2Ty)); + fragArrayABPtrTy = LLVM::LLVMPointerType::get(fragArrayABTy); + fragArrayCDTy = LLVM::LLVMStructType::getLiteral( + context, SmallVector(4, f16x2Ty)); + fragArrayCDPtrTy = LLVM::LLVMPointerType::get(fragArrayCDTy); + }; + + Type int8Type; + Type int32Type; + Type int64Type; + Type int32PtrTy; + Type f16Ty; + Type f16PtrTy; + Type f16x2Ty; + Type f16x8Ty; + Type f16x16Ty; + Type fragArrayABTy; + Type fragArrayABPtrTy; + Type fragArrayCDTy; + Type fragArrayCDPtrTy; + SmallVector numHalfsInOpFrags; + enum OperandMap { A, B, C, D }; +}; + +} // namespace mlir +#endif diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -26,6 +26,9 @@ #include "../GPUCommon/IndexIntrinsicsOpLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" #include "../PassDetail.h" +#include "WmmaLoadOptoNvvmLowering.h" +#include "WmmaMmaOptoNvvmLowering.h" +#include "WmmaStoreOptoNvvmLowering.h" using namespace mlir; @@ -122,6 +125,17 @@ return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); }); + converter.addConversion([&](gpu::MMAFragmentType type) -> Type { + VectorType vecTy = type.getElementType().cast(); + unsigned vecSize = vecTy.getDimSize(0); + Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext())); + unsigned size = type.getSize(); + SmallVector elements(size, vec); + auto structType = + LLVM::LLVMStructType::getLiteral(&getContext(), elements); + return structType; + }); + OwningRewritePatternList patterns, llvmPatterns; // Apply in-dialect lowering first. In-dialect lowering will replace ops @@ -171,6 +185,9 @@ // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). GPUFuncOpLowering<0>>(converter); + patterns.insert(converter); + patterns.insert(converter); + patterns.insert(converter); patterns.insert>(converter, "__nv_fabsf", "__nv_fabs"); patterns.insert>(converter, "__nv_atanf", diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaLoadOptoNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaLoadOptoNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaLoadOptoNvvmLowering.h @@ -0,0 +1,151 @@ +//===-- WmmaLoadOptoNVVMLowering.h - GPU MMA loadOp to NVVM Op lowering ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_loadOp to NVVM +// Dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_GPUTONVVM_WMMALOADOPTONVVMLOWERING_H +#define MLIR_LIB_CONVERSION_GPUTONVVM_WMMALOADOPTONVVMLOWERING_H + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { +/// This class implemtents the conversion of GPU MMA loadOp to wmma.load op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to store the data in the destination memref +/// after it has been loaded. +struct WmmaLoadOptoNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaLoadOptoNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaLoadMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + int8_t indexTypeBitwidth = this->getTypeConverter()->getIndexTypeBitwidth(); + + // The corresponding intrinsics expects leadDimension to be a 32-bit + // integer, so all the calculations of linearizing the load address + // must also follow this restriction. + if (indexTypeBitwidth != 32) + return rewriter.notifyMatchFailure( + op, "Expected indices to the meref to be 32-bit wide."); + + // Source memref of the original op. + MemRefType srcMemrefType = + subgroupMmaLoadMatrixOp.srcMemref().getType().cast(); + Location loc = op->getLoc(); + + auto beginInx = subgroupMmaLoadMatrixOp.indices().getBeginOperandIndex(); + auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); + auto operand = subgroupMmaLoadMatrixOp.operandAttr(); + + // Emit information for the memref operands. + auto promotedSrcOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(0), operands[0], rewriter); + + // Emit ops which compute the load offset using `srcOffsetI`, + // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + + // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are + // assumed to be normalized and hence the simple conversion works. + Value srcOffsetIVal = subgroupMmaLoadMatrixOp->getOpOperand(beginInx).get(); + Value srcOffsetJVal = + subgroupMmaLoadMatrixOp->getOpOperand(beginInx + 1).get(); + Value leadingDim32 = rewriter.create( + loc, llvmTypes.int32Type, leadDimension); + Value numElemsLeadDim = rewriter.create( + loc, llvmTypes.int32Type, leadingDim32, srcOffsetIVal); + Value loadOffset = rewriter.create( + loc, llvmTypes.int32Type, numElemsLeadDim, srcOffsetJVal); + // Cast offset I64 to make the calculation below independent of index + // bitwidth supplied. + Value promotedSrcOpToUse; + + promotedSrcOpToUse = promotedSrcOp[2]; + Value actualOffset = rewriter.create( + loc, llvmTypes.int32Type, loadOffset, promotedSrcOpToUse); + Value loadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.f16Ty, + srcMemrefType.getMemorySpace()), + promotedSrcOp[1], ArrayRef{actualOffset}); + + // Bitcast the pointer from *half to *i32 so that it matches the semantics + // of the inrinsic exposed by the NVPTX backend. + Value loadAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.int32Type, + srcMemrefType.getMemorySpace()), + loadAddress); + + Type resType; + unsigned numElemsInResFrag; + StringRef operandStr = operand.cast().getValue(); + + if (operandStr.equals("AOp") || operandStr.equals("BOp")) { + resType = llvmTypes.fragArrayABTy; + numElemsInResFrag = llvmTypes.numHalfsInOpFrags[llvmTypes.A]; + } else { + resType = llvmTypes.fragArrayCDTy; + numElemsInResFrag = llvmTypes.numHalfsInOpFrags[llvmTypes.C]; + } + + ValueRange loadOpOperands({loadAddressCasted, leadingDim32}); + + // Create nvvm.mma_load op according to the operand. + if (operandStr.equals("AOp")) { + NVVM::WMMALoadAOp wmmaLoadAOp = + rewriter.create(loc, resType, loadOpOperands); + rewriter.replaceOp(op, wmmaLoadAOp.getResult()); + } else if (operandStr.equals("BOp")) { + NVVM::WMMALoadBOp wmmaLoadBOp = + rewriter.create(loc, resType, loadOpOperands); + rewriter.replaceOp(op, wmmaLoadBOp.getResult()); + } else { + NVVM::WMMALoadCOp wmmaLoadCOp = + rewriter.create(loc, resType, loadOpOperands); + rewriter.replaceOp(op, wmmaLoadCOp.getResult()); + } + + return success(); + } + +private: + /// Contains definitions of all the LLVM types which are used for lowering + /// this GPU subgroupMmaLoadMatrixOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir + +#endif diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h @@ -0,0 +1,100 @@ +//===--- WmmaMmaOptoNVVMLowering.h - GPU MMA mmaOp to NVVM Op lowering ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_computeop to the +// NVVM Dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_GPUTONVVM_WMMASTOREOPTONVVMLOWERING_H +#define MLIR_LIB_CONVERSION_GPUTONVVM_WMMASTOREOPTONVVMLOWERING_H + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { +/// This class implemtents the conversion of GPU MMA computeOp to wmma.mma op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to unpack the data from source memrefs to give +/// them to the NVVM OP and then again pack the results to store them into the +/// destination memref. +struct WmmaMmaOptoNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaMmaOptoNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaComputeOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + Location loc = op->getLoc(); + + SmallVector opTypes; + SmallVector elemTypes; + SmallVector llvmElemTypes; + SmallVector opIndices; + + // The wmma.mma intrinsic in llvm requires the operands as individual + // values. So individual elements from the memrefs need to be extracted and + // then passed on to the intrinsic call. Emit llvm ops to extract individual + // values form lowered memrefs. + SmallVector unpackedOps; + + auto unpackOp = [&](CommonLLVMTypes::OperandMap op, Value operand) { + for (unsigned i = 0, e = llvmTypes.numHalfsInOpFrags[op]; i < e; ++i) { + Value toUse = rewriter.create( + loc, llvmTypes.f16x2Ty, operand, rewriter.getI64ArrayAttr(i)); + unpackedOps.push_back(toUse); + } + }; + + unpackOp(llvmTypes.A, operands[0]); + unpackOp(llvmTypes.B, operands[1]); + unpackOp(llvmTypes.C, operands[2]); + + // Operand holder for wmma.mma.op. + ValueRange wmmaMmaOpOperands(unpackedOps); + + // Create nvvm.wmma.mma op. + NVVM::WMMAMmaOp wmmaMmaOp = rewriter.create( + loc, llvmTypes.fragArrayCDTy, wmmaMmaOpOperands); + + rewriter.replaceOp(op, wmmaMmaOp.getResult()); + return success(); + } + +private: + /// Contains definitions of all the LLVM types which are used for lowering + /// this GPU subgroupMmaComputeOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir + +#endif diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaStoreOptoNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaStoreOptoNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaStoreOptoNvvmLowering.h @@ -0,0 +1,137 @@ +//==-- WmmaStoreOptoNVVMLowering.h - GPU MMA storeOp to NVVM Op lowering --==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_store op to the +// NVVM Dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_GPUTONVVM_WMMAMMAOPTONVVMLOWERING_H +#define MLIR_LIB_CONVERSION_GPUTONVVM_WMMAMMAOPTONVVMLOWERING_H + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { +/// This class implemtents the conversion of GPU MMA storeOp to wmma.store op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to unpack the data in the source memref and +/// convert the data in the format that is needed by the NVVM op. +struct WmmaStoreOptoNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaStoreOptoNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaStoreMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + int8_t indexTypeBitwidth = this->getTypeConverter()->getIndexTypeBitwidth(); + // The corresponding intrinsics expects leadDimension to be a 32-bit + // integer, so all the calculations of linearizing the load address + // must also follow this restriction. + if (indexTypeBitwidth != 32) + return rewriter.notifyMatchFailure( + op, "Expected indices to the meref to be 32-bit wide."); + + Location loc = op->getLoc(); + + // Destination memref of the original op. + MemRefType dstMemrefType = + subgroupMmaStoreMatrixOp.dstMemref().getType().cast(); + + auto promotedDstOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(1), operands[1], rewriter); + + auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr(); + unsigned beginInx = + subgroupMmaStoreMatrixOp.indices().getBeginOperandIndex(); + + // Emit ops which compute the store offset using `dstOffsetI`, + // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + + // ((leadDimension * dstOffsetI) + dstOffsetJ)). + Value dstOffsetIVal = subgroupMmaStoreMatrixOp.getOperand(beginInx); + Value dstOffsetJVal = subgroupMmaStoreMatrixOp.getOperand(beginInx + 1); + Value leadingDim32 = rewriter.create( + loc, llvmTypes.int32Type, leadDimension); + Value numElemsLeadDim = rewriter.create( + loc, llvmTypes.int32Type, leadingDim32, dstOffsetIVal); + Value loadOffset = rewriter.create( + loc, llvmTypes.int32Type, numElemsLeadDim, dstOffsetJVal); + // Cast offset I64 to make the calculation below independent of index + // bitwidth supplied. + Value promotedDstOpToUse; + + promotedDstOpToUse = promotedDstOp[2]; + Value actualOffset = rewriter.create( + loc, llvmTypes.int32Type, loadOffset, promotedDstOpToUse); + Value storeAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.f16Ty, + dstMemrefType.getMemorySpace()), + promotedDstOp[1], ArrayRef{actualOffset}); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits. + Value storeAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.int32Type, + dstMemrefType.getMemorySpace()), + storeAddress); + + SmallVector storeOpOperands; + storeOpOperands.push_back(storeAddressCasted); + + // Unpack the results from the source memref. + for (unsigned i = 0, e = llvmTypes.numHalfsInOpFrags[llvmTypes.D]; i < e; + ++i) { + Value toUse = rewriter.create( + loc, llvmTypes.f16x2Ty, operands[0], rewriter.getI64ArrayAttr(i)); + storeOpOperands.push_back(toUse); + } + + storeOpOperands.push_back(leadingDim32); + + // Create nvvm.mma_store op. + ValueRange unpackedValueRange(storeOpOperands); + rewriter.create(loc, storeOpOperands); + + rewriter.eraseOp(op); + return success(); + } + +private: + /// Contains definitions of all the LLVM types which are used for lowering + /// this GPU SubgroupMmaStoreMatrixOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir + +#endif diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -0,0 +1,135 @@ +// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_load_op() -> + // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> { + func @gpu_wmma_load_op() -> (!gpu.mmafragment<8, vector<2xf16>>) { + %wg = alloca() {alignment = 32} : memref<32x32xf16, 3> + %A = alloca() : memref<1xvector<16xf16>, 5> + %i = constant 16 : index + %j = constant 16 : index + %bias = constant dense<42.> : vector<2xf16> + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {operand = "AOp", leadDimension = 32 : i32} : memref<32x32xf16, 3> -> !gpu.mmafragment<8, vector<2xf16>> + //CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + //CHECK-NEXT: %{{.*}} = llvm.mlir.constant(32 : index) : i32 + //CHECK-NEXT: %{{.*}} = llvm.mlir.constant(32 : index) : i32 + //CHECK-NEXT: %{{.*}} = llvm.mlir.constant(1 : index) : i32 + //CHECK-NEXT: %{{.*}} = llvm.mlir.constant(1024 : index) : i32 + //CHECK-NEXT: %{{.*}} = llvm.mlir.null : !llvm.ptr + //CHECK-NEXT: %{{.*}} = llvm.getelementptr %{{.*}}[%{{.*}}] : (!llvm.ptr, i32) -> !llvm.ptr + //CHECK-NEXT: %{{.*}} = llvm.ptrtoint %{{.*}} : !llvm.ptr to i32 + //CHECK-NEXT: %{{.*}} = llvm.alloca %{{.*}} x f16 {alignment = 32 : i64} : (i32) -> !llvm.ptr + //CHECK-NEXT: %{{.*}} = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.mlir.constant(0 : index) : i32 + //CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[LDM:.*]] = llvm.mlir.constant(32 : i32) : i32 + //CHECK-NEXT: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + //CHECK-NEXT: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + //CHECK-NEXT: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + //CHECK-NEXT: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + return %0 : !gpu.mmafragment<8, vector<2xf16>> + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_store_op + // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { + func @gpu_wmma_store_op(%arg0 : !gpu.mmafragment<4, vector<2xf16>>) -> () { + %sg = alloca(){alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + %j = constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : i32} : !gpu.mmafragment<4, vector<2xf16>>, memref<32x32xf16, 3> + //CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + //CHECK-NEXT: %1 = llvm.mlir.constant(32 : index) : i32 + //CHECK-NEXT: %2 = llvm.mlir.constant(32 : index) : i32 + //CHECK-NEXT: %3 = llvm.mlir.constant(1 : index) : i32 + //CHECK-NEXT: %4 = llvm.mlir.constant(1024 : index) : i32 + //CHECK-NEXT: %5 = llvm.mlir.null : !llvm.ptr + //CHECK-NEXT: %6 = llvm.getelementptr %5[%4] : (!llvm.ptr, i32) -> !llvm.ptr + //CHECK-NEXT: %7 = llvm.ptrtoint %6 : !llvm.ptr to i32 + //CHECK-NEXT: %8 = llvm.alloca %7 x f16 {alignment = 32 : i64} : (i32) -> !llvm.ptr + //CHECK-NEXT: %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %12 = llvm.mlir.constant(0 : index) : i32 + //CHECK-NEXT: %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %14 = llvm.insertvalue %1, %13[3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %15 = llvm.insertvalue %2, %14[3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %16 = llvm.insertvalue %2, %15[4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %17 = llvm.insertvalue %3, %16[4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %18 = llvm.extractvalue %17[0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %21 = llvm.extractvalue %17[3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %22 = llvm.extractvalue %17[3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %23 = llvm.extractvalue %17[4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %24 = llvm.extractvalue %17[4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[LDM:.*]] = llvm.mlir.constant(32 : i32) : i32 + //CHECK-NEXT: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + //CHECK-NEXT: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + //CHECK-NEXT: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + //CHECK-NEXT: %[[EL1:.*]] = llvm.extractvalue %[[D]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 + //CHECK-NEXT: llvm.return + return + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_mma_op + // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { + func @gpu_wmma_mma_op(%A : !gpu.mmafragment<8, vector<2xf16>>, %B : !gpu.mmafragment<8, vector<2xf16>>, %C : !gpu.mmafragment<4, vector<2xf16>>) -> () { + %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mmafragment<8, vector<2xf16>>, !gpu.mmafragment<8, vector<2xf16>>, !gpu.mmafragment<4, vector<2xf16>> -> !gpu.mmafragment<4, vector<2xf16>> + //CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A2:.*]] = llvm.extractvalue %[[A]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A3:.*]] = llvm.extractvalue %[[A]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A4:.*]] = llvm.extractvalue %[[A]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A5:.*]] = llvm.extractvalue %[[A]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A6:.*]] = llvm.extractvalue %[[A]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A7:.*]] = llvm.extractvalue %[[A]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A8:.*]] = llvm.extractvalue %[[A]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B1:.*]] = llvm.extractvalue %[[B]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B2:.*]] = llvm.extractvalue %[[B]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B3:.*]] = llvm.extractvalue %[[B]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B4:.*]] = llvm.extractvalue %[[B]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B5:.*]] = llvm.extractvalue %[[B]][4] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B6:.*]] = llvm.extractvalue %[[B]][5] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B7:.*]] = llvm.extractvalue %[[B]][6] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B8:.*]] = llvm.extractvalue %[[B]][7] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C1:.*]] = llvm.extractvalue %[[C]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C2:.*]] = llvm.extractvalue %[[C]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C3:.*]] = llvm.extractvalue %[[C]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C4:.*]] = llvm.extractvalue %[[C]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %20 = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5%]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: llvm.return + return + } +}