diff --git a/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h b/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/CommonTypes.h @@ -0,0 +1,79 @@ +//===----- CommonTypes.h - Contains LLVM Types common to all Lowerings. ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the common LLVM types that are used by the lowerings of +// GPU MMA Ops to NVVM ops. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_GPUTONVVM_COMMONTYPES_H +#define MLIR_LIB_CONVERSION_GPUTONVVM_COMMONTYPES_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "llvm/IR/DerivedTypes.h" + +namespace mlir { + +/// Contains all the common LLVM types which are used across the lowerings of +/// GPU subgroup ops to NVVM dialect. +struct CommonLLVMTypes { +public: + CommonLLVMTypes(MLIRContext *context) { + numHalfsInOpFrags.resize(4); + numHalfsInOpFrags[A] = 8; + numHalfsInOpFrags[B] = 8; + numHalfsInOpFrags[C] = 4; + numHalfsInOpFrags[D] = 4; + int8Type = IntegerType::get(context, 8); + int64Type = IntegerType::get(context, 64); + int32Type = IntegerType::get(context, 32); + int32PtrTy = LLVM::LLVMPointerType::get(int32Type); + f16Ty = FloatType::getF16(context); + f32Ty = FloatType::getF32(context); + f16PtrTy = LLVM::LLVMPointerType::get(f16Ty); + f16x8Ty = VectorType::get(8, f16Ty); + f16x16Ty = VectorType::get(16, f16Ty); + f16x2Ty = VectorType::get(2, f16Ty); + fragArrayABTy = LLVM::LLVMStructType::getLiteral( + context, SmallVector(8, f16x2Ty)); + fragArrayABPtrTy = LLVM::LLVMPointerType::get(fragArrayABTy); + fragArrayCDTy = LLVM::LLVMStructType::getLiteral( + context, SmallVector(4, f16x2Ty)); + fragArrayCDF32Ty = + LLVM::LLVMStructType::getLiteral(context, SmallVector(8, f32Ty)); + }; + + Type int8Type; + Type int32Type; + Type int64Type; + Type int32PtrTy; + Type f16Ty; + Type f32Ty; + Type f16PtrTy; + Type f16x2Ty; + Type f16x8Ty; + Type f16x16Ty; + /// Type for the fragment of A and B operands that a single thread holds for + /// fp16 data type. + Type fragArrayABTy; + /// Type for a pointer to the fragment of `AB` operands that a single thread + /// holds for fp16 data type. + Type fragArrayABPtrTy; + /// Type for the fragment of C and D operands that a single thread holds for + /// fp16 data type. + Type fragArrayCDTy; + /// Type for the fragment of C and D operands that a single thread holds for + /// fp32 data type. + Type fragArrayCDF32Ty; + SmallVector numHalfsInOpFrags; + enum OperandMap { A, B, C, D }; +}; + +} // namespace mlir +#endif diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -28,6 +28,8 @@ #include "../GPUCommon/IndexIntrinsicsOpLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" #include "../PassDetail.h" +#include "WmmaLoadStoreToNvvmLowering.h" +#include "WmmaMmaOptoNvvmLowering.h" using namespace mlir; @@ -127,6 +129,37 @@ return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); }); + // Lowering for MMAMatrixType. + converter.addConversion([&](gpu::MMAMatrixType type) -> Type { + // The number of items in structToReturn are dependent on the the dataType + // and the MMA operand that this operation is associated with. + llvm::DenseMap numElemsPerThreadF16, + numElemsPerThreadF32; + numElemsPerThreadF16["AOp"] = 8; + numElemsPerThreadF16["BOp"] = 8; + numElemsPerThreadF16["COp"] = 4; + numElemsPerThreadF16["DOp"] = 4; + numElemsPerThreadF32["AOp"] = 8; + numElemsPerThreadF32["BOp"] = 8; + numElemsPerThreadF32["COp"] = 8; + numElemsPerThreadF32["DOp"] = 8; + Type structToReturn; + if (type.getElementType().isF16()) { + unsigned vecSize = 2 /*number of f16's in 32-bit*/; + Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext())); + unsigned size = numElemsPerThreadF16[type.getOperand()]; + SmallVector elements(size, vec); + structToReturn = + LLVM::LLVMStructType::getLiteral(&getContext(), elements); + } else if (type.getElementType().isF32()) { + unsigned size = numElemsPerThreadF32[type.getOperand()]; + SmallVector elements(size, FloatType::getF32(&getContext())); + structToReturn = + LLVM::LLVMStructType::getLiteral(&getContext(), elements); + } + return structToReturn; + }); + RewritePatternSet patterns(m.getContext()); RewritePatternSet llvmPatterns(m.getContext()); @@ -182,6 +215,9 @@ Identifier::get(NVVM::NVVMDialect::getKernelFuncAttrName(), &converter.getContext())); + patterns.insert(converter); + patterns.insert(converter); + patterns.insert(converter); patterns.add>(converter, "__nv_fabsf", "__nv_fabs"); patterns.add>(converter, "__nv_atanf", diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaLoadStoreToNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaLoadStoreToNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaLoadStoreToNvvmLowering.h @@ -0,0 +1,323 @@ +//===- WmmaLoadStoreOpToNVVMLowering.h - LD/ST to NVVM lowering -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_loadOp and +// MMA_storeop to NVVM Dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_GPUTONVVM_WMMALOADOPTONVVMLOWERING_H +#define MLIR_LIB_CONVERSION_GPUTONVVM_WMMALOADOPTONVVMLOWERING_H + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { + +static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); +} + +/// This class implemtents the conversion of GPU MMA loadOp to wmma.load op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to store the data in the destination memref +/// after it has been loaded. +struct WmmaLoadOpToNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaLoadOpToNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaLoadMatrixOp subgroupMmaLoadMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaLoadMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + int8_t indexTypeBitwidth = this->getTypeConverter()->getIndexTypeBitwidth(); + + // The corresponding intrinsics expects leadDimension to be a 32-bit + // integer, so all the calculations of linearizing the load address + // must also follow this restriction. + if (indexTypeBitwidth != 32) + return rewriter.notifyMatchFailure( + op, "Expected indices to the meref to be 32-bit wide."); + + // Source memref of the original op. + MemRefType srcMemrefType = + subgroupMmaLoadMatrixOp.srcMemref().getType().cast(); + Location loc = op->getLoc(); + + auto beginInx = subgroupMmaLoadMatrixOp.indices().getBeginOperandIndex(); + auto leadDimension = subgroupMmaLoadMatrixOp.leadDimensionAttr(); + // auto operand = subgroupMmaLoadMatrixOp.operandAttr(); + + // Emit information for the memref operands. + auto promotedSrcOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(0), operands[0], rewriter); + + // Emit ops which compute the load offset using `srcOffsetI`, + // `srcOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + + // ((leadDimension * srcOffsetI) + srcOffsetJ)). The memrefs here are + // assumed to be normalized and hence the simple conversion works. + Value srcOffsetIVal = subgroupMmaLoadMatrixOp->getOpOperand(beginInx).get(); + Value srcOffsetJVal = + subgroupMmaLoadMatrixOp->getOpOperand(beginInx + 1).get(); + Value leadingDim32 = rewriter.create( + loc, llvmTypes.int32Type, leadDimension); + Value numElemsLeadDim = rewriter.create( + loc, llvmTypes.int32Type, leadingDim32, srcOffsetIVal); + Value loadOffset = rewriter.create( + loc, llvmTypes.int32Type, numElemsLeadDim, srcOffsetJVal); + + Value promotedSrcOpToUse; + promotedSrcOpToUse = promotedSrcOp[2]; + Value actualOffset = rewriter.create( + loc, llvmTypes.int32Type, loadOffset, promotedSrcOpToUse); + Value loadAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.f16Ty, + srcMemrefType.getMemorySpaceAsInt()), + promotedSrcOp[1], ArrayRef{actualOffset}); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits and semantics mathc with the + // intrinsic exposed by NVPTX backend. + Value loadAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.int32Type, + srcMemrefType.getMemorySpaceAsInt()), + loadAddress); + + // Get the shape of the MMAMatrix type being returned. The shape will + // choose which intrinsic this op will be lowered to. + gpu::MMAMatrixType retType = + subgroupMmaLoadMatrixOp.res().getType().cast(); + ArrayRef retTypeShape = retType.getShape(); + + Type resType; + // StringRef operandStr = operand.cast().getValue(); + StringRef operandStr = retType.getOperand(); + if (operandStr.equals("AOp") || operandStr.equals("BOp")) { + resType = llvmTypes.fragArrayABTy; + } else { + if (srcMemrefType.getElementType().isF16()) + resType = llvmTypes.fragArrayCDTy; + else if (srcMemrefType.getElementType().isF32()) + resType = llvmTypes.fragArrayCDF32Ty; + else + return failure(); + } + + // Create nvvm.mma_load op according to the operand types. + SmallVector loadOpOperands({loadAddressCasted, leadingDim32}); + if (operandStr.equals("AOp")) { + if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + NVVM::WMMALoadAM16N16K16Op wmmaLoadAOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadAOp.getResult()); + } else { + // Only M16N16K16 version implemented. Add cases as more versions are + // implemented. + return failure(); + } + } else if (operandStr.equals("BOp")) { + if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + NVVM::WMMALoadBM16N16K16Op wmmaLoadBOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadBOp.getResult()); + } else { + // Only M16N16K16 version implemented. Add cases as more versions are + // implemented. + return failure(); + } + } else { + if (retTypeShape[0] == 16 && retTypeShape[1] == 16) { + if (srcMemrefType.getElementType().isF16()) { + NVVM::WMMALoadCF16M16N16K16Op wmmaLoadCOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadCOp.getResult()); + } else if (srcMemrefType.getElementType().isF32()) { + NVVM::WMMALoadCF32M16N16K16Op wmmaLoadCOp = + rewriter.create(loc, resType, + loadOpOperands); + rewriter.replaceOp(op, wmmaLoadCOp.getResult()); + } + } else { + // Only M16N16K16 version implemented. Add cases as more versions are + // implemented. + return failure(); + } + } + return success(); + } + +private: + /// Contains definitions of all the LLVM types which are used for lowering + /// this GPU subgroupMmaLoadMatrixOp. + CommonLLVMTypes llvmTypes; +}; + +/// This class implements the conversion of GPU MMA storeOp to wmma.store op +/// in the NVVM dialect. The conversion not only emits the NVVM op but also +/// emits code that is necessary to unpack the data in the source and convert +/// the data in the format that is needed by the NVVM op. +struct WmmaStoreOpToNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaStoreOpToNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaStoreMatrixOp subgroupMmaStoreMatrixOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaStoreMatrixOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + int8_t indexTypeBitwidth = this->getTypeConverter()->getIndexTypeBitwidth(); + // The corresponding intrinsics expects leadDimension to be a 32-bit + // integer, so all the calculations of linearizing the load address + // must also follow this restriction. + if (indexTypeBitwidth != 32) + return rewriter.notifyMatchFailure( + op, "Expected indices to the meref to be 32-bit wide."); + + Location loc = op->getLoc(); + + // Destination memref of the original op. + MemRefType dstMemrefType = + subgroupMmaStoreMatrixOp.dstMemref().getType().cast(); + + auto promotedDstOp = this->getTypeConverter()->promoteOperands( + loc, op->getOperand(1), operands[1], rewriter); + + auto leadDimension = subgroupMmaStoreMatrixOp.leadDimensionAttr(); + unsigned beginInx = + subgroupMmaStoreMatrixOp.indices().getBeginOperandIndex(); + + // Emit ops which compute the store offset using `dstOffsetI`, + // `dstOffsetJ`. The actualOffset is (memrefOffset + (alignedPtr + + // ((leadDimension * dstOffsetI) + dstOffsetJ)). + Value dstOffsetIVal = subgroupMmaStoreMatrixOp.getOperand(beginInx); + Value dstOffsetJVal = subgroupMmaStoreMatrixOp.getOperand(beginInx + 1); + Value leadingDim32 = rewriter.create( + loc, llvmTypes.int32Type, leadDimension); + Value numElemsLeadDim = rewriter.create( + loc, llvmTypes.int32Type, leadingDim32, dstOffsetIVal); + Value loadOffset = rewriter.create( + loc, llvmTypes.int32Type, numElemsLeadDim, dstOffsetJVal); + + Value promotedDstOpToUse; + promotedDstOpToUse = promotedDstOp[2]; + Value actualOffset = rewriter.create( + loc, llvmTypes.int32Type, loadOffset, promotedDstOpToUse); + Value storeAddress = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.f16Ty, + dstMemrefType.getMemorySpaceAsInt()), + promotedDstOp[1], ArrayRef{actualOffset}); + + // Bitcast the base address pointer of the destination memref, So that + // values can be stored in chunks of 32-bits and semantics mathc with the + // intrinsic exposed by NVPTX backend. + Value storeAddressCasted = rewriter.create( + loc, + LLVM::LLVMPointerType::get(llvmTypes.int32Type, + dstMemrefType.getMemorySpaceAsInt()), + storeAddress); + + SmallVector storeOpOperands; + storeOpOperands.push_back(storeAddressCasted); + + // Get the shape of the MMAMatrix type being stored. The shape will + // choose which intrinsic this op will be lowered to. + gpu::MMAMatrixType srcType = + subgroupMmaStoreMatrixOp.src().getType().cast(); + ArrayRef srcTypeShape = srcType.getShape(); + + // Unpack the results from the source. + if (subgroupMmaStoreMatrixOp.src() + .getType() + .cast() + .getElementType() == llvmTypes.f16Ty) { + for (unsigned i = 0, e = llvmTypes.numHalfsInOpFrags[llvmTypes.D]; i < e; + ++i) { + Value toUse = rewriter.create( + loc, llvmTypes.f16x2Ty, operands[0], rewriter.getI32ArrayAttr(i)); + storeOpOperands.push_back(toUse); + } + storeOpOperands.push_back(leadingDim32); + + // Create nvvm.mma_store op. + if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) + rewriter.create(loc, storeOpOperands); + else + // Only M16N16K16 version implemented. Add cases as more versions are + // implemented. + return failure(); + + rewriter.eraseOp(op); + return success(); + } + if (subgroupMmaStoreMatrixOp.src() + .getType() + .cast() + .getElementType() == llvmTypes.f32Ty) { + for (unsigned i = 0, e = 8; i < e; ++i) { + Value toUse = rewriter.create( + loc, llvmTypes.f32Ty, operands[0], rewriter.getI32ArrayAttr(i)); + storeOpOperands.push_back(toUse); + } + storeOpOperands.push_back(leadingDim32); + + // Create nvvm.mma_store op. + if (srcTypeShape[0] == 16 && srcTypeShape[1] == 16) + rewriter.create(loc, storeOpOperands); + else + // Only M16N16K16 version implemented. Add cases as more versions are + // implemented. + return failure(); + + rewriter.eraseOp(op); + return success(); + } + + return failure(); + } + +private: + /// Definitions of all the LLVM types which are used for lowering this GPU + /// SubgroupMmaStoreMatrixOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir + +#endif diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h b/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaMmaOptoNvvmLowering.h @@ -0,0 +1,146 @@ +//===--- WmmaMmaOpToNVVMLowering.h - MmaOp to NVVM Op lowering -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains patterns to lower the GPU subgroup MMA_computeOp to the +// NVVM Dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_GPUTONVVM_WMMASTOREOPTONVVMLOWERING_H +#define MLIR_LIB_CONVERSION_GPUTONVVM_WMMASTOREOPTONVVMLOWERING_H + +#include "CommonTypes.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +namespace mlir { +/// This class implemtents the conversion of GPU MMA computeOp to wmma.mma op +/// in the NVVM dialect. +struct WmmaMmaOpToNVVMLowering + : public ConvertOpToLLVMPattern { +public: + MLIRContext *context = &this->getTypeConverter()->getContext(); + + explicit WmmaMmaOpToNVVMLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + llvmTypes(context) {} + + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) + return rewriter.notifyMatchFailure( + op, "Cannot convert if operands aren't of LLVM type."); + + return success(); + } + + LogicalResult + matchAndRewrite(gpu::SubgroupMmaComputeOp subgroupMmaComputeOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Operation *op = subgroupMmaComputeOp.getOperation(); + if (failed(areAllLLVMTypes(op, operands, rewriter))) + return failure(); + + Location loc = op->getLoc(); + + // The wmma.mma intrinsic in llvm requires the operands as individual + // values. So individual elements from the memrefs need to be extracted and + // then passed on to the intrinsic call. Emit llvm ops to extract individual + // values form lowered memrefs. + SmallVector unpackedOps; + + auto unpackOp = [&](CommonLLVMTypes::OperandMap op, Value operand, + unsigned numElems, Type elemType) { + for (unsigned i = 0; i < numElems; ++i) { + Value toUse = rewriter.create( + loc, elemType, operand, rewriter.getI32ArrayAttr(i)); + unpackedOps.push_back(toUse); + } + }; + + // Get the shapes of the MMAMatrix type being used. The shapes will + // choose which intrinsic this op will be lowered to. + gpu::MMAMatrixType aType = + subgroupMmaComputeOp.opA().getType().cast(); + ArrayRef aTypeShape = aType.getShape(); + gpu::MMAMatrixType bType = + subgroupMmaComputeOp.opA().getType().cast(); + ArrayRef bTypeShape = bType.getShape(); + gpu::MMAMatrixType cType = + subgroupMmaComputeOp.opA().getType().cast(); + ArrayRef cTypeShape = cType.getShape(); + + if (subgroupMmaComputeOp.opC() + .getType() + .cast() + .getElementType() == llvmTypes.f16Ty) { + unpackOp(llvmTypes.A, operands[0], + llvmTypes.numHalfsInOpFrags[llvmTypes.A], llvmTypes.f16x2Ty); + unpackOp(llvmTypes.B, operands[1], + llvmTypes.numHalfsInOpFrags[llvmTypes.B], llvmTypes.f16x2Ty); + unpackOp(llvmTypes.C, operands[2], + llvmTypes.numHalfsInOpFrags[llvmTypes.C], llvmTypes.f16x2Ty); + + if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && + bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { + // Create nvvm.wmma.mma op. + NVVM::WMMAMmaF16F16M16N16K16Op wmmaMmaOp = + rewriter.create( + loc, llvmTypes.fragArrayCDTy, unpackedOps); + + rewriter.replaceOp(op, wmmaMmaOp.getResult()); + return success(); + } else { + // Only M16N16K16 version implemented. Add cases as more versions are + // implemented. + return failure(); + } + } + + if (subgroupMmaComputeOp.opC() + .getType() + .cast() + .getElementType() == llvmTypes.f32Ty) { + unpackOp(llvmTypes.A, operands[0], + llvmTypes.numHalfsInOpFrags[llvmTypes.A], llvmTypes.f16x2Ty); + unpackOp(llvmTypes.B, operands[1], + llvmTypes.numHalfsInOpFrags[llvmTypes.B], llvmTypes.f16x2Ty); + unpackOp(llvmTypes.C, operands[2], 8, llvmTypes.f32Ty); + + if (aTypeShape[0] == 16 && aTypeShape[1] == 16 && bTypeShape[0] == 16 && + bTypeShape[1] == 16 && cTypeShape[0] == 16 && cTypeShape[1] == 16) { + // Create nvvm.wmma.mma op. + NVVM::WMMAMmaF32F32M16N16K16Op wmmaMmaOp = + rewriter.create( + loc, llvmTypes.fragArrayCDF32Ty, unpackedOps); + + rewriter.replaceOp(op, wmmaMmaOp.getResult()); + return success(); + } else { + // Only M16N16K16 version implemented. Add cases as more versions are + // implemented. + return failure(); + } + } + + return failure(); + } + +private: + /// Definitions of all the LLVM types which are used for lowering + /// this GPU subgroupMmaComputeOp. + CommonLLVMTypes llvmTypes; +}; +} // namespace mlir + +#endif diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -0,0 +1,92 @@ +// RUN: mlir-opt --convert-gpu-to-nvvm="index-bitwidth=32" --split-input-file %s | FileCheck %s + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_load_op() -> + // CHECK-SAME: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> { + //func @gpu_wmma_load_op() -> (!gpu.mmafragment<8, vector<2xf16>>) { + func @gpu_wmma_load_op() -> (!gpu.mma_matrix<16x16xf16, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + %j = constant 16 : index + //%0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {operand = "AOp", leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mmafragment<8, vector<2xf16>> + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> + //CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + //CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[OFFSET:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + //CHECK-NEXT: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + //CHECK-NEXT: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + //CHECK-NEXT: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + //CHECK-NEXT: %[[FRAG:.*]] = nvvm.wmma.m16n16k16.load.a.f16.row.stride %[[CADDRESS]], %[[LDM]] : (!llvm.ptr, i32) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: llvm.return %[[FRAG]] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //return %0 : !gpu.mmafragment<8, vector<2xf16>> + return %0 : !gpu.mma_matrix<16x16xf16, "AOp"> + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_store_op + // CHECK-SAME: (%[[D:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { + func @gpu_wmma_store_op(%arg0 : !gpu.mma_matrix<16x16xf16, "DOp">) -> () { + %sg = memref.alloca(){alignment = 32} : memref<32x32xf16, 3> + %i = constant 16 : index + %j = constant 16 : index + gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index} : !gpu.mma_matrix<16x16xf16, "DOp">, memref<32x32xf16, 3> + //CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i32 + //CHECK: %[[BASE:.*]] = llvm.extractvalue %17[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK-NEXT: %[[OFFSET:.*]] = llvm.extractvalue %17[2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + //CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + //CHECK-NEXT: %[[LI:.*]] = llvm.mul %[[LDM]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i32 + //CHECK-NEXT: %[[LIJO:.*]] = llvm.add %[[LIJ]], %[[OFFSET]] : i32 + //CHECK-NEXT: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJO]]] : (!llvm.ptr, i32) -> !llvm.ptr + //CHECK-NEXT: %[[CADDRESS:.*]] = llvm.bitcast %[[ADDRESS]] : !llvm.ptr to !llvm.ptr + //CHECK-NEXT: %[[EL1:.*]] = llvm.extractvalue %[[D]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[EL2:.*]] = llvm.extractvalue %[[D]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[EL3:.*]] = llvm.extractvalue %[[D]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[EL4:.*]] = llvm.extractvalue %[[D]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: nvvm.wmma.m16n16k16.store.d.f16.row.stride %[[CADDRESS]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]], %[[LDM]] : !llvm.ptr, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, i32 + //CHECK-NEXT: llvm.return + return + } +} + +// ----- + +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_mma_op + // CHECK-SAME: (%[[A:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[B:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, %[[C:.*]]: !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>) { + func @gpu_wmma_mma_op(%A : !gpu.mma_matrix<16x16xf16, "AOp">, %B : !gpu.mma_matrix<16x16xf16, "BOp">, %C : !gpu.mma_matrix<16x16xf16, "COp">) -> () { + %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp">, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "DOp"> + //CHECK: %[[A1:.*]] = llvm.extractvalue %[[A]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A2:.*]] = llvm.extractvalue %[[A]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A3:.*]] = llvm.extractvalue %[[A]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A4:.*]] = llvm.extractvalue %[[A]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A5:.*]] = llvm.extractvalue %[[A]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A6:.*]] = llvm.extractvalue %[[A]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A7:.*]] = llvm.extractvalue %[[A]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[A8:.*]] = llvm.extractvalue %[[A]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B1:.*]] = llvm.extractvalue %[[B]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B2:.*]] = llvm.extractvalue %[[B]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B3:.*]] = llvm.extractvalue %[[B]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B4:.*]] = llvm.extractvalue %[[B]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B5:.*]] = llvm.extractvalue %[[B]][4 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B6:.*]] = llvm.extractvalue %[[B]][5 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B7:.*]] = llvm.extractvalue %[[B]][6 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[B8:.*]] = llvm.extractvalue %[[B]][7 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C1:.*]] = llvm.extractvalue %[[C]][0 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C2:.*]] = llvm.extractvalue %[[C]][1 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C3:.*]] = llvm.extractvalue %[[C]][2 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %[[C4:.*]] = llvm.extractvalue %[[C]][3 : i32] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: %{{.*}} = nvvm.wmma.m16n16k16.mma.row.row.f16.f16 %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]], %[[A7]], %[[A8]], %[[B1]], %[[B2]], %[[B3]], %[[B4]], %[[B5]], %[[B6]], %[[B7]], %[[B8]], %[[C1]], %[[C2]], %[[C3]], %[[C4]] : vector<2xf16> -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> + //CHECK-NEXT: llvm.return + return + } +}