diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -803,8 +803,13 @@ "dialect"; let constructor = "mlir::createConvertVectorToGPUPass()"; let dependentDialects = [ - "memref::MemRefDialect", - "gpu::GPUDialect" + "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect", + "vector::VectorDialect", "nvgpu::NVGPUDialect" + ]; + + let options = [ + Option<"useNvGpu", "use-nvgpu", "bool", /*default=*/"false", + "convert to NvGPU ops instead of GPU dialect ops"> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToGPU/NvGpuSupport.h b/mlir/include/mlir/Conversion/VectorToGPU/NvGpuSupport.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToGPU/NvGpuSupport.h @@ -0,0 +1,92 @@ +//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to GPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H +#define MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace nvgpu { + +/// Helps to calculate the offsets within the tile for any NVVM/PTX MMA operand +/// that has a base tile size of 8 elements x [128|256|512] bits +namespace NvvmMmaOperandBaseTileOperand8x128 { + +/// Returns the number of bits in a single tile row. It is either 128, 256, or +/// 512 bits depending on the data type and whether the operand is an +/// accumulator. +int64_t inferTileWidthInBits(Type elementType, bool isAcc); + +/// Specifies information about the registers which compose a matrix fragment +/// according to the PTX documentation. +struct FragmentElementInfo { + Type registerLLVMType; + int64_t elementsPerRegister; + int64_t registerWidthBits; + int64_t numRegistersPerFragment; +}; + +/// Returns a FragmentElementInfo struct describing the register types for the +/// given matrix fragment type. +FailureOr getRegisterType(gpu::MMAMatrixType type); + +/// Returns an AffineMap which maps a two dimensions representing (laneId, +/// logicalValueId) and returns two results representing offsets within a +/// matrix operand. The offsets point to the values the thread is responsible +/// for (AKA the matrix fragment values) during a warp-collective matrix +/// operation. For a visual reference of this LaneId -> (row, col) mapping, +/// please see NVIDIA's PTX documentation: +/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + gpu::MMAMatrixType fragmentType); + +struct LdMatrixParams { + gpu::MMAMatrixType fragmentType; + int64_t numTiles; + IteratorType contiguousDimType; + NVVM::MMALayout targetLayout; +}; + +FailureOr getLdMatrixParams(gpu::MMAMatrixType fragType, + bool transpose); +/// Returns an AffineMap which maps a single dimension representing the laneId +/// to two results representing offsets within the matrix operand that should +/// be the pointer locations a thread should pass to the ldmatrix instruction. +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms); + +} // namespace NvvmMmaOperandBaseTileOperand8x128 + +// Transform contract into (m, k)x(n, k)x(m, n) form so that it can be converted +// to MMA matmul. +struct PrepareContractToGPUMMASync + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; +}; + +} // namespace nvgpu +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOGPU_NVGPUSUPPORT_H diff --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h --- a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h +++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h @@ -17,16 +17,25 @@ class RewritePatternSet; /// Patterns to transform vector ops into a canonical form to convert to MMA -/// matrix operations. -void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns); +/// matrix operations. If `useNvGpu` is true, then the patterns will populated +/// will prepare for conversion to `nvgpu` mma operations rather than the `gpu` +/// dialect WMMA operations. +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useNvGpu = false); /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will /// convert slice of operations that can be legally converted to MMA operations. /// The rest of the vector operations are left untouched. void convertVectorToMMAOps(Operation *rootOp); +/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons +/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice +/// of operations that can be legally lowered on this path while the rest of +/// the vector operations are left untouched. +LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp); + /// Convert from vector to GPU ops. -std::unique_ptr createConvertVectorToGPUPass(); +std::unique_ptr createConvertVectorToGPUPass(bool useNvGpu = false); } // namespace mlir diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -51,6 +51,10 @@ class LLVMDialect; } // namespace LLVM +namespace nvgpu { +class NVGPUDialect; +} + namespace NVVM { class NVVMDialect; } // namespace NVVM diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_conversion_library(MLIRVectorToGPU VectorToGPU.cpp + NvGpuSupport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU diff --git a/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/NvGpuSupport.cpp @@ -0,0 +1,285 @@ +//===- NvGpuSupport.cpp - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to NvGPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToGPU/NvGpuSupport.h" +#include "mlir/Dialect/GPU/GPUDialect.h" + +namespace mlir { +namespace nvgpu { +namespace NvvmMmaOperandBaseTileOperand8x128 { + +namespace { + +/// There are always 4 threads per [128|256|512] bit row. +constexpr int64_t kThreadsPerRow = 4; + +constexpr int64_t kNumRowsPerTile = 8; + +/// Returns the number of registers which compose a matrix fragment held by a +/// single thread. +int64_t inferNumRegistersPerMatrixFragment(gpu::MMAMatrixType type) { + int64_t lineSize = inferTileWidthInBits(type.getElementType(), + /*isAcc=*/type.getOperand() == "COp"); + auto shape = type.getShape(); + return (shape[0] / kNumRowsPerTile) * + (shape[1] * type.getElementType().getIntOrFloatBitWidth()) / lineSize; +} + +/// Returns the number of 8 x [128|256|512] bit tiles that compose the given +/// operand shape. +std::array getTileShape(ArrayRef operandShape, + Type elementType, int64_t lineSizeBits) { + // For each 8x128bit square, a thread is responsible for one 32bit register. + return {operandShape[0] / kNumRowsPerTile, + (operandShape[1] * elementType.getIntOrFloatBitWidth()) / + lineSizeBits}; +} + +} // namespace + +int64_t inferTileWidthInBits(Type elementType, bool isAcc) { + if (isAcc && elementType.getIntOrFloatBitWidth() == 32) { + return 256; + } + if (elementType.getIntOrFloatBitWidth() == 64) { + return isAcc ? 512 : 256; + } + return 128; +} + +FailureOr getRegisterType(gpu::MMAMatrixType type) { + MLIRContext *ctx = type.getContext(); + bool isAccum = type.getOperand() == "COp"; + if (type.getElementType().isF16()) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + + // f64 acc + Type f64Ty = Float64Type::get(ctx); + if (type.getElementType().isF64() && isAccum) { + return FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, + inferNumRegistersPerMatrixFragment(type)}; + } + + // f64 operand + if (type.getElementType().isF64() && !isAccum) { + return FragmentElementInfo{f64Ty, 1, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + + // int8 operand + if (type.getElementType().isInteger(8)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + // Integer 32bit acc operands + if (type.getElementType().isInteger(32)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + + // Floating point 32bit acc operands + if (type.getElementType().isF32() && isAccum) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(Float32Type::get(ctx), 2), 2, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + return failure(); +} + +static AffineMap getRegisterIndexToTileOffsetMap(OpBuilder &base, + Type elementType, + ArrayRef operandShape, + bool isAccumulator, + int64_t elementsPerRegister, + AffineExpr logicalValueId) { + const int64_t lineSize = inferTileWidthInBits(elementType, isAccumulator); + const int64_t elementsPerLine = + lineSize / elementType.getIntOrFloatBitWidth(); + const std::array num8x128bTiles = + getTileShape(operandShape, elementType, lineSize); + AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); + return AffineMap::get( + 2, 0, + {(registerIdx % num8x128bTiles[0]) * 8, + (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, + base.getContext()); +} + +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + gpu::MMAMatrixType fragmentType) { + Type elementType = fragmentType.getElementType(); + ArrayRef operandShape = fragmentType.getShape(); + bool isAccumulator = fragmentType.getOperand() == "COp"; + FailureOr + regInfo = getRegisterType(fragmentType); + if (failed(regInfo)) + return failure(); + + const int64_t elementBitWidth = + fragmentType.getElementType().getIntOrFloatBitWidth(); + const int64_t elementsPerRegister = + regInfo->registerWidthBits / elementBitWidth; + + AffineExpr laneId, logicalValueIdDim; + bindDims(builder.getContext(), laneId, logicalValueIdDim); + + // Determine what register logicalValueId corresponds to. Use that as a + // linear index into the coordinate mapping `index -> (tile row, tile col)`. + AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( + builder, elementType, operandShape, isAccumulator, elementsPerRegister, + logicalValueIdDim); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(2, 0, dimExprs, builder.getContext()); + }; + + auto tileRow = registerIndexToTileCoord.getResult(0); + auto tileCol = registerIndexToTileCoord.getResult(1); + return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), + tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + + (logicalValueIdDim % elementsPerRegister)}); +} + +FailureOr +getLdMatrixParams(gpu::MMAMatrixType fragType, bool transpose) { + LdMatrixParams params; + params.fragmentType = fragType; + if (fragType.getOperand() == "AOp" || fragType.getOperand() == "COp") { + params.targetLayout = NVVM::MMALayout::row; + } else { + params.targetLayout = NVVM::MMALayout::col; + } + ArrayRef shape = fragType.getShape(); + params.contiguousDimType = + transpose ? IteratorType::Parallel : IteratorType::Reduction; + + if (params.targetLayout == NVVM::MMALayout::row) { + params.numTiles = + (shape[0] / kNumRowsPerTile) * + ((shape[1] * fragType.getElementType().getIntOrFloatBitWidth()) / 128); + } else { + params.numTiles = + (shape[1] / kNumRowsPerTile) * + ((shape[0] * fragType.getElementType().getIntOrFloatBitWidth()) / 128); + } + + return params; +} + +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms) { + // One thread per 128b row. + const int64_t kNumThreadsPerTile = kNumRowsPerTile; + const int bitsPerElement = static_cast( + params.fragmentType.getElementType().getIntOrFloatBitWidth()); + const int kElementsPer128b = (128 / bitsPerElement); + ArrayRef operandShape = params.fragmentType.getShape(); + AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(1, 0, dimExprs, builder.getContext()); + }; + + // This case corresponds to row-major A|C or col-major B operands. + if (params.contiguousDimType == IteratorType::Reduction) { + AffineExpr row = d0 % (operandShape[0]); + AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); + return makeMap({row, col}); + } + + // This case Corresponds to col-major A|C or row-major B operands. The + // operandShape given is already pre-transposed (e.g. 8x16 = KxN). + if (params.contiguousDimType == IteratorType::Parallel) { + const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; + // Threads are assigned in groups of 8 first across columns, then to + // rows. This is transpose of what `ldmatrix` expects, but when + // `ldmatrix` gets the `.trans` qualifier, final the effect will be to + // transpose just the blocks. + auto groupIdx = d0.floorDiv(kNumThreadsPerTile); + auto tileCol = (groupIdx % num8x128bCols); + auto tileRow = groupIdx.floorDiv(num8x128bCols); + return makeMap({tileCol * kElementsPer128b, + tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); + } + return failure(); +} + +} // namespace NvvmMmaOperandBaseTileOperand8x128 + +LogicalResult +PrepareContractToGPUMMASync::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + Value res = op.getAcc(); + + // Set up the parallel/reduction structure in right form. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m; + AffineExpr n; + AffineExpr k; + bindDims(rewriter.getContext(), m, n, k); + static constexpr std::array perm = {1, 0}; + auto iteratorTypes = op.getIteratorTypes().getValue(); + SmallVector maps = op.getIndexingMaps(); + if (iteratorTypes.size() != 3) + return failure(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return failure(); + + // The canonical form is "TNT" = A row-major, B col-major, C row-major. + const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); + if (maps == canonicalForm) { + return failure(); + } + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), + op.getIteratorTypes()); + return success(); +} + +} // namespace nvgpu +} // namespace mlir \ No newline at end of file diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -12,6 +12,7 @@ #include +#include "mlir/Conversion/VectorToGPU/NvGpuSupport.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "../PassDetail.h" @@ -19,6 +20,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/NVGPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -27,11 +29,39 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; +/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an +/// AffineMap representing offsets to apply to indices, the function fills +/// `indices` with the original indices plus the offsets. The offsets are +/// applied by taking into account the permutation map of the transfer op. If +/// the `offsetMap` has dimension placeholders, those should be provided in +/// `dimValues`. +template +static void getXferIndices(OpBuilder &b, TransferOpType xferOp, + AffineMap offsetMap, ArrayRef dimValues, + SmallVector &indices) { + indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); + Location loc = xferOp.getLoc(); + unsigned offsetsIdx = 0; + for (auto expr : xferOp.getPermutationMap().getResults()) { + if (auto dim = expr.template dyn_cast()) { + Value prevIdx = indices[dim.getPosition()]; + SmallVector dims(dimValues.begin(), dimValues.end()); + dims.push_back(prevIdx); + AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); + indices[dim.getPosition()] = makeComposedAffineApply( + b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); + continue; + } + } +} + // Return true if the contract op can be convert to MMA matmul. -static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { +static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, + bool useNvGpu) { if (llvm::size(contract.getMasks()) != 0) return false; @@ -47,7 +77,10 @@ // The contract needs to represent a matmul to be able to convert to // MMAMatrix matmul. - if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + if (!useNvGpu && + contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + return false; + if (useNvGpu && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) return false; return true; @@ -61,7 +94,7 @@ if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. - if(memrefType.getRank() < 2) + if (memrefType.getRank() < 2) return 0; int64_t offset = 0; SmallVector strides; @@ -75,7 +108,8 @@ } // Return true if the transfer op can be converted to a MMA matrix load. -static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { +static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, + bool useNvGpu) { if (readOp.getMask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; @@ -87,9 +121,14 @@ AffineExpr zero = b.getAffineConstantExpr(0); auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, readOp.getContext()); - // TODO: Support transpose once it is added to GPU dialect ops. - // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). - return !(!map.isMinorIdentity() && map != broadcastInnerDim); + + if (!useNvGpu) { + // TODO: Support transpose once it is added to GPU dialect ops. + // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). + return map.isMinorIdentity() || map == broadcastInnerDim; + } + + return true; } // Return true if the transfer op can be converted to a MMA matrix store. @@ -147,15 +186,15 @@ return convertElementwiseOpToMMA(op).hasValue(); } -static bool supportsMMaMatrixType(Operation *op) { +static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) - return transferReadSupportsMMAMatrixType(transferRead); + return transferReadSupportsMMAMatrixType(transferRead, useNvGpu); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) - return contractSupportsMMAMatrixType(contract); + return contractSupportsMMAMatrixType(contract, useNvGpu); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) @@ -203,7 +242,8 @@ // Analyze slice of operations based on convert op to figure out if the whole // slice can be converted to MMA operations. -static SetVector getOpToConvert(mlir::Operation *op) { +static SetVector getOpToConvert(mlir::Operation *op, + bool useNvGpu) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type t) { return t.isa(); }); @@ -221,8 +261,9 @@ // If any instruction cannot use MMA matrix type drop the whole // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. - if (llvm::any_of(dependentOps, - [](Operation *op) { return !supportsMMaMatrixType(op); })) + if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { + return !supportsMMaMatrixType(op, useNvGpu); + })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); @@ -351,7 +392,7 @@ static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); - assert(transferReadSupportsMMAMatrixType(op)); + assert(transferReadSupportsMMAMatrixType(op, /*useNvGpu=*/false)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); AffineMap map = op.getPermutationMap(); @@ -386,6 +427,245 @@ op.erase(); } +/// Returns the vector type which represents a matrix fragment. +static VectorType getMmaSyncVectorOperandType( + gpu::MMAMatrixType fragType, + const nvgpu::NvvmMmaOperandBaseTileOperand8x128::FragmentElementInfo + ®Info) { + SmallVector shape{regInfo.numRegistersPerFragment, + regInfo.elementsPerRegister}; + Type elType = regInfo.registerLLVMType; + if (auto vecType = elType.dyn_cast()) + elType = vecType.getElementType(); + return VectorType::get(shape, elType); +} + +/// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. +static LogicalResult +convertConstantOpMmaSync(arith::ConstantOp op, + llvm::DenseMap &valueMapping) { + assert(constantSupportsMMAMatrixType(op)); + OpBuilder b(op); + Attribute splat = + op.getValue().cast().getSplatValue(); + auto scalarConstant = + b.create(op.getLoc(), splat.getType(), splat); + auto originalVecType = op.getResult().getType().cast(); + gpu::MMAMatrixType gpuMatType = gpu::MMAMatrixType::get( + originalVecType.getShape(), originalVecType.getElementType(), + inferFragType(op)); + FailureOr + regInfo = nvgpu::NvvmMmaOperandBaseTileOperand8x128::getRegisterType( + gpuMatType); + if (failed(regInfo)) + return failure(); + VectorType vectorType = getMmaSyncVectorOperandType(gpuMatType, *regInfo); + Value result = + b.create(op.getLoc(), scalarConstant, vectorType); + valueMapping[op.getResult()] = result; + return success(); +} + +static LogicalResult +creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, + gpu::MMAMatrixType fragType, + llvm::DenseMap &valueMapping) { + Location loc = op->getLoc(); + + FailureOr + regInfo = + nvgpu::NvvmMmaOperandBaseTileOperand8x128::getRegisterType(fragType); + if (failed(regInfo)) + return failure(); + + auto params = nvgpu::NvvmMmaOperandBaseTileOperand8x128::getLdMatrixParams( + fragType, + /*transpose=*/!op.getPermutationMap().isMinorIdentity()); + if (failed(params)) + return failure(); + + // Adjust the load offset. + auto laneId = builder.create(loc); + FailureOr offsets = + nvgpu::NvvmMmaOperandBaseTileOperand8x128::getLaneIdToLdMatrixMatrixCoord( + loc, builder, *params); + if (failed(offsets)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(fragType, *regInfo); + + SmallVector indices; + getXferIndices(builder, op, *offsets, {laneId}, + indices); + nvgpu::LdMatrixOp newOp = builder.create( + loc, vectorType, op.getSource(), indices, + !op.getPermutationMap().isMinorIdentity(), params->numTiles); + valueMapping[op] = newOp->getResult(0); + return success(); +} + +static LogicalResult +createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, + gpu::MMAMatrixType fragmentType, + llvm::DenseMap &valueMapping) { + Location loc = op.getLoc(); + FailureOr + regInfo = nvgpu::NvvmMmaOperandBaseTileOperand8x128::getRegisterType( + fragmentType); + if (failed(regInfo)) { + op->emitError() << "Failed to deduce register fragment type during " + "conversion to distributed non-ldmatrix compatible load"; + return failure(); + } + + NVVM::MMALayout targetLayout = fragmentType.getOperand() == "BOp" + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; + + Value laneId = builder.create(loc); + SmallVector elements; + + // This is the individual element type. + Type loadedElType = regInfo->registerLLVMType; + VectorType vectorType = getMmaSyncVectorOperandType(fragmentType, *regInfo); + + Value fill = builder.create( + op.getLoc(), fragmentType.getElementType(), + builder.getZeroAttr(fragmentType.getElementType())); + Value result = builder.create(op.getLoc(), fill, vectorType); + + bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); + + // Vectorized loads. + if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) { + if (!loadedElType.isa()) { + loadedElType = VectorType::get({1}, loadedElType); + } + + for (int i = 0; i < vectorType.getShape()[0]; i++) { + FailureOr coords = nvgpu::NvvmMmaOperandBaseTileOperand8x128:: + getLaneIdAndValueIdToOperandCoord(op.getLoc(), builder, fragmentType); + if (failed(coords)) + return failure(); + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister)); + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + + Value el = builder.create(loc, loadedElType, + op.getSource(), newIndices); + result = builder.create(loc, el, result, + builder.getI64ArrayAttr(i)); + } + } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) { + if (auto vecType = loadedElType.dyn_cast()) { + loadedElType = vecType.getElementType(); + } + // Load each element individually. + for (int i = 0; i < vectorType.getShape()[0]; i++) { + for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; + innerIdx++) { + + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); + FailureOr coords = + nvgpu::NvvmMmaOperandBaseTileOperand8x128:: + getLaneIdAndValueIdToOperandCoord(op.getLoc(), builder, + fragmentType); + if (failed(coords)) + return failure(); + + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + Value el = builder.create(op.getLoc(), loadedElType, + op.getSource(), newIndices); + result = builder.create( + op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); + } + } + } else { + return failure(); + } + + valueMapping[op.getResult()] = result; + return success(); +} + +static LogicalResult +convertTransferReadToLoads(vector::TransferReadOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + StringRef fragType(inferFragType(op)); + gpu::MMAMatrixType type = + gpu::MMAMatrixType::get(op.getVectorType().getShape(), + op.getVectorType().getElementType(), fragType); + + bool isLdMatrixCompatible = true; + if (op.getSource().getType().cast().getMemorySpaceAsInt() != 3) { + isLdMatrixCompatible = false; + } + if (nvgpu::NvvmMmaOperandBaseTileOperand8x128::inferTileWidthInBits( + type.getElementType(), /*isAcc=*/fragType == "COp") != 128) { + isLdMatrixCompatible = false; + } + if (!op.getPermutationMap().isMinorIdentity() && + (type.getOperand() == "BOp") && + type.getElementType().getIntOrFloatBitWidth() < 16) { + isLdMatrixCompatible = false; + } + if ((type.getOperand() == "COp") && + type.getElementType().getIntOrFloatBitWidth() < 16) { + isLdMatrixCompatible = false; + } + + if (!isLdMatrixCompatible) + return createNonLdMatrixLoads(op, b, type, valueMapping); + + return creatLdMatrixCompatibleLoads(op, b, type, valueMapping); +} + +static LogicalResult +convertTransferWriteToStores(vector::TransferWriteOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Location loc = op->getLoc(); + Value matrix = valueMapping.find(op.getVector())->second; + + gpu::MMAMatrixType matType = + gpu::MMAMatrixType::get(op.getVectorType().getShape(), + op.getVectorType().getElementType(), "COp"); + FailureOr + regInfo = + nvgpu::NvvmMmaOperandBaseTileOperand8x128::getRegisterType(matType); + if (failed(regInfo)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(matType, *regInfo); + Value laneId = b.create(loc); + + for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { + Value logicalValueId = b.create( + loc, b.getIndexType(), + b.getIndexAttr(i * regInfo->elementsPerRegister)); + FailureOr coords = nvgpu::NvvmMmaOperandBaseTileOperand8x128:: + getLaneIdAndValueIdToOperandCoord(op.getLoc(), b, matType); + if (failed(coords)) + return failure(); + + Value el = b.create(loc, matrix, ArrayRef{i}); + SmallVector newIndices; + getXferIndices( + b, op, *coords, {laneId, logicalValueId}, newIndices); + b.create(loc, el, op.getSource(), newIndices); + } + op->erase(); + return success(); +} + static void convertContractOp(vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); @@ -397,6 +677,22 @@ valueMapping[op.getResult()] = matmul; } +static LogicalResult +convertContractOpToMmaSync(vector::ContractionOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Value opA = valueMapping.find(op.getLhs())->second; + Value opB = valueMapping.find(op.getRhs())->second; + Value opC = valueMapping.find(op.getAcc())->second; + int64_t m = op.getLhs().getType().cast().getShape()[0]; + int64_t n = op.getRhs().getType().cast().getShape()[0]; + int64_t k = op.getLhs().getType().cast().getShape()[1]; + Value matmul = b.create( + op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k})); + valueMapping[op.getResult()] = matmul; + return success(); +} + /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { @@ -509,13 +805,20 @@ valueMapping[op->getResult(0)] = newOp; } -void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); +void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useNvGpu) { + if (!useNvGpu) { + patterns.add( + patterns.getContext()); + return; + } + patterns + .add( + patterns.getContext()); } void mlir::convertVectorToMMAOps(Operation *rootOp) { - SetVector ops = getOpToConvert(rootOp); + SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/false); llvm::DenseMap valueMapping; for (Operation *op : ops) { if (auto transferRead = dyn_cast(op)) { @@ -538,21 +841,71 @@ } } +LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { + SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/true); + llvm::DenseMap valueMapping; + for (Operation *op : ops) { + if (llvm::TypeSwitch(op) + .Case([&](vector::TransferReadOp transferReadOp) { + return convertTransferReadToLoads(transferReadOp, valueMapping); + }) + .Case([&](vector::TransferWriteOp transferWriteOp) { + return convertTransferWriteToStores(transferWriteOp, + valueMapping); + }) + .Case([&](vector::ContractionOp contractionOp) { + return convertContractOpToMmaSync(contractionOp, valueMapping); + }) + .Case([&](scf::ForOp forOp) { + convertForOp(forOp, valueMapping); + return success(); + }) + .Case([&](scf::YieldOp yieldOp) { + convertYieldOp(yieldOp, valueMapping); + return success(); + }) + .Case([&](arith::ConstantOp constOp) { + return convertConstantOpMmaSync(constOp, valueMapping); + }) + .Default([&](Operation *op) { + op->emitError() << "unhandled vector to mma type: " << *op; + return failure(); + }) + .failed()) { + op->emitError() << "Failed to convert op " << *op; + return failure(); + } + } + return success(); +} + namespace { struct ConvertVectorToGPUPass : public ConvertVectorToGPUBase { + + explicit ConvertVectorToGPUPass(bool useNvGpu_) { + useNvGpu.setValue(useNvGpu_); + } + void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populatePrepareVectorToMMAPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + + if (useNvGpu.getValue()) { + if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) + return signalPassFailure(); + } - convertVectorToMMAOps(getOperation()); + (void)convertVectorToMMAOps(getOperation()); } }; } // namespace -std::unique_ptr mlir::createConvertVectorToGPUPass() { - return std::make_unique(); +std::unique_ptr mlir::createConvertVectorToGPUPass(bool useNvGpu) { + return std::make_unique(useNvGpu); } diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -0,0 +1,286 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-nvgpu=true})" | FileCheck %s + +//######################################################### +// INT8 row-row-row +//######################################################### + +// CHECK-DAG: [[rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)> + +// CHECK-DAG: [[rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)> +// CHECK-DAG: [[colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> +// CHECK-DAG: [[rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)> +// CHECK-DAG: [[rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)> +// CHECK-DAG: [[rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)> +// CHECK-DAG: [[rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)> +// CHECK-DAG: [[rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)> +// CHECK-DAG: [[rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)> +// CHECK-DAG: [[rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)> + +// CHECK-DAG: [[rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)> +// CHECK-DAG: [[rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)> + + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: func @m16n8k32_int8_row_row_row +func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) { + %cst_0 = arith.constant dense<0> : vector<32x8xi8> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0 : i8 + %cst0 = arith.constant 0 : i32 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK: [[row:%.+]] = affine.apply [[rowA0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colA0_map]]()[{{%.+}}] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB0_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB1_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB2_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB3_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB4_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB5_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB6_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB7_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-NOT: memref.load %arg1 + + // Verify that the operand C is distributed to loads correctly. + // CHECK: [[row:%.+]] = affine.apply [[rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK-NOT: vector.load %arg2{{.*}} + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8>, vector<8x32xi8> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> + // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> + + // CHECK: [[row:%.+]] = affine.apply [[rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32> + return +} + +// ----- + +//######################################################### +// f64 row-row-row +//######################################################### +// CHECK-DAG: [[rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)> +// CHECK-DAG: [[colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)> + +// CHECK-DAG: [[rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)> +// CHECK-DAG: [[colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> + +// CHECK-DAG: [[rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40) + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: func @m8n8k4_f64_row_row_row +func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) { + %cst_0 = arith.constant dense<0.0> : vector<4x8xf64> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0.0 : f64 + %cst0 = arith.constant 0.0 : f64 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA0_map]] + // CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowb0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colb0_map]] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colC0_map]] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64> + // CHECK: [[d:%.+]] = nvgpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colC0_map]] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64> + return +} + +// ----- + +//######################################################### +// FP16 row-row-row +//######################################################### + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK-DAG: [[rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> +// CHECK-DAG: [[colB_map:#.+]] = affine_map<() -> (3)> + +// CHECK: func @m16n8k16_fp16_row_row_row +func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true} + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} + +// ----- + +// CHECK-DAG: [[Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> +// CHECK-DAG: [[Bcol_map:#.+]] = affine_map<() -> (3)> +// CHECK-DAG: [[Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> + +#map0 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: func @batch_m16n8k16_fp16_row_row_row +func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-DAG: [[row:%.+]] = affine.apply [[Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[Acol_map]] + // CHECK: nvgpu.ldmatrix %arg0[%c0, [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> + %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[Brow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[Bcol_map]] + // CHECK: nvgpu.ldmatrix %arg1[%c0, [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[Acol_map]] + // CHECK: nvgpu.ldmatrix %arg2[%c0, [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> + return +} + +// ----- + +//######################################################### +// FP16 row-col-row +//######################################################### + +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: [[rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK: [[colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK: [[rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)> +// CHECK: [[colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)> + +// CHECK: func @m16n8k16_fp16_row_col_row +func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA_map]] + // CHECK: nvgpu.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB_map]] + // CHECK: nvgpu.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA_map]] + // CHECK: nvgpu.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} \ No newline at end of file