diff --git a/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h b/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h @@ -0,0 +1,66 @@ +//===----- CommonTypes.h - Contains LLVM Types common to all Lowerings. ---===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the common LLVM types that are used by the lowerings of +// GPU MMA Ops to NVVM ops. +// +//===----------------------------------------------------------------------===// +#ifndef COMMON_TYPES_INCLUDED +#define COMMON_TYPES_INCLUDED + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "llvm/IR/DerivedTypes.h" + +namespace mlir { + +/// Contains all the common LLVM types which are used across the lowerings of +/// GPU subgroup ops to NVVM dialect. +struct CommonLLVMTypes { +public: + CommonLLVMTypes(MLIRContext *context) { + numHalfsInOpFrags.resize(4); + numHalfsInOpFrags[A] = 8; + numHalfsInOpFrags[B] = 8; + numHalfsInOpFrags[C] = 4; + numHalfsInOpFrags[D] = 4; + llvmInt8Type = LLVM::LLVMType::getInt8Ty(context); + llvmInt64Type = LLVM::LLVMType::getInt64Ty(context); + llvmInt32Type = LLVM::LLVMType::getInt32Ty(context); + llvmInt32PtrTy = LLVM::LLVMPointerType::get(llvmInt32Type); + llvmF16Ty = LLVM::LLVMType::getHalfTy(context); + llvmF16PtrTy = LLVM::LLVMPointerType::get(llvmF16Ty); + llvmF16x8Ty = LLVM::LLVMType::getVectorTy(llvmF16Ty, 8); + llvmF16x16Ty = LLVM::LLVMType::getVectorTy(llvmF16Ty, 16); + llvmF16x2Ty = LLVM::LLVMType::getVectorTy(llvmF16Ty, 2); + fragArrayABTy = LLVM::LLVMType::getStructTy( + context, SmallVector(8, llvmF16x2Ty)); + fragArrayABPtrTy = LLVM::LLVMPointerType::get(fragArrayABTy); + fragArrayCDTy = LLVM::LLVMType::getStructTy( + context, SmallVector(4, llvmF16x2Ty)); + fragArrayCDPtrTy = LLVM::LLVMPointerType::get(fragArrayCDTy); + }; + + LLVM::LLVMType llvmInt8Type; + LLVM::LLVMType llvmInt32Type; + LLVM::LLVMType llvmInt64Type; + LLVM::LLVMType llvmInt32PtrTy; + LLVM::LLVMType llvmF16Ty; + LLVM::LLVMType llvmF16PtrTy; + LLVM::LLVMType llvmF16x2Ty; + LLVM::LLVMType llvmF16x8Ty; + LLVM::LLVMType llvmF16x16Ty; + LLVM::LLVMType fragArrayABTy; + LLVM::LLVMType fragArrayABPtrTy; + LLVM::LLVMType fragArrayCDTy; + LLVM::LLVMType fragArrayCDPtrTy; + SmallVector numHalfsInOpFrags; + enum operandMap { A, B, C, D }; +}; + +} // namespace mlir +#endif diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -26,6 +26,9 @@ #include "../GPUCommon/IndexIntrinsicsOpLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" #include "../PassDetail.h" +#include "WmmaLoadOptoNvvmLowering.h" +#include "WmmaMmaOptoNvvmLowering.h" +#include "WmmaStoreOptoNvvmLowering.h" using namespace mlir; @@ -171,6 +174,9 @@ // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). GPUFuncOpLowering<0>>(converter); + patterns.insert(converter); + patterns.insert(converter); + patterns.insert(converter); patterns.insert>(converter, "__nv_fabsf", "__nv_fabs"); patterns.insert>(converter, "__nv_atanf", diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaLoadOptoNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaLoadOptoNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaLoadOptoNvvmLowering.h @@ -0,0 +1,181 @@ +//===-- WmmaLoadOptoNVVMLowering.h - GPU MMA loadOp to NVVM Op lowering ---===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_loadOp to NVVM +// Dialect. +// +//===----------------------------------------------------------------------===// + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { +/// This class implemtents the conversion of GPU MMA loadOp to wmma.load op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to store the data in the destination memref +/// after it has been loaded. +struct WmmaLoadOptoNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaLoadOptoNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return value.getType().isa(); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaLoadMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + // Source memref of the original op. + MemRefType srcMemrefType = + subgroupMmaLoadMatrixOp.srcMemref().getType().cast(); + Location loc = op->getLoc(); + + auto beginInx = subgroupMmaLoadMatrixOp.indices().getBeginOperandIndex(); + auto ldm = subgroupMmaLoadMatrixOp.ldmAttr(); + auto operand = subgroupMmaLoadMatrixOp.operandAttr(); + + // Emit information for the memref operands. + auto promotedSrcOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(0), operands[0], rewriter); + + auto promotedDstOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(1), operands[1], rewriter); + + // Emit ops which compute the load offset using `srcOffsetI`, + // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + ((ldm * + // srcOffsetI) + srcOffsetJ)). The memrefs here are assumed to be normalized + // and hence the simple conversion works. + Value srcOffsetIVal = subgroupMmaLoadMatrixOp->getOpOperand(beginInx).get(); + Value srcOffsetJVal = + subgroupMmaLoadMatrixOp->getOpOperand(beginInx + 1).get(); + Value leadingDim64 = + rewriter.create(loc, llvmTypes.llvmInt64Type, ldm); + Value numElemsLeadDim = rewriter.create( + loc, llvmTypes.llvmInt64Type, leadingDim64, srcOffsetIVal); + Value loadOffset = rewriter.create( + loc, llvmTypes.llvmInt64Type, numElemsLeadDim, srcOffsetJVal); + // Cast offset I64 to make the calculation below independent of index + // bitwidth supplied. + Value promotedSrcOpToUse; + int64_t indexTypeBitwidth = + this->getTypeConverter()->getIndexTypeBitwidth(); + if (indexTypeBitwidth < 64) + promotedSrcOpToUse = rewriter.create( + loc, llvmTypes.llvmInt64Type, promotedSrcOp[2]); + else if (indexTypeBitwidth > 64) + promotedSrcOpToUse = rewriter.create( + loc, llvmTypes.llvmInt64Type, promotedSrcOp[2]); + else + promotedSrcOpToUse = promotedSrcOp[2]; + Value actualOffset = rewriter.create( + loc, llvmTypes.llvmInt64Type, loadOffset, promotedSrcOpToUse); + Value loadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmF16Ty, + srcMemrefType.getMemorySpace()), + promotedSrcOp[1], ArrayRef{actualOffset}); + + // Bitcast the pointer from *half to *i32 so that it matches the semantics + // of the inrinsic exposed by the NVPTX backend. + Value loadAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + srcMemrefType.getMemorySpace()), + loadAddress); + + // Result types for wmmaLoadOp. + LLVM::LLVMType resType, dstMemrefElemType; + unsigned numElemsInResFrag; + if (operand.cast().getValue().equals("AOp") || + operand.cast().getValue().equals("BOp")) { + resType = llvmTypes.fragArrayABTy; + dstMemrefElemType = llvmTypes.llvmF16x16Ty; + numElemsInResFrag = llvmTypes.numHalfsInOpFrags[llvmTypes.A]; + } else { + resType = llvmTypes.fragArrayCDTy; + dstMemrefElemType = llvmTypes.llvmF16x8Ty; + numElemsInResFrag = llvmTypes.numHalfsInOpFrags[llvmTypes.C]; + } + + // For NVPTX intrinsic compatibility, create an I32 constant op for ldm. + // This might result in loss of data. leadingDim is in number of elements + // as required by the NVPTX instrinsic. + Value leadingDim32 = + rewriter.create(loc, llvmTypes.llvmInt32Type, ldm); + + // Create nvvm.mma_load op according to the operand. + ValueRange loadOpOperands({loadAddressCasted, leadingDim32}); + + NVVM::WMMALoadOp wmmaLoadOp = rewriter.create( + loc, resType, loadOpOperands, operand.cast()); + + // Get the store address for this fragment in the destination memref. + Value dstStoreAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(dstMemrefElemType, + /*NVVM private memory space*/ 0), + promotedDstOp[1], subgroupMmaLoadMatrixOp.dstIndex()); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits, as they were returned by the + // wmmaLoadOP. + Value storeAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + dstStoreAddress); + + // Move the data into the memref that was passed as an argument to the + // original op. The result of the op is a !llvm.struct, the results are to + // be moved into the memref element by element. The number of elements in + // the memref and the number of elements in the struct should be same. + for (unsigned i = 0, e = numElemsInResFrag; i < e; ++i) { + Value toStore = rewriter.create( + loc, llvmTypes.llvmF16x2Ty, wmmaLoadOp, + rewriter.getIndexArrayAttr(i)); + Value toStoreI32 = rewriter.create( + loc, llvmTypes.llvmInt32Type, toStore); + Value storeAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + storeAddressCasted, + ArrayRef{rewriter.create( + loc, llvmTypes.llvmInt32Type, rewriter.getUI32IntegerAttr(i))}); + rewriter.create(loc, toStoreI32, storeAddress); + } + + rewriter.eraseOp(op); + return success(); + } + +private: + /// Contains definitions of all the LLVM types which are used for lowering + /// this GPU subgroupMmaLoadMatrixOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h @@ -0,0 +1,193 @@ +//===--- WmmaMmaOptoNVVMLowering.h - GPU MMA mmaOp to NVVM Op lowering ----===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_computeop to the +// NVVM Dialect. +// +//===----------------------------------------------------------------------===// + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { +/// This class implemtents the conversion of GPU MMA computeOp to wmma.mma op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to unpack the data from source memrefs to give +/// them to the NVVM OP and then again pack the results to store them into the +/// destination memref. +struct WmmaMmaOptoNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaMmaOptoNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return value.getType().isa(); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaComputeOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + Location loc = op->getLoc(); + + SmallVector opTypes; + SmallVector elemTypes; + SmallVector llvmElemTypes; + SmallVector opIndices; + + auto populateOpInfo = [&]() { + opTypes.push_back( + subgroupMmaComputeOp.opA().getType().cast()); + opTypes.push_back( + subgroupMmaComputeOp.opB().getType().cast()); + opTypes.push_back( + subgroupMmaComputeOp.opC().getType().cast()); + opTypes.push_back( + subgroupMmaComputeOp.opD().getType().cast()); + + for (MemRefType opType : opTypes) { + elemTypes.push_back(opType.getElementType()); + } + + opIndices.push_back(subgroupMmaComputeOp.AIndex()); + opIndices.push_back(subgroupMmaComputeOp.BIndex()); + opIndices.push_back(subgroupMmaComputeOp.CIndex()); + opIndices.push_back(subgroupMmaComputeOp.DIndex()); + + llvmElemTypes.push_back(llvmTypes.llvmF16x16Ty); + llvmElemTypes.push_back(llvmTypes.llvmF16x16Ty); + llvmElemTypes.push_back(llvmTypes.llvmF16x8Ty); + llvmElemTypes.push_back(llvmTypes.llvmF16x8Ty); + }; + + // Gather type, shape info fo the memrefs. + populateOpInfo(); + + // Promote operands of this op. This emits !llvm.extractelement for each of + // the operand memrefs and makes it easy to use these values in subsequent + // instruction. + SmallVector, 4> promotedOps; + promotedOps.resize(4); + + auto promoteOps = [&](CommonLLVMTypes::operandMap operand) { + promotedOps[operand] = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(operand), operands[operand], rewriter); + }; + + promoteOps(llvmTypes.A); + promoteOps(llvmTypes.B); + promoteOps(llvmTypes.C); + promoteOps(llvmTypes.D); + + // The wmma.mma intrinsic in llvm requires the operands as individual + // values. So individual elements from the memrefs need to be extracted and + // then passed on to the intrinsic call. Emit llvm ops to extract individual + // values form lowered memrefs. + SmallVector unpackedOps; + + auto unpackOp = [&](CommonLLVMTypes::operandMap op) { + // Get the store address for this fragment in the destination memref. + Value loadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmElemTypes[op], + /*NVVM private memory space*/ 0), + promotedOps[op][1], opIndices[op]); + + // Cast the address from vector<16xhalf>* to int32, to i32* to load + // the elements in chunks of 32-bits. + Value loadAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + loadAddress); + + for (unsigned i = 0, e = llvmTypes.numHalfsInOpFrags[op]; i < e; ++i) { + Value loadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + loadAddressCasted, + ArrayRef{rewriter.create( + loc, llvmTypes.llvmInt32Type, rewriter.getUI32IntegerAttr(i))}); + Value toStore = rewriter.create(loc, loadAddress); + unpackedOps.push_back(rewriter.create( + loc, llvmTypes.llvmF16x2Ty, toStore)); + } + }; + + unpackOp(llvmTypes.A); + unpackOp(llvmTypes.B); + unpackOp(llvmTypes.C); + + // Operand holder for wmma.mma.op. + ValueRange wmmaMmaOpOperands(unpackedOps); + + // Create nvvm.wmma.mma op. + NVVM::WMMAMmaOp wmmaMmaOp = rewriter.create( + loc, llvmTypes.fragArrayCDTy, wmmaMmaOpOperands); + + // Get the store address for this fragment in the destination memref. + Value dstStoreAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmElemTypes[llvmTypes.D], + /*NVVM private memory space*/ 0), + promotedOps[llvmTypes.D][1], subgroupMmaComputeOp.DIndex()); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits, as they were returned by the + // wmmaLoadOP. + Value storeAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + dstStoreAddress); + + // Store the results in memref D. + for (unsigned i = 0, e = llvmTypes.numHalfsInOpFrags[llvmTypes.D]; i < e; + ++i) { + Value toStore = rewriter.create( + loc, llvmTypes.llvmF16x2Ty, wmmaMmaOp, rewriter.getIndexArrayAttr(i)); + Value toStoreI32 = rewriter.create( + loc, llvmTypes.llvmInt32Type, toStore); + Value storeAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + storeAddressCasted, + ArrayRef{rewriter.create( + loc, llvmTypes.llvmInt32Type, rewriter.getUI32IntegerAttr(i))}); + rewriter.create(loc, toStoreI32, storeAddress); + } + + rewriter.eraseOp(op); + return success(); + } + +private: + /// Contains definitions of all the LLVM types which are used for lowering + /// this GPU subgroupMmaComputeOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaStoreOptoNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaStoreOptoNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaStoreOptoNvvmLowering.h @@ -0,0 +1,165 @@ +//==-- WmmaStoreOptoNVVMLowering.h - GPU MMA storeOp to NVVM Op lowering --==// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_store op to the +// NVVM Dialect. +// +//===----------------------------------------------------------------------===// + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { +/// This class implemtents the conversion of GPU MMA storeOp to wmma.store op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to unpack the data in the source memref and +/// convert the data in the format that is needed by the NVVM op. +struct WmmaStoreOptoNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaStoreOptoNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return value.getType().isa(); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaStoreMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + Location loc = op->getLoc(); + + // Destination memref of the original op. + MemRefType dstMemrefType = + subgroupMmaStoreMatrixOp.dstMemref().getType().cast(); + + // Promote operands of this op. This emits !llvm.extractelement for each + // of the operand memrefs and makes it easy to use these values in + // subsequent instruction. + auto promotedSrcOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(0), operands[0], rewriter); + + auto promotedDstOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(1), operands[1], rewriter); + + auto ldm = subgroupMmaStoreMatrixOp.ldmAttr(); + unsigned beginInx = + subgroupMmaStoreMatrixOp.indices().getBeginOperandIndex(); + + // Emit ops which compute the store offset using `dstOffsetI`, + // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + ((ldm * + // dstOffsetI) + dstOffsetJ)). + Value dstOffsetIVal = subgroupMmaStoreMatrixOp.getOperand(beginInx); + Value dstOffsetJVal = subgroupMmaStoreMatrixOp.getOperand(beginInx + 1); + Value leadingDim64 = + rewriter.create(loc, llvmTypes.llvmInt64Type, ldm); + Value numElemsLeadDim = rewriter.create( + loc, llvmTypes.llvmInt64Type, leadingDim64, dstOffsetIVal); + Value loadOffset = rewriter.create( + loc, llvmTypes.llvmInt64Type, numElemsLeadDim, dstOffsetJVal); + // Cast offset I64 to make the calculation below independent of index + // bitwidth supplied. + Value promotedDstOpToUse; + int64_t indexTypeBitwidth = + this->getTypeConverter()->getIndexTypeBitwidth(); + if (indexTypeBitwidth < 64) + promotedDstOpToUse = rewriter.create( + loc, llvmTypes.llvmInt64Type, promotedDstOp[2]); + else if (indexTypeBitwidth > 64) + promotedDstOpToUse = rewriter.create( + loc, llvmTypes.llvmInt64Type, promotedDstOp[2]); + else + promotedDstOpToUse = promotedDstOp[2]; + Value actualOffset = rewriter.create( + loc, llvmTypes.llvmInt64Type, loadOffset, promotedDstOpToUse); + Value storeAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmF16Ty, + dstMemrefType.getMemorySpace()), + promotedDstOp[1], ArrayRef{actualOffset}); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits. + Value storeAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + dstMemrefType.getMemorySpace()), + storeAddress); + + SmallVector storeOpOperands; + storeOpOperands.push_back(storeAddressCasted); + + // Get the load address for this fragment in the source memref. + Value srcLoadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmF16x8Ty, + /*NVVM private memory space*/ 0), + promotedSrcOp[1], subgroupMmaStoreMatrixOp.srcIndex()); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits, as they were returned by the + // wmmaLoadOP. + Value loadAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + srcLoadAddress); + + // Unpack the results from the source memref. + for (unsigned i = 0, e = llvmTypes.numHalfsInOpFrags[llvmTypes.D]; i < e; + ++i) { + Value loadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.llvmInt32Type, + /*NVVM private memory space*/ 0), + loadAddressCasted, + ArrayRef{rewriter.create( + loc, llvmTypes.llvmInt32Type, rewriter.getUI32IntegerAttr(i))}); + Value toStore = rewriter.create(loc, loadAddress); + storeOpOperands.push_back(rewriter.create( + loc, llvmTypes.llvmF16x2Ty, toStore)); + } + + // For NVPTX intrinsic compatibility, create an I32 constant op for ldm. + // This might result in loss of data. leadingDim is in number of elements + // as required by the NVPTX instrinsic. + Value leadingDim32 = + rewriter.create(loc, llvmTypes.llvmInt32Type, ldm); + storeOpOperands.push_back(leadingDim32); + + // Create nvvm.mma_store op. + ValueRange unpackedValueRange(storeOpOperands); + rewriter.create(loc, storeOpOperands); + + rewriter.eraseOp(op); + return success(); + } + +private: + /// Contains definitions of all the LLVM types which are used for lowering + /// this GPU SubgroupMmaStoreMatrixOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -0,0 +1,400 @@ +// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_load_op() + func @gpu_wmma_load_op() -> () { + %wg = alloca() {alignment = 32} : memref<32x32xf16, 3> + %A = alloca() : memref<1xvector<16xf16>, 5> + %i = constant 16 : i64 + %j = constant 16 : i64 + %c0 = constant 0 : i64 + gpu.subgroup_mma_load_matrix %wg[%i, %j], %A[%c0] {operand = "AOp", ldm = 32 : i64} : memref<32x32xf16, 3>, memref<1xvector<16xf16>, 5> + + // CHECK: %[[OFF:.*]] = llvm.mlir.constant(16 : i64) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(0 : i64) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(32 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(32 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(1024 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.null : !llvm.ptr + // CHECK-NEXT: %{{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK-NEXT: %{{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr to !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.alloca {{.*}} x !llvm.half {alignment = 32 : i64} : (!llvm.i64) -> !llvm.ptr + // CHECK-NEXT: %{{.*}} = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.mlir.null : !llvm.ptr, 5> + // CHECK-NEXT: %{{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, 5>, !llvm.i64) -> !llvm.ptr, 5> + // CHECK-NEXT: %{{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr, 5> to !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.alloca {{.*}} x !llvm.vec<16 x half> : (!llvm.i64) -> !llvm.ptr, 5> + // CHECK-NEXT: %{{.*}} = llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %[[BASE:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %[[OFFSETT:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %[[STOREADDR:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %[[LDM:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 + // CHECK-NEXT: %[[ILDM:.*]] = llvm.mul %[[LDM]], %[[OFF]] : !llvm.i64 + // CHECK-NEXT: %[[IJLDM:.*]] = llvm.add %[[ILDM]], %[[OFF]] : !llvm.i64 + // CHECK-NEXT: %[[IJOLDM:.*]] = llvm.add %[[IJLDM]], %[[OFFSETT]] : !llvm.i64 + // CHECK-NEXT: %[[NBASE:.*]] = llvm.getelementptr %[[BASE]][%[[IJOLDM]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK-NEXT: %[[CBASE:.*]] = llvm.bitcast %[[NBASE]] : !llvm.ptr to !llvm.ptr + // CHECK-NEXT: %[[STRIDE:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i32 + // CHECK-NEXT: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load %[[CBASE]], %[[STRIDE]] {operand = "AOp"} : !llvm.ptr, !llvm.i32 -> !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[STOREADDRR:.*]] = llvm.getelementptr %[[STOREADDR]][%{{.*}}] : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> + // CHECK-NEXT: %[[CSTOREADDR:.*]] = llvm.bitcast %[[STOREADDRR]] : !llvm.ptr> to !llvm.ptr + // CHECK-NEXT: %[[ST0:.*]] = llvm.extractvalue %[[FRAG]][0 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST0:.*]] = llvm.bitcast %[[ST0]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF0:.*]] = llvm.mlir.constant(0 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR0:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF0]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST0]], %[[FADDR0]] : !llvm.ptr + // CHECK-NEXT: %[[ST1:.*]] = llvm.extractvalue %[[FRAG]][1 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST1:.*]] = llvm.bitcast %[[ST1]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF1:.*]] = llvm.mlir.constant(1 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR1:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF1]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST1]], %[[FADDR1]] : !llvm.ptr + // CHECK-NEXT: %[[ST2:.*]] = llvm.extractvalue %[[FRAG]][2 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST2:.*]] = llvm.bitcast %[[ST2]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF2:.*]] = llvm.mlir.constant(2 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR2:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF2]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST2]], %[[FADDR2]] : !llvm.ptr + // CHECK-NEXT: %[[ST3:.*]] = llvm.extractvalue %[[FRAG]][3 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST3:.*]] = llvm.bitcast %[[ST3]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF3:.*]] = llvm.mlir.constant(3 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR3:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF3]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST3]], %[[FADDR3]] : !llvm.ptr + // CHECK-NEXT: %[[ST4:.*]] = llvm.extractvalue %[[FRAG]][4 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST4:.*]] = llvm.bitcast %[[ST4]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF4:.*]] = llvm.mlir.constant(4 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR4:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF4]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST4]], %[[FADDR4]] : !llvm.ptr + // CHECK-NEXT: %[[ST5:.*]] = llvm.extractvalue %[[FRAG]][5 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST5:.*]] = llvm.bitcast %[[ST5]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF5:.*]] = llvm.mlir.constant(5 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR5:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF5]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST5]], %[[FADDR5]] : !llvm.ptr + // CHECK-NEXT: %[[ST6:.*]] = llvm.extractvalue %[[FRAG]][6 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST6:.*]] = llvm.bitcast %[[ST6]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF6:.*]] = llvm.mlir.constant(6 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR6:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF6]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST6]], %[[FADDR6]] : !llvm.ptr + // CHECK-NEXT: %[[ST7:.*]] = llvm.extractvalue %[[FRAG]][7 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + // CHECK-NEXT: %[[CST7:.*]] = llvm.bitcast %[[ST7]] : !llvm.vec<2 x half> to !llvm.i32 + // CHECK-NEXT: %[[OFF7:.*]] = llvm.mlir.constant(7 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[FADDR7:.*]] = llvm.getelementptr %[[CSTOREADDR]][%[[OFF7]]] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: llvm.store %[[CST7]], %[[FADDR7]] : !llvm.ptr + // CHECK-NEXT: llvm.return + return + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_store_op() + func @gpu_wmma_store_op() -> () { + %sg = alloca(){alignment = 32} : memref<32x32xf16, 3> + %D = alloca() : memref<1xvector<8xf16>, 5> + %i = constant 16 : i64 + %j = constant 16 : i64 + %c0 = constant 0 : i64 + gpu.subgroup_mma_store_matrix %D[%c0], %sg[%i,%j] {ldm = 32 : i64} : memref<1xvector<8xf16>, 5>, memref<32x32xf16, 3> + + // CHECK: %[[OFF:.*]] = llvm.mlir.constant(16 : i64) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : i64) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(32 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(32 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1024 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.null : !llvm.ptr + // CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK-NEXT: {{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr to !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.alloca {{.*}} x !llvm.half {alignment = 32 : i64} : (!llvm.i64) -> !llvm.ptr + // CHECK-NEXT: {{.*}} = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.mlir.null : !llvm.ptr, 5> + // CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, 5>, !llvm.i64) -> !llvm.ptr, 5> + // CHECK-NEXT: {{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr, 5> to !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.alloca {{.*}} x !llvm.vec<8 x half> : (!llvm.i64) -> !llvm.ptr, 5> + // CHECK-NEXT: {{.*}} = llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: %[[SRCADDRR:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %[[BASEADDR:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %[[OFFSETTT:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-NEXT: %[[LDM:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i64 + // CHECK-NEXT: %[[OILDM:.*]] = llvm.mul %[[LDM]], %[[OFF]] : !llvm.i64 + // CHECK-NEXT: %[[OIJLDM:.*]] = llvm.add %[[OILDM]], %[[OFF]] : !llvm.i64 + // CHECK-NEXT: %[[TOFFSET:.*]] = llvm.add %[[OIJLDM]], %[[OFFSETTT]] : !llvm.i64 + // CHECK-NEXT: %[[LADDR:.*]] = llvm.getelementptr %[[BASEADDR]][%[[TOFFSET]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + // CHECK-NEXT: %[[CADDR:.*]] = llvm.bitcast %[[LADDR]] : !llvm.ptr to !llvm.ptr + // CHECK-NEXT: %[[BASE:.*]] = llvm.getelementptr %[[SRCADDRR]][%1] : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> + // CHECK-NEXT: %[[SRCADDR:.*]] = llvm.bitcast %[[BASE]] : !llvm.ptr> to !llvm.ptr + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[ADDR0:.*]] = llvm.getelementptr %[[SRCADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: %[[EL0I32:.*]] = llvm.load %[[ADDR0]] : !llvm.ptr + // CHECK-NEXT: %[[EL0:.*]] = llvm.bitcast %[[EL0I32]] : !llvm.i32 to !llvm.vec<2 x half> + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[ADDR1:.*]] = llvm.getelementptr %[[SRCADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: %[[EL1I32:.*]] = llvm.load %[[ADDR1]] : !llvm.ptr + // CHECK-NEXT: %[[EL1:.*]] = llvm.bitcast %[[EL1I32]] : !llvm.i32 to !llvm.vec<2 x half> + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(2 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[ADDR2:.*]] = llvm.getelementptr %[[SRCADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: %[[EL2I32:.*]] = llvm.load %[[ADDR2]] : !llvm.ptr + // CHECK-NEXT: %[[EL2:.*]] = llvm.bitcast %[[EL2I32]] : !llvm.i32 to !llvm.vec<2 x half> + // CHECK-NEXT: {{.*}} = llvm.mlir.constant(3 : ui32) : !llvm.i32 + // CHECK-NEXT: %[[ADDR3:.*]] = llvm.getelementptr %[[SRCADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + // CHECK-NEXT: %[[EL3I32:.*]] = llvm.load %[[ADDR3]] : !llvm.ptr + // CHECK-NEXT: %[[EL3:.*]] = llvm.bitcast %[[EL3I32]] : !llvm.i32 to !llvm.vec<2 x half> + // CHECK-NEXT: %[[STRIDE:.*]] = llvm.mlir.constant(32 : i64) : !llvm.i32 + // CHECK-NEXT: nvvm.wmma.m16n16k16.store %[[CADDR]], %[[EL0]], %[[EL1]], %[[EL2]], %[[EL3]], %[[STRIDE]] : !llvm.ptr, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.i32 + return + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_mma_op() + func @gpu_wmma_mma_op() -> () { + %A = alloca() : memref<1xvector<16xf16>, 5> + %B = alloca() : memref<1xvector<16xf16>, 5> + %C = alloca() : memref<1xvector<8xf16>, 5> + %D = alloca() : memref<1xvector<8xf16>, 5> + %c0 = constant 0 : i64 + + gpu.subgroup_mma_compute %A[%c0], %B[%c0], %C[%c0], %D[%c0] : memref<1xvector<16xf16>, 5>, memref<1xvector<16xf16>, 5>, memref<1xvector<8xf16>, 5>, memref<1xvector<8xf16>, 5> + + //CHECK-NEXT:%[[c0:.*]] = llvm.mlir.constant(0 : i64) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.null : !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, 5>, !llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr, 5> to !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.alloca {{.*}} x !llvm.vec<16 x half> : (!llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.null : !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, 5>, !llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr, 5> to !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.alloca {{.*}} x !llvm.vec<16 x half> : (!llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.null : !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, 5>, !llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr, 5> to !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.alloca {{.*}} x !llvm.vec<8 x half> : (!llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.mlir.null : !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm.ptr, 5>, !llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.ptrtoint {{.*}} : !llvm.ptr, 5> to !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.alloca {{.*}} x !llvm.vec<8 x half> : (!llvm.i64) -> !llvm.ptr, 5> + //CHECK-NEXT:{{.*}} = llvm.mlir.undef : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.insertvalue {{.*}}, {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:%[[AADDRR:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:%[[BADDRR:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:%[[CADDRR:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:%[[DADDRR:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[2] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[3, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[4, 0] : !llvm.struct<(ptr>, ptr>, i64, array<1 x i64>, array<1 x i64>)> + //CHECK-NEXT:%[[ABASE:.*]] = llvm.getelementptr %[[AADDRR]][%[[c0]]] : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> + //CHECK-NEXT:%[[AADDR:.*]] = llvm.bitcast %[[ABASE]] : !llvm.ptr> to !llvm.ptr + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A0I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A0:.*]] = llvm.bitcast %[[A0I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A1I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A1:.*]] = llvm.bitcast %[[A1I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(2 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A2I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A2:.*]] = llvm.bitcast %[[A2I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(3 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A3I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A3:.*]] = llvm.bitcast %[[A3I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(4 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A4I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A4:.*]] = llvm.bitcast %[[A4I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(5 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A5I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A5:.*]] = llvm.bitcast %[[A5I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(6 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A6I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A6:.*]] = llvm.bitcast %[[A6I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(7 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[AADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[A7I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[A7:.*]] = llvm.bitcast %[[A7I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:%[[BBASE:.*]] = llvm.getelementptr %[[BADDRR]][%[[c0]]] : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> + //CHECK-NEXT:%[[BADDR:.*]] = llvm.bitcast %[[BBASE]] : !llvm.ptr> to !llvm.ptr + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B0I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B0:.*]] = llvm.bitcast %[[B0I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B1I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B1:.*]] = llvm.bitcast %[[B1I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(2 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B2I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B2:.*]] = llvm.bitcast %[[B2I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(3 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B3I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B3:.*]] = llvm.bitcast %[[B3I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(4 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B4I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B4:.*]] = llvm.bitcast %[[B4I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(5 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B5I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B5:.*]] = llvm.bitcast %[[B5I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(6 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B6I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B6:.*]] = llvm.bitcast %[[B6I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(7 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[BADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[B7I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[B7:.*]] = llvm.bitcast %[[B7I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:%[[CBASE:.*]] = llvm.getelementptr %[[CADDRR]][%[[c0]]] : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> + //CHECK-NEXT:%[[CADDR:.*]] = llvm.bitcast %[[CBASE]] : !llvm.ptr> to !llvm.ptr + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[CADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[C0I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[C0:.*]] = llvm.bitcast %[[C0I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[CADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[C1I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[C1:.*]] = llvm.bitcast %[[C1I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(2 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[CADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[C2I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[C2:.*]] = llvm.bitcast %[[C2I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(3 : ui32) : !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.getelementptr %[[CADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:%[[C3I32:.*]] = llvm.load {{.*}} : !llvm.ptr + //CHECK-NEXT:%[[C3:.*]] = llvm.bitcast %[[C3I32]] : !llvm.i32 to !llvm.vec<2 x half> + //CHECK-NEXT:{{.*}} = nvvm.wmma.m16n16k16.mma %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[B0]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[C0]], %[[C1]], %[[C2]], %[[C3]] : !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half>, !llvm.vec<2 x half> -> !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + //CHECK-NEXT:%[[DBASE:.*]] = llvm.getelementptr %[[DADDRR]][%[[c0]]] : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> + //CHECK-NEXT:%[[DADDR:.*]] = llvm.bitcast %[[DBASE]] : !llvm.ptr> to !llvm.ptr + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[0 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + //CHECK-NEXT:%[[D0:.*]] = llvm.bitcast {{.*}} : !llvm.vec<2 x half> to !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(0 : ui32) : !llvm.i32 + //CHECK-NEXT:%[[DADDR0:.*]] = llvm.getelementptr %[[DADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT:llvm.store %[[D0]], %[[DADDR0]] : !llvm.ptr + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[1 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + //CHECK-NEXT:%[[D1:.*]] = llvm.bitcast {{.*}} : !llvm.vec<2 x half> to !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(1 : ui32) : !llvm.i32 + //CHECK-NEXT:%[[DADDR1:.*]] = llvm.getelementptr %[[DADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT: llvm.store %[[D1]], %[[DADDR1]] : !llvm.ptr + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[2 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + //CHECK-NEXT:%[[D2:.*]] = llvm.bitcast {{.*}} : !llvm.vec<2 x half> to !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(2 : ui32) : !llvm.i32 + //CHECK-NEXT:%[[DADDR2:.*]] = llvm.getelementptr %[[DADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT: llvm.store %[[D2]], %[[DADDR2]] : !llvm.ptr + //CHECK-NEXT:{{.*}} = llvm.extractvalue {{.*}}[3 : index] : !llvm.struct<(vec<2 x half>, vec<2 x half>, vec<2 x half>, vec<2 x half>)> + //CHECK-NEXT:%[[D3:.*]] = llvm.bitcast {{.*}} : !llvm.vec<2 x half> to !llvm.i32 + //CHECK-NEXT:{{.*}} = llvm.mlir.constant(3 : ui32) : !llvm.i32 + //CHECK-NEXT:%[[DADDR3:.*]] = llvm.getelementptr %[[DADDR]][{{.*}}] : (!llvm.ptr, !llvm.i32) -> !llvm.ptr + //CHECK-NEXT: llvm.store %[[D3]], %[[DADDR3]] : !llvm.ptr + //CHECK-NEXT:{{.*}} llvm.return + + return + } +}