diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -9,6 +9,7 @@ #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include namespace mlir { @@ -22,8 +23,11 @@ namespace gpu { class GPUModuleOp; +class MMAMatrixType; } +LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type); + /// Configure target to convert from the GPU dialect to NVVM. void configureGpuToNVVMConversionLegality(ConversionTarget &target); diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -1068,8 +1068,8 @@ ``` }]; - let arguments = (ins Arg>:$opA, - Arg>:$opB, + let arguments = (ins Arg>:$opA, + Arg>:$opB, Arg>:$opC); let results = (outs GPU_MMAMatrix:$res); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -134,34 +134,8 @@ // Lowering for MMAMatrixType. converter.addConversion([&](gpu::MMAMatrixType type) -> Type { - // The number of items in structToReturn are dependent on the the dataType - // and the MMA operand that this operation is associated with. - llvm::DenseMap numElemsPerThreadF16, - numElemsPerThreadF32; - numElemsPerThreadF16["AOp"] = 8; - numElemsPerThreadF16["BOp"] = 8; - numElemsPerThreadF16["COp"] = 4; - numElemsPerThreadF32["AOp"] = 8; - numElemsPerThreadF32["BOp"] = 8; - numElemsPerThreadF32["COp"] = 8; - Type structToReturn; - if (type.getElementType().isF16()) { - // Number of f16's in 32-bit. - unsigned vecSize = 2; - Type vec = VectorType::get(vecSize, FloatType::getF16(&getContext())); - unsigned size = numElemsPerThreadF16[type.getOperand()]; - SmallVector elements(size, vec); - structToReturn = - LLVM::LLVMStructType::getLiteral(&getContext(), elements); - } else if (type.getElementType().isF32()) { - unsigned size = numElemsPerThreadF32[type.getOperand()]; - SmallVector elements(size, FloatType::getF32(&getContext())); - structToReturn = - LLVM::LLVMStructType::getLiteral(&getContext(), elements); - } - return structToReturn; + return convertMMAToLLVMType(type); }); - RewritePatternSet patterns(m.getContext()); RewritePatternSet llvmPatterns(m.getContext()); diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -59,16 +60,6 @@ llvm_unreachable("Unsupported type"); } -/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. -static LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { - NVVM::MMAFrag frag = convertOperand(type.getOperand()); - NVVM::MMATypes eltType = getElementType(type); - std::pair typeInfo = - inferMMAType(eltType, frag, type.getContext()); - return LLVM::LLVMStructType::getLiteral( - type.getContext(), SmallVector(typeInfo.second, typeInfo.first)); -} - /// This class implements the conversion of GPU MMA loadOp to wmma.load op /// in the NVVM dialect. The conversion not only emits the NVVM op but also /// emits code that is necessary to store the data in the destination memref @@ -433,6 +424,17 @@ } // anonymous namespace namespace mlir { + +/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. +LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { + NVVM::MMAFrag frag = convertOperand(type.getOperand()); + NVVM::MMATypes eltType = getElementType(type); + std::pair typeInfo = + inferMMAType(eltType, frag, type.getContext()); + return LLVM::LLVMStructType::getLiteral( + type.getContext(), SmallVector(typeInfo.second, typeInfo.first)); +} + void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.insert