diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h rename from mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h rename to mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h --- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.h +++ b/mlir/include/mlir/Dialect/NVGPU/Utils/MMAUtils.h @@ -1,4 +1,4 @@ -//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===// +//===-- MMAUtils.h - MLIR NVGPU dialect utilities for MMA operations-------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,24 +6,30 @@ // //===----------------------------------------------------------------------===// // -// This file provides utilities to assist in the lowering of Vector operations -// to GPU dialect MMA operations. +// This file provides utilities to assist in the lowering of other dialects +// (e.g. Vector) to `nvgpu.mma.*` dialect operations. // //===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H -#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" +#ifndef MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H +#define MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H + #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Types.h" namespace mlir { +namespace vector { +enum class IteratorType : uint32_t; +class ContractionOp; +} // namespace vector + +namespace NVVM { +enum class MMALayout : uint32_t; +} // namespace NVVM + namespace nvgpu { +/// Represents the role of an operand in an MMA instruction: +/// `result := matmul(A, B) + C` enum class MatMulOperandRole : int32_t { A = 0, B, C }; /// Collects information about a warp-level matrix operand represented by a @@ -33,8 +39,10 @@ MatMulOperandRole operandRole; }; -/// Given an op that operates on a VectorType representing a warp-level matrix -/// operand, the function returns a struct containing relevant type information. +/// If `op` is a `vector.transfer_write`, return the `WarpMatrixInfo` for the +/// vector operand. If op is a `vector.transfer_read`, `vector.contraction`, or +/// `arith.constant`, return the `WarpMatrixInfo` corresponding to the result. +/// Otherwise, return failure. FailureOr getWarpMatrixInfo(Operation *op); /// Returns the number of bits in a single tile row. It is either 128, 256, or @@ -67,6 +75,8 @@ getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, const WarpMatrixInfo &fragmentType); +/// Encapsulates the parameters needed to lower a `nvgpu.ldmatrix` operation to +/// `nvvm.ldmatrix`. struct LdMatrixParams { VectorType fragmentType; bool isAccum; @@ -75,6 +85,8 @@ NVVM::MMALayout targetLayout; }; +/// Given `type` that contains info for a warp-matrix operand and whether or not +/// the load is a transposed load, return the LdMatrixParams. FailureOr getLdMatrixParams(const WarpMatrixInfo &type, bool transpose); /// Returns an AffineMap which maps a single dimension representing the laneId @@ -84,8 +96,10 @@ getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, const LdMatrixParams ¶ms); -// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted -// to MMA matmul. +/// Transform `vector.contract` into (m,k)x(n,k)x(m,n) form so that it can be +/// converted to `nvgpu.mma.sync`. This specific form is meant to indicate that +/// the vector operands are organized such that the reduction dimension is +/// contiguous. struct PrepareContractToGPUMMASync : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -97,4 +111,4 @@ } // namespace nvgpu } // namespace mlir -#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H +#endif // MLIR_DIALECT_NVGPU_UTILS_MMAUTILS_H diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_conversion_library(MLIRVectorToGPU VectorToGPU.cpp - NvGpuSupport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU @@ -14,6 +13,7 @@ MLIRLLVMDialect MLIRMemRefDialect MLIRNVGPUDialect + MLIRNVGPUUtils MLIRTransforms MLIRVectorDialect MLIRVectorUtils diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -10,16 +10,17 @@ // //===----------------------------------------------------------------------===// -#include - -#include "NvGpuSupport.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" +#include + #include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" diff --git a/mlir/lib/Dialect/NVGPU/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/CMakeLists.txt --- a/mlir/lib/Dialect/NVGPU/CMakeLists.txt +++ b/mlir/lib/Dialect/NVGPU/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(Utils) add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/NVGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/NVGPU/Utils/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRNVGPUUtils + MMAUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arithmetic/Utils + + LINK_LIBS PUBLIC + MLIRAffineDialect + MLIRLLVMDialect + MLIRNVGPUDialect + MLIRNVVMDialect + MLIRVectorDialect + MLIRIR + ) diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp rename from mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp rename to mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp --- a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -1,37 +1,32 @@ -//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===// +//===- MMAUtils.cpp - MLIR NVGPU dialect utils for MMA operations----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// This file provides utilities to assist in the lowering of Vector operations -// to NvGPU dialect MMA operations. -// -//===----------------------------------------------------------------------===// +#include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" -#include "NvGpuSupport.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" -namespace mlir { -namespace nvgpu { -namespace { +using namespace mlir; +using namespace mlir::nvgpu; /// There are always 4 threads per [128|256|512] bit row. -constexpr int64_t kThreadsPerRow = 4; +static constexpr int64_t kThreadsPerRow = 4; +static constexpr int64_t kNumRowsPerTile = 8; -constexpr int64_t kNumRowsPerTile = 8; - -bool isAccumulatorOrResult(MatMulOperandRole operandType) { +static bool isAccumulatorOrResult(MatMulOperandRole operandType) { return operandType == MatMulOperandRole::C; } /// Returns the number of registers which compose a matrix fragment held by a /// single thread. -int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { +static int64_t inferNumRegistersPerMatrixFragment(const WarpMatrixInfo &type) { int64_t lineSize = inferTileWidthInBits(type); auto shape = type.vectorType.getShape(); return (shape[0] / kNumRowsPerTile) * @@ -41,17 +36,16 @@ /// Returns the number of 8 x [128|256|512] bit tiles that compose the given /// operand shape. -std::array getTileShape(ArrayRef operandShape, - Type elementType, int64_t lineSizeBits) { +static std::array getTileShape(ArrayRef operandShape, + Type elementType, + int64_t lineSizeBits) { // For each 8x128bit square, a thread is responsible for one 32bit register. return {operandShape[0] / kNumRowsPerTile, (operandShape[1] * elementType.getIntOrFloatBitWidth()) / lineSizeBits}; } -} // namespace - -FailureOr getWarpMatrixInfo(Operation *op) { +FailureOr nvgpu::getWarpMatrixInfo(Operation *op) { WarpMatrixInfo info; // Determine the vector type. @@ -84,7 +78,7 @@ return info; } -int64_t inferTileWidthInBits(const WarpMatrixInfo &type) { +int64_t nvgpu::inferTileWidthInBits(const WarpMatrixInfo &type) { bool isAcc = isAccumulatorOrResult(type.operandRole); Type elType = type.vectorType.getElementType(); if (isAcc && elType.getIntOrFloatBitWidth() == 32) { @@ -97,7 +91,7 @@ } FailureOr -getMmaSyncRegisterType(const WarpMatrixInfo &type) { +nvgpu::getMmaSyncRegisterType(const WarpMatrixInfo &type) { MLIRContext *ctx = type.vectorType.getContext(); const bool isAccum = isAccumulatorOrResult(type.operandRole); @@ -170,8 +164,8 @@ } FailureOr -getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, - const WarpMatrixInfo &fragmentType) { +nvgpu::getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + const WarpMatrixInfo &fragmentType) { Type elementType = fragmentType.vectorType.getElementType(); ArrayRef operandShape = fragmentType.vectorType.getShape(); FailureOr regInfo = @@ -205,8 +199,8 @@ (logicalValueIdDim % elementsPerRegister)}); } -FailureOr getLdMatrixParams(const WarpMatrixInfo &type, - bool transpose) { +FailureOr +nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) { LdMatrixParams params; Type elType = type.vectorType.getElementType(); params.fragmentType = type.vectorType; @@ -235,8 +229,8 @@ } FailureOr -getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, - const LdMatrixParams ¶ms) { +nvgpu::getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms) { // One thread per 128b row. const int64_t kNumThreadsPerTile = kNumRowsPerTile; const int bitsPerElement = static_cast( @@ -273,9 +267,8 @@ return failure(); } -LogicalResult -PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const { +LogicalResult nvgpu::PrepareContractToGPUMMASync::matchAndRewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value lhs = op.getLhs(); Value rhs = op.getRhs(); @@ -330,6 +323,3 @@ op.getIteratorTypes()); return success(); } - -} // namespace nvgpu -} // namespace mlir