diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -764,8 +764,12 @@ "dialect"; let constructor = "mlir::createConvertVectorToGPUPass()"; let dependentDialects = [ - "memref::MemRefDialect", - "gpu::GPUDialect" + "memref::MemRefDialect", "gpu::GPUDialect", "AffineDialect", + "vector::VectorDialect" + ]; + + let options = [ + Option<"useWmma", "use-wmma", "bool", /*default=*/"true", ""> ]; } diff --git a/mlir/include/mlir/Conversion/VectorToGPU/NvvmMMASupport.h b/mlir/include/mlir/Conversion/VectorToGPU/NvvmMMASupport.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToGPU/NvvmMMASupport.h @@ -0,0 +1,81 @@ +//===- NvvmMMASupport.h - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to GPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOGPU_MMASUPPORT_H +#define MLIR_CONVERSION_VECTORTOGPU_MMASUPPORT_H + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace gpu { + +/// Helps to calculate the offsets within the tile for any NVVM/PTX MMA operand +/// that has a base tile size of 8 elements x [128|256|512] bits +namespace NvvmMmaOperandBaseTileOperand8x128 { + +/// Returns the number of bits in a single tile row. It is either 128, 256, or +/// 512 bits depending on the data type and whether the operand is an +/// accumulator. +int64_t inferTileWidthInBits(Type elementType, bool isAcc); + +/// Specifies information about the registers which compose a matrix fragment +/// according to the PTX documentation. +struct FragmentElementInfo { + Type registerLLVMType; + int64_t elementsPerRegister; + int64_t registerWidthBits; + int64_t numRegistersPerFragment; +}; + +/// Returns a FragmentElementInfo struct describing the register types for the +/// given matrix fragment type. +FailureOr getRegisterType(MMAMatrixType type); + +/// Returns an AffineMap which maps a two dimensions representing (laneId, +/// logicalValueId) and returns two results representing offsets within a +/// matrix operand. The offsets point to the values the thread is responsible +/// for (AKA the matrix fragment values) during a warp-collective matrix +/// operation. For a visual reference of this LaneId -> (row, col) mapping, +/// please see NVIDIA's PTX documentation: +/// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + MMAMatrixType fragmentType); + +struct LdMatrixParams { + MMAMatrixType fragmentType; + int64_t numTiles; + IteratorType contiguousDimType; + NVVM::MMALayout targetLayout; +}; + +FailureOr getLdMatrixParams(MMAMatrixType fragType, + bool transpose); +/// Returns an AffineMap which maps a single dimension representing the laneId +/// to two results representing offsets within the matrix operand that should +/// be the pointer locations a thread should pass to the ldmatrix instruction. +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms); + +} // namespace NvvmMmaOperandBaseTileOperand8x128 +} // namespace gpu +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOGPU_MMASUPPORT_H diff --git a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h --- a/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h +++ b/mlir/include/mlir/Conversion/VectorToGPU/VectorToGPU.h @@ -18,15 +18,22 @@ /// Patterns to transform vector ops into a canonical form to convert to MMA /// matrix operations. -void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns); +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useWmma); /// Convert vector ops to MMA matrix operations nested under `rootOp`. This will /// convert slice of operations that can be legally converted to MMA operations. /// The rest of the vector operations are left untouched. void convertVectorToMMAOps(Operation *rootOp); +/// Convert vector ops ops nested under `rootOp` to vector and GPU operaitons +/// compatible with the `nvvm.mma.sync` lowering path. This will convert a slice +/// of operations that can be legally lowered on this path while the rest of +/// the vector operations are left untouched. +LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp); + /// Convert from vector to GPU ops. -std::unique_ptr createConvertVectorToGPUPass(); +std::unique_ptr createConvertVectorToGPUPass(bool useWmma = true); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/GPU/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -165,6 +165,13 @@ /// C += A*B. This function returns which operand in the given equation is /// held by this type. String returned can be one of"AOp", "BOp" and "COp". StringRef getOperand() const; + + /// Returns whether this operand represents an accumulator or result type. + bool isAccOrResult() const { return getOperand() == "COp"; } + + int64_t getElementTypeBitWidth() const { + return getElementType().getIntOrFloatBitWidth(); + } }; // Adds a `gpu.async.token` to the front of the argument list. diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -96,6 +96,19 @@ }]; } +def GPU_LaneIdOp : GPU_Op<"lane_id", [NoSideEffect]> { + let description = [{ + Returns the lane id within the subgroup (warp/wave). + + Example: + ```mlir + %laneId = gpu.lane_id : index + ``` + }]; + let results = (outs Index:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + def GPU_SubgroupIdOp : GPU_Op<"subgroup_id", [NoSideEffect]>, Arguments<(ins)>, Results<(outs Index:$result)> { let description = [{ @@ -1354,4 +1367,58 @@ }]; } +def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix", + [MemoryEffects<[MemRead]>]> { + let description = [{ + The `gpu.mma.ldmatrix` op represents loading a matrix fragment from + memory. The load source and result type must be compatible with lowering + to the `nvvm.ldmatrix` instruction. This op is meant to represent + the distributed version of a `vector.transfer_read` as an intermediate + step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`. + + Example: + + ```mlir + gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16> + ``` + }]; + + let arguments = (ins Arg:$srcMemref, + Variadic:$indices, BoolAttr:$transpose, + I32Attr:$numTiles); + let results = (outs AnyVector:$res); + let assemblyFormat = [{ + $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) + }]; +} + +def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> { + let description = [{ + The `gpu.mma.sync` op represents the distributed form of a collective + matrix-multiply-and-accumulate (mma) operation that is compatible with + `nvvm.mma.sync`. The operands and results are fragments of the full matrix + operands. The full shape of the distributed mma operation is given by the + `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`. + + This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and + is an intermediate point between lowering from `vector.contract` to + `nvvm.mma.sync`. + + Example: + + ```mlir + gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + ``` + }]; + let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC, + I64ArrayAttr:$mmaShape); + + let results = (outs AnyVector:$res); + + let assemblyFormat = [{ + `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict + `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) + }]; +} + #endif // GPU_OPS diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -62,6 +62,34 @@ } }; +template +struct GPULaneIdIntrinsicOpLowering : ConvertOpToLLVMPattern { +private: + unsigned indexBitwidth; + +public: + explicit GPULaneIdIntrinsicOpLowering(LLVMTypeConverter &typeConverter) + : ConvertOpToLLVMPattern(typeConverter), + indexBitwidth(typeConverter.getIndexTypeBitwidth()) {} + // Convert the kernel arguments to an LLVM type, preserve the rest. + LogicalResult + matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + MLIRContext *context = rewriter.getContext(); + Value newOp = rewriter.create(loc, rewriter.getI32Type()); + if (indexBitwidth > 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } else if (indexBitwidth < 32) { + newOp = rewriter.create( + loc, IntegerType::get(context, indexBitwidth), newOp); + } + rewriter.replaceOp(op, {newOp}); + return success(); + } +}; + } // namespace mlir #endif // MLIR_CONVERSION_GPUCOMMON_INDEXINTRINSICSOPLOWERING_H_ diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -18,17 +18,28 @@ #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/VectorToGPU/NvvmMMASupport.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/FormatVariadic.h" @@ -208,6 +219,270 @@ } }; +struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = getContext(); + Location loc = op->getLoc(); + + // The result type of ldmatrix always be a struct of 32bit integer + // registers. The result type of the GPU operation is always a vector of + // shape (NumRegisters, VectorRegister) where VectorRegister is the vector + // type of the result and always 32 bits long. We bitcast the result of the + // NVVM::LdMatrix to this vector type. + auto vectorResultType = op->getResultTypes()[0].dyn_cast(); + if (!vectorResultType) + return failure(); + Type innerVectorType = LLVM::getFixedVectorType( + vectorResultType.getElementType(), vectorResultType.getShape()[1]); + + LLVM::LLVMStructType ldMatrixResultType = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(vectorResultType.getShape()[0], + rewriter.getI32Type())); + + auto srcMemrefType = op.srcMemref().getType().cast(); + Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), + adaptor.indices(), rewriter); + Value ldMatrixResult = rewriter.create( + loc, ldMatrixResultType, srcPtr, + /*num=*/op.numTiles(), + /*layout=*/op.transpose() ? NVVM::MMALayout::col + : NVVM::MMALayout::row); + + Type finalResultType = typeConverter->convertType(vectorResultType); + Value result = rewriter.create(loc, finalResultType); + for (int i = 0; i < vectorResultType.getShape()[0]; i++) { + Value i32Register = rewriter.create( + loc, ldMatrixResultType.getBody()[i], ldMatrixResult, + rewriter.getI64ArrayAttr(i)); + Value casted = + rewriter.create(loc, innerVectorType, i32Register); + result = rewriter.create( + loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + +private: + /// Checks if all the operands of the op being lowered are of LLVM Types. The + /// types are expected to be converted by the `LLVMTypeConverter` before the + /// op is actually lowered. If the type of an operands is not already + /// converted it hints a missing typeConversion and failure is returned in + /// that case. + static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, + ConversionPatternRewriter &rewriter) { + if (!llvm::all_of(operands, [](Value value) { + return LLVM::isCompatibleType(value.getType()); + })) { + return rewriter.notifyMatchFailure( + op, "cannot convert if operands aren't of LLVM type."); + } + + return success(); + } + + /// Returns the type for the intrinsic given the vectorResultType of the + /// `gpu.mma.sync` operation. + static Type inferIntrinsicResultType(Type vectorResultType) { + MLIRContext *ctx = vectorResultType.getContext(); + auto a = vectorResultType.cast(); + auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); + auto i32Ty = IntegerType::get(ctx, 32); + auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64Ty = Float64Type::get(ctx); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + if (a.getElementType() == f16x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(a.getNumElements(), f16x2Ty)); + } + if (a.getElementType() == i32x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(static_cast(a.getNumElements()) * 2, + i32Ty)); + } + if (a.getElementType() == f64x2Ty) { + return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); + } + return vectorResultType; + } + + /// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is + /// always an LLVM struct) into a fragment that is compatible with the vector + /// type of this operation. This involves extracting elements from the struct + /// and inserting them into an LLVM array. These extra data-movement + /// operations should be canonicalized away by the LLVM backend. + static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, + Type resultType, Value intrinsicResult, + RewriterBase &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + auto structType = intrinsicResultType.dyn_cast(); + auto arrayType = resultType.dyn_cast(); + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); + Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create( + loc, IntegerType::get(ctx, 32), rewriter.getI32IntegerAttr(index)); + }; + + if (arrayType) { + SmallVector elements; + + if (arrayType.getElementType() == f16x2Ty) { + for (unsigned i = 0; i < structType.getBody().size(); i++) { + elements.push_back(rewriter.create( + loc, structType.getBody()[i], intrinsicResult, + rewriter.getI64ArrayAttr(i))); + } + } + + // The intrinsic returns i32 and f64 values as individual scalars. We need + // to extract them from the struct and pack them into vectors. + if (arrayType.getElementType() == i32x2Ty || + arrayType.getElementType() == f64x2Ty) { + Value vec = + rewriter.create(loc, arrayType.getElementType()); + for (unsigned i = 0; i < structType.getBody().size() / 2; i++) { + Value x1 = rewriter.create( + loc, structType.getBody()[i * 2], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2)); + Value x2 = rewriter.create( + loc, structType.getBody()[i * 2 + 1], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2 + 1)); + vec = rewriter.create(loc, vec.getType(), vec, + x1, makeConst(0)); + vec = rewriter.create(loc, vec.getType(), vec, + x2, makeConst(1)); + } + elements.push_back(vec); + } + + // Create the final vectorized result. + Value result = rewriter.create(loc, arrayType); + for (auto el : llvm::enumerate(elements)) { + result = rewriter.create( + loc, arrayType, result, el.value(), + rewriter.getI64ArrayAttr(el.index())); + } + return result; + } + + return intrinsicResult; + } + + static SmallVector + unpackOperandVector(RewriterBase &rewriter, Location loc, Value operand) { + SmallVector result; + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type i8Ty = rewriter.getI8Type(); + auto arrayTy = operand.getType().cast(); + for (size_t i = 0, e = arrayTy.getNumElements(); i < e; ++i) { + Value toUse = rewriter.create( + loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); + + // For 4xi8 vectors, the intrinsic expects these to be provided as i32 + // scalar types. + Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); + if (arrayTy.getElementType() == i8x4Ty) { + result.push_back(rewriter.create( + loc, rewriter.getI32Type(), toUse)); + continue; + } + + // For some element types (i32, f64), we need to unpack the inner + // vector/array type as well because the intrinsic expects individual + // scalars to be provided. + VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); + if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || + innerArrayTy.getElementType() == f64Ty)) { + for (int idx = 0; idx < innerArrayTy.getNumElements(); idx++) { + result.push_back(rewriter.create( + loc, toUse, + rewriter.create( + loc, rewriter.getI64Type(), + rewriter.getI64IntegerAttr(idx)))); + } + continue; + } + result.push_back(toUse); + } + return result; + } + +public: + LogicalResult + matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + + if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) { + return failure(); + } + + // Get the shapes of the MMAMatrix type being used. The shapes will + // choose which intrinsic this op will be lowered to. + auto aType = op.matrixA().getType().cast(); + + int64_t m = op.mmaShape()[0].cast().getInt(); + int64_t n = op.mmaShape()[1].cast().getInt(); + int64_t k = op.mmaShape()[2].cast().getInt(); + std::array gemmShape{m, n, k}; + + auto matA = unpackOperandVector(rewriter, loc, adaptor.matrixA()); + auto matB = unpackOperandVector(rewriter, loc, adaptor.matrixB()); + auto matC = unpackOperandVector(rewriter, loc, adaptor.matrixC()); + + NVVM::MMATypes ptxTypeA; + NVVM::MMATypes ptxTypeB; + Optional overflow(llvm::None); + if (aType.getElementType().isInteger(8)) { + ptxTypeA = NVVM::MMATypes::s8; + ptxTypeB = NVVM::MMATypes::s8; + overflow = NVVM::MMAIntOverflow::satfinite; + + } else if (aType.getElementType().isF16()) { + ptxTypeA = NVVM::MMATypes::f16; + ptxTypeB = NVVM::MMATypes::f16; + } else if (aType.getElementType().isF64()) { + ptxTypeA = NVVM::MMATypes::f64; + ptxTypeB = NVVM::MMATypes::f64; + } else { + return op->emitError("could not deduce operand PTX types"); + } + + Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); + Type intrinsicResTy = inferIntrinsicResultType( + typeConverter->convertType(op->getResultTypes()[0])); + Value intrinsicResult = rewriter.create( + op.getLoc(), intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/llvm::None, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array{ptxTypeA, ptxTypeB}, + /*multiplicandLayouts=*/ + std::array{NVVM::MMALayout::row, + NVVM::MMALayout::col}); + rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, + desiredRetTy, intrinsicResult, + rewriter)); + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -276,6 +551,16 @@ } }; +struct GPULaneIdOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(gpu::LaneIdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + return success(); + } +}; + } // namespace void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { @@ -303,7 +588,9 @@ NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPUShuffleOpLowering, GPUReturnOpLowering>(converter); + GPULaneIdIntrinsicOpLowering, + GPUShuffleOpLowering, GPUReturnOpLowering, MmaSyncOptoNVVM, + MmaLdMatrixOpToNVVM>(converter); // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default diff --git a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToGPU/CMakeLists.txt @@ -1,18 +1,17 @@ -add_mlir_conversion_library(MLIRVectorToGPU +add_mlir_conversion_library( + MLIRVectorToGPU VectorToGPU.cpp - + NvvmMMASupport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToGPU - LINK_COMPONENTS Core - - LINK_LIBS PUBLIC + LINK_LIBS + PUBLIC MLIRArithmetic MLIRGPUOps MLIRLLVMIR MLIRMemRef MLIRTransforms MLIRVector - MLIRVectorUtils - ) + MLIRVectorUtils) diff --git a/mlir/lib/Conversion/VectorToGPU/NvvmMMASupport.cpp b/mlir/lib/Conversion/VectorToGPU/NvvmMMASupport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToGPU/NvvmMMASupport.cpp @@ -0,0 +1,213 @@ +//===- NvvmMMASupport.cpp - MLIR Vector to GPU lowering support --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utilities to assist in the lowering of Vector operations +// to GPU dialect MMA operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToGPU/NvvmMMASupport.h" + +namespace mlir { +namespace gpu { +namespace NvvmMmaOperandBaseTileOperand8x128 { + +namespace { + +/// There are always 4 threads per [128|256|512] bit row. +constexpr int64_t kThreadsPerRow = 4; + +constexpr int64_t kNumRowsPerTile = 8; + +/// Returns the number of registers which compose a matrix fragment held by a +/// single thread. +int64_t inferNumRegistersPerMatrixFragment(MMAMatrixType type) { + int64_t lineSize = + inferTileWidthInBits(type.getElementType(), type.isAccOrResult()); + auto shape = type.getShape(); + return (shape[0] / kNumRowsPerTile) * + (shape[1] * type.getElementTypeBitWidth()) / lineSize; +} + +/// Returns the number of 8 x [128|256|512] bit tiles that compose the given +/// operand shape. +std::array getTileShape(ArrayRef operandShape, + Type elementType, int64_t lineSizeBits) { + // For each 8x128bit square, a thread is responsible for one 32bit register. + return {operandShape[0] / kNumRowsPerTile, + (operandShape[1] * elementType.getIntOrFloatBitWidth()) / + lineSizeBits}; +} + +} // namespace + +int64_t inferTileWidthInBits(Type elementType, bool isAcc) { + if (isAcc && elementType.getIntOrFloatBitWidth() == 32) { + return 256; + } + if (elementType.getIntOrFloatBitWidth() == 64) { + return isAcc ? 512 : 256; + } + return 128; +} + +FailureOr getRegisterType(MMAMatrixType type) { + MLIRContext *ctx = type.getContext(); + if (type.getElementType().isF16()) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(Float16Type::get(ctx), 2), 2, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + + // f64 acc + Type f64Ty = Float64Type::get(ctx); + if (type.getElementType().isF64() && type.isAccOrResult()) { + return FragmentElementInfo{LLVM::getFixedVectorType(f64Ty, 2), 2, 128, + inferNumRegistersPerMatrixFragment(type)}; + } + + // f64 operand + if (type.getElementType().isF64() && !type.isAccOrResult()) { + return FragmentElementInfo{f64Ty, 1, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + + // int8 operand + if (type.getElementType().isInteger(8)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 8), 4), 4, 32, + inferNumRegistersPerMatrixFragment(type)}; + } + // 32bit acc operands + if (type.getElementType().isInteger(32)) { + return FragmentElementInfo{ + LLVM::getFixedVectorType(IntegerType::get(ctx, 32), 2), 2, 64, + inferNumRegistersPerMatrixFragment(type)}; + } + return failure(); +} + +static AffineMap getRegisterIndexToTileOffsetMap(OpBuilder &base, + Type elementType, + ArrayRef operandShape, + bool isAccumulator, + int64_t elementsPerRegister, + AffineExpr logicalValueId) { + const int64_t lineSize = inferTileWidthInBits(elementType, isAccumulator); + const int64_t elementsPerLine = + lineSize / elementType.getIntOrFloatBitWidth(); + const std::array num8x128bTiles = + getTileShape(operandShape, elementType, lineSize); + AffineExpr registerIdx = logicalValueId.floorDiv(elementsPerRegister); + return AffineMap::get( + 2, 0, + {(registerIdx % num8x128bTiles[0]) * 8, + (registerIdx.floorDiv(num8x128bTiles[0])) * elementsPerLine}, + base.getContext()); +} + +FailureOr +getLaneIdAndValueIdToOperandCoord(Location loc, OpBuilder &builder, + MMAMatrixType fragmentType) { + Type elementType = fragmentType.getElementType(); + ArrayRef operandShape = fragmentType.getShape(); + bool isAccumulator = fragmentType.isAccOrResult(); + FailureOr regInfo = getRegisterType(fragmentType); + if (failed(regInfo)) + return failure(); + + const int64_t elementBitWidth = fragmentType.getElementTypeBitWidth(); + const int64_t elementsPerRegister = + regInfo->registerWidthBits / elementBitWidth; + + AffineExpr laneId, logicalValueIdDim; + bindDims(builder.getContext(), laneId, logicalValueIdDim); + + // Determine what register logicalValueId corresponds to. Use that as a + // linear index into the coordinate mapping `index -> (tile row, tile col)`. + AffineMap registerIndexToTileCoord = getRegisterIndexToTileOffsetMap( + builder, elementType, operandShape, isAccumulator, elementsPerRegister, + logicalValueIdDim); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(2, 0, dimExprs, builder.getContext()); + }; + + auto tileRow = registerIndexToTileCoord.getResult(0); + auto tileCol = registerIndexToTileCoord.getResult(1); + return makeMap({tileRow + laneId.floorDiv(kThreadsPerRow), + tileCol + (laneId % kThreadsPerRow) * elementsPerRegister + + (logicalValueIdDim % elementsPerRegister)}); +} + +FailureOr getLdMatrixParams(MMAMatrixType fragType, + bool transpose) { + LdMatrixParams params; + params.fragmentType = fragType; + if (fragType.getOperand() == "AOp" || fragType.getOperand() == "COp") { + params.targetLayout = NVVM::MMALayout::row; + } else { + params.targetLayout = NVVM::MMALayout::col; + } + ArrayRef shape = fragType.getShape(); + params.contiguousDimType = + transpose ? IteratorType::Parallel : IteratorType::Reduction; + + if (params.targetLayout == NVVM::MMALayout::row) { + params.numTiles = (shape[0] / kNumRowsPerTile) * + ((shape[1] * fragType.getElementTypeBitWidth()) / 128); + } else { + params.numTiles = (shape[1] / kNumRowsPerTile) * + ((shape[0] * fragType.getElementTypeBitWidth()) / 128); + } + + return params; +} + +FailureOr +getLaneIdToLdMatrixMatrixCoord(Location loc, OpBuilder &builder, + const LdMatrixParams ¶ms) { + // One thread per 128b row. + const int64_t kNumThreadsPerTile = kNumRowsPerTile; + const int bitsPerElement = + static_cast(params.fragmentType.getElementTypeBitWidth()); + const int kElementsPer128b = (128 / bitsPerElement); + ArrayRef operandShape = params.fragmentType.getShape(); + AffineExpr d0 = getAffineDimExpr(0, builder.getContext()); + + auto makeMap = [&](ArrayRef dimExprs) -> AffineMap { + return AffineMap::get(1, 0, dimExprs, builder.getContext()); + }; + + // This case corresponds to row-major A|C or col-major B operands. + if (params.contiguousDimType == IteratorType::Reduction) { + AffineExpr row = d0 % (operandShape[0]); + AffineExpr col = d0.floorDiv(operandShape[0]) * (kElementsPer128b); + return makeMap({row, col}); + } + + // This case Corresponds to col-major A|C or row-major B operands. The + // operandShape given is already pre-transposed (e.g. 8x16 = KxN). + if (params.contiguousDimType == IteratorType::Parallel) { + const int64_t num8x128bCols = (operandShape[0] * bitsPerElement) / 128; + // Threads are assigned in groups of 8 first across columns, then to + // rows. This is transpose of what `ldmatrix` expects, but when + // `ldmatrix` gets the `.trans` qualifier, final the effect will be to + // transpose just the blocks. + auto groupIdx = d0.floorDiv(kNumThreadsPerTile); + auto tileCol = (groupIdx % num8x128bCols); + auto tileRow = groupIdx.floorDiv(num8x128bCols); + return makeMap({tileCol * kElementsPer128b, + tileRow * kNumRowsPerTile + (d0 % kNumRowsPerTile)}); + } + return failure(); +} + +} // namespace NvvmMmaOperandBaseTileOperand8x128 +} // namespace gpu +} // namespace mlir \ No newline at end of file diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -12,26 +12,61 @@ #include +#include "mlir/Conversion/VectorToGPU/NvvmMMASupport.h" #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "../PassDetail.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/TypeSwitch.h" -using namespace mlir; +namespace mlir { +namespace { + +/// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an +/// AffineMap representing offsets to apply to indices, the function fills +/// `indices` with the original indices plus the offsets. The offsets are +/// applied by taking into account the permutation map of the transfer op. If +/// the `offsetMap` has dimension placeholders, those should be provided in +/// `dimValues`. +template +void getXferIndices(OpBuilder &b, TransferOpType xferOp, AffineMap offsetMap, + ArrayRef dimValues, SmallVector &indices) { + indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); + Location loc = xferOp.getLoc(); + unsigned offsetsIdx = 0; + for (auto expr : xferOp.getPermutationMap().getResults()) { + if (auto dim = expr.template dyn_cast()) { + Value prevIdx = indices[dim.getPosition()]; + SmallVector dims(dimValues.begin(), dimValues.end()); + dims.push_back(prevIdx); + AffineExpr d0 = b.getAffineDimExpr(offsetMap.getNumDims()); + indices[dim.getPosition()] = makeComposedAffineApply( + b, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); + continue; + } + } +} +} // namespace // Return true if the contract op can be convert to MMA matmul. -static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) { +static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, + bool useWmma) { if (llvm::size(contract.getMasks()) != 0) return false; @@ -47,7 +82,9 @@ // The contract needs to represent a matmul to be able to convert to // MMAMatrix matmul. - if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + if (useWmma && contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}})) + return false; + if (!useWmma && contract.getIndexingMaps() != infer({{m, k}, {n, k}, {m, n}})) return false; return true; @@ -61,7 +98,7 @@ if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. - if(memrefType.getRank() < 2) + if (memrefType.getRank() < 2) return 0; int64_t offset = 0; SmallVector strides; @@ -75,7 +112,8 @@ } // Return true if the transfer op can be converted to a MMA matrix load. -static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { +static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp, + bool useWmma) { if (readOp.getMask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; @@ -87,9 +125,14 @@ AffineExpr zero = b.getAffineConstantExpr(0); auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, readOp.getContext()); - // TODO: Support transpose once it is added to GPU dialect ops. - // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). - return !(!map.isMinorIdentity() && map != broadcastInnerDim); + + if (useWmma) { + // TODO: Support transpose once it is added to GPU dialect ops. + // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1). + return map.isMinorIdentity() || map == broadcastInnerDim; + } + + return true; } // Return true if the transfer op can be converted to a MMA matrix store. @@ -147,15 +190,15 @@ return convertElementwiseOpToMMA(op).hasValue(); } -static bool supportsMMaMatrixType(Operation *op) { +static bool supportsMMaMatrixType(Operation *op, bool useWmma) { if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) - return transferReadSupportsMMAMatrixType(transferRead); + return transferReadSupportsMMAMatrixType(transferRead, useWmma); if (auto transferWrite = dyn_cast(op)) return transferWriteSupportsMMAMatrixType(transferWrite); if (auto contract = dyn_cast(op)) - return contractSupportsMMAMatrixType(contract); + return contractSupportsMMAMatrixType(contract, useWmma); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) @@ -203,7 +246,8 @@ // Analyze slice of operations based on convert op to figure out if the whole // slice can be converted to MMA operations. -static SetVector getOpToConvert(mlir::Operation *op) { +static SetVector getOpToConvert(mlir::Operation *op, + bool useWmma) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type t) { return t.isa(); }); @@ -221,8 +265,9 @@ // If any instruction cannot use MMA matrix type drop the whole // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. - if (llvm::any_of(dependentOps, - [](Operation *op) { return !supportsMMaMatrixType(op); })) + if (llvm::any_of(dependentOps, [useWmma](Operation *op) { + return !supportsMMaMatrixType(op, useWmma); + })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); @@ -233,6 +278,62 @@ namespace { // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted // to MMA matmul. +struct PrepareContractToGPUMMASync + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); + + // Set up the parallel/reduction structure in right form. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + static constexpr std::array perm = {1, 0}; + auto iteratorTypes = op.getIteratorTypes().getValue(); + SmallVector maps = op.getIndexingMaps(); + if (!(isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2]))) + return failure(); + + // The canonical form is "TNT" = A row-major, B col-major, C row-major. + const auto canonicalForm = infer({{m, k}, {n, k}, {m, n}}); + if (maps == canonicalForm) { + return failure(); + } + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + std::swap(rhs, lhs); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else { + return failure(); + } + rewriter.replaceOpWithNewOp( + op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(canonicalForm), + op.getIteratorTypes()); + return success(); + } +}; + struct PrepareContractToGPUMMA : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -351,7 +452,7 @@ static void convertTransferReadOp(vector::TransferReadOp op, llvm::DenseMap &valueMapping) { assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); - assert(transferReadSupportsMMAMatrixType(op)); + assert(transferReadSupportsMMAMatrixType(op, /*useWmma=*/true)); Optional stride = getMemrefConstantHorizontalStride(op.getShapedType()); AffineMap map = op.getPermutationMap(); @@ -386,6 +487,213 @@ op.erase(); } +/// Returns the vector type which represents a matrix fragment. +static VectorType getMmaSyncVectorOperandType( + gpu::MMAMatrixType fragType, + const gpu::NvvmMmaOperandBaseTileOperand8x128::FragmentElementInfo + ®Info) { + SmallVector shape{regInfo.numRegistersPerFragment, + regInfo.elementsPerRegister}; + Type elType = regInfo.registerLLVMType; + if (auto vecType = elType.dyn_cast()) + elType = vecType.getElementType(); + return VectorType::get(shape, elType); +} + +static LogicalResult +creatLdMatrixCompatibleLoads(vector::TransferReadOp op, OpBuilder &builder, + gpu::MMAMatrixType fragType, + llvm::DenseMap &valueMapping) { + Location loc = op->getLoc(); + + FailureOr + regInfo = + gpu::NvvmMmaOperandBaseTileOperand8x128::getRegisterType(fragType); + if (failed(regInfo)) + return failure(); + + auto params = gpu::NvvmMmaOperandBaseTileOperand8x128::getLdMatrixParams( + fragType, + /*transpose=*/!op.getPermutationMap().isMinorIdentity()); + if (failed(params)) + return failure(); + + // Adjust the load offset. + auto laneId = builder.create(loc); + FailureOr offsets = + gpu::NvvmMmaOperandBaseTileOperand8x128::getLaneIdToLdMatrixMatrixCoord( + loc, builder, *params); + if (failed(offsets)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(fragType, *regInfo); + + SmallVector indices; + getXferIndices(builder, op, *offsets, {laneId}, + indices); + gpu::MmaLdMatrixOp newOp = builder.create( + loc, vectorType, op.getSource(), indices, + !op.getPermutationMap().isMinorIdentity(), params->numTiles); + valueMapping[op] = newOp->getResult(0); + return success(); +} + +static LogicalResult +createNonLdMatrixLoads(vector::TransferReadOp op, OpBuilder &builder, + gpu::MMAMatrixType fragmentType, + llvm::DenseMap &valueMapping) { + Location loc = op.getLoc(); + FailureOr + regInfo = gpu::NvvmMmaOperandBaseTileOperand8x128::getRegisterType( + fragmentType); + if (failed(regInfo)) + return failure(); + + NVVM::MMALayout targetLayout = fragmentType.getOperand() == "BOp" + ? NVVM::MMALayout::col + : NVVM::MMALayout::row; + + Value laneId = builder.create(loc); + SmallVector elements; + + // This is the individual element type. + Type loadedElType = regInfo->registerLLVMType; + VectorType vectorType = getMmaSyncVectorOperandType(fragmentType, *regInfo); + + Value fill = builder.create( + op.getLoc(), fragmentType.getElementType(), + builder.getZeroAttr(fragmentType.getElementType())); + Value result = builder.create(op.getLoc(), fill, vectorType); + + bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); + + // Vectorized loads. + if (!isTransposeLoad && targetLayout == NVVM::MMALayout::row) { + if (!loadedElType.isa()) { + loadedElType = VectorType::get({1}, loadedElType); + } + + for (int i = 0; i < vectorType.getShape()[0]; i++) { + FailureOr coords = gpu::NvvmMmaOperandBaseTileOperand8x128:: + getLaneIdAndValueIdToOperandCoord(op.getLoc(), builder, fragmentType); + if (failed(coords)) + return failure(); + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister)); + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + + Value el = builder.create(loc, loadedElType, + op.getSource(), newIndices); + result = builder.create(loc, el, result, + builder.getI64ArrayAttr(i)); + } + } else if (isTransposeLoad && targetLayout == NVVM::MMALayout::col) { + if (auto vecType = loadedElType.dyn_cast()) { + loadedElType = vecType.getElementType(); + } + // Load each element individually. + for (int i = 0; i < vectorType.getShape()[0]; i++) { + for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; + innerIdx++) { + + Value logicalValueId = builder.create( + loc, builder.getIndexType(), + builder.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); + FailureOr coords = gpu::NvvmMmaOperandBaseTileOperand8x128:: + getLaneIdAndValueIdToOperandCoord(op.getLoc(), builder, + fragmentType); + if (failed(coords)) + return failure(); + + SmallVector newIndices; + getXferIndices( + builder, op, *coords, {laneId, logicalValueId}, newIndices); + Value el = builder.create(op.getLoc(), loadedElType, + op.getSource(), newIndices); + result = builder.create( + op.getLoc(), el, result, builder.getI64ArrayAttr({i, innerIdx})); + } + } + } else { + return failure(); + } + + valueMapping[op.getResult()] = result; + return success(); +} + +LogicalResult +convertTransferReadToLoads(vector::TransferReadOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + const char *fragType = inferFragType(op); + gpu::MMAMatrixType type = + gpu::MMAMatrixType::get(op.getVectorType().getShape(), + op.getVectorType().getElementType(), fragType); + + bool isLdMatrixCompatible = true; + if (op.getSource().getType().cast().getMemorySpaceAsInt() != 3) { + isLdMatrixCompatible = false; + } + if (gpu::NvvmMmaOperandBaseTileOperand8x128::inferTileWidthInBits( + type.getElementType(), type.isAccOrResult()) != 128) { + isLdMatrixCompatible = false; + } + if (!op.getPermutationMap().isMinorIdentity() && + (type.getOperand() == "BOp") && type.getElementTypeBitWidth() < 16) { + isLdMatrixCompatible = false; + } + if ((type.getOperand() == "COp") && type.getElementTypeBitWidth() < 16) { + isLdMatrixCompatible = false; + } + + if (!isLdMatrixCompatible) + return createNonLdMatrixLoads(op, b, type, valueMapping); + + return creatLdMatrixCompatibleLoads(op, b, type, valueMapping); +} + +LogicalResult +convertTransferWriteToStores(vector::TransferWriteOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Location loc = op->getLoc(); + Value matrix = valueMapping.find(op.getVector())->second; + + gpu::MMAMatrixType matType = + gpu::MMAMatrixType::get(op.getVectorType().getShape(), + op.getVectorType().getElementType(), "COp"); + FailureOr + regInfo = + gpu::NvvmMmaOperandBaseTileOperand8x128::getRegisterType(matType); + if (failed(regInfo)) + return failure(); + + VectorType vectorType = getMmaSyncVectorOperandType(matType, *regInfo); + Value laneId = b.create(loc); + + for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { + Value logicalValueId = b.create( + loc, b.getIndexType(), + b.getIndexAttr(i * regInfo->elementsPerRegister)); + FailureOr coords = gpu::NvvmMmaOperandBaseTileOperand8x128:: + getLaneIdAndValueIdToOperandCoord(op.getLoc(), b, matType); + if (failed(coords)) + return failure(); + + Value el = b.create(loc, matrix, ArrayRef{i}); + SmallVector newIndices; + getXferIndices( + b, op, *coords, {laneId, logicalValueId}, newIndices); + b.create(loc, el, op.getSource(), newIndices); + } + op->erase(); + return success(); +} + static void convertContractOp(vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder b(op); @@ -397,6 +705,22 @@ valueMapping[op.getResult()] = matmul; } +LogicalResult +convertContractOpToMmaSync(vector::ContractionOp op, + llvm::DenseMap &valueMapping) { + OpBuilder b(op); + Value opA = valueMapping.find(op.getLhs())->second; + Value opB = valueMapping.find(op.getRhs())->second; + Value opC = valueMapping.find(op.getAcc())->second; + int64_t m = op.getLhs().getType().cast().getShape()[0]; + int64_t n = op.getRhs().getType().cast().getShape()[0]; + int64_t k = op.getLhs().getType().cast().getShape()[1]; + Value matmul = b.create(op.getLoc(), opC.getType(), opA, opB, + opC, b.getI64ArrayAttr({m, n, k})); + valueMapping[op.getResult()] = matmul; + return success(); +} + /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static void convertConstantOp(arith::ConstantOp op, llvm::DenseMap &valueMapping) { @@ -509,13 +833,19 @@ valueMapping[op->getResult(0)] = newOp; } -void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns) { - patterns.add( +void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, + bool useWmma) { + if (useWmma) { + patterns.add( + patterns.getContext()); + return; + } + patterns.add( patterns.getContext()); } -void mlir::convertVectorToMMAOps(Operation *rootOp) { - SetVector ops = getOpToConvert(rootOp); +void convertVectorToMMAOps(Operation *rootOp) { + SetVector ops = getOpToConvert(rootOp, /*useWmma=*/true); llvm::DenseMap valueMapping; for (Operation *op : ops) { if (auto transferRead = dyn_cast(op)) { @@ -538,21 +868,56 @@ } } +LogicalResult convertVectorToNVVMCompatibleMMASync(Operation *rootOp) { + SetVector ops = getOpToConvert(rootOp, /*useWmma=*/false); + llvm::DenseMap valueMapping; + for (Operation *op : ops) { + if (llvm::TypeSwitch(op) + .Case([&](vector::TransferReadOp transferReadOp) { + return convertTransferReadToLoads(transferReadOp, valueMapping); + }) + .Case([&](vector::TransferWriteOp transferWriteOp) { + return convertTransferWriteToStores(transferWriteOp, + valueMapping); + }) + .Case([&](vector::ContractionOp contractionOp) { + return convertContractOpToMmaSync(contractionOp, valueMapping); + }) + .Case([&](arith::ConstantOp constantOp) { return success(); }) + .Default([&](Operation *op) { return failure(); }) + .failed()) { + return failure(); + } + } + return success(); +} + namespace { struct ConvertVectorToGPUPass : public ConvertVectorToGPUBase { + + explicit ConvertVectorToGPUPass(bool _useWmma) { useWmma.setValue(_useWmma); } + void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populatePrepareVectorToMMAPatterns(patterns); + populatePrepareVectorToMMAPatterns(patterns, useWmma.getValue()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - convertVectorToMMAOps(getOperation()); + if (useWmma.getValue()) { + convertVectorToMMAOps(getOperation()); + return; + } + + if (failed(convertVectorToNVVMCompatibleMMASync(getOperation()))) + return signalPassFailure(); } }; } // namespace -std::unique_ptr mlir::createConvertVectorToGPUPass() { - return std::make_unique(); +std::unique_ptr createConvertVectorToGPUPass(bool useWmma) { + return std::make_unique(useWmma); } + +} // namespace mlir diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -60,7 +60,8 @@ StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32(); + return elementType.isF16() || elementType.isF32() || elementType.isF64() || + elementType.isInteger(8) || elementType.isInteger(32); } LogicalResult diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -6,7 +6,8 @@ // CHECK32-LABEL: func @gpu_index_ops() func.func @gpu_index_ops() -> (index, index, index, index, index, index, - index, index, index, index, index, index) { + index, index, index, index, index, index, + index) { // CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64 // CHECK: = nvvm.read.ptx.sreg.tid.x : i32 @@ -49,10 +50,17 @@ // CHECK: = llvm.sext %{{.*}} : i32 to i64 %gDimZ = gpu.grid_dim z + + // CHECK: = nvvm.read.ptx.sreg.laneid : i32 + // CHECK: = llvm.sext %{{.*}} : i32 to i64 + %laneId = gpu.lane_id : index + func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, - %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ + %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ, + %laneId : index, index, index, index, index, index, - index, index, index, index, index, index + index, index, index, index, index, index, + index } } diff --git a/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s + +gpu.module @test_module { + // CHECK-LABEL: @m16n8k16_fp16 + func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>> + + // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>> + // CHECK-NOT llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + %d = "gpu.mma.sync"(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>> + return %d : vector<2x2xf16> + } + + // CHECK-LABEL: @m16n8k32_int8 + func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + + // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> + + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow + // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type + // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type + // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} + %d = "gpu.mma.sync"(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + + // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>> + return %d : vector<2x2xi32> + } + + // CHECK-LABEL: @m8n8k4_f64 + func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { + // CHECK: llvm.extractvalue %arg0 + // CHECK: llvm.extractvalue %arg1 + // CHECK: llvm.extractvalue %arg2 + + // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}] + // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32} + %d = "gpu.mma.sync"(%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + // CHECK: llvm.mlir.undef : vector<2xf64> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)> + // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>> + // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>> + return %d : vector<1x2xf64> + } + + func @ldmatrix(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { + %c0 = arith.constant 0 : index + // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) + %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + return %a : vector<4x2xf16> + } +} diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops-mma-sync.mlir @@ -0,0 +1,286 @@ +// RUN: mlir-opt %s -split-input-file -pass-pipeline="func.func(convert-vector-to-gpu{use-wmma=false})" | FileCheck %s + +//######################################################### +// INT8 row-row-row +//######################################################### + +// CHECK-DAG: [[rowA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[colA0_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 16 + 1)> + +// CHECK-DAG: [[rowB0_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 39)> +// CHECK-DAG: [[colB0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> +// CHECK-DAG: [[rowB1_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 40)> +// CHECK-DAG: [[rowB2_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 41)> +// CHECK-DAG: [[rowB3_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 42)> +// CHECK-DAG: [[rowB4_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 55)> +// CHECK-DAG: [[rowB5_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 56)> +// CHECK-DAG: [[rowB6_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 57)> +// CHECK-DAG: [[rowB7_map:#.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16 + 58)> + +// CHECK-DAG: [[rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40)> +// CHECK-DAG: [[rowC8_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 57)> + + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: func @m16n8k32_int8_row_row_row +func @m16n8k32_int8_row_row_row(%arg0: memref<128x128xi8, 3>, %arg1: memref<128x128xi8>, %arg2: memref<128x128xi32>) { + %cst_0 = arith.constant dense<0> : vector<32x8xi8> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0 : i8 + %cst0 = arith.constant 0 : i32 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK: [[row:%.+]] = affine.apply [[rowA0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colA0_map]]()[{{%.+}}] + // CHECK: gpu.mma.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<128x128xi8, 3> -> vector<4x4xi8> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB0_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB1_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB2_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB3_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB4_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB5_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB6_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB7_map]]()[{{%.+}}] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB0_map]]()[{{%.+}}] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xi8> + // CHECK-NOT: memref.load %arg1 + + // Verify that the operand C is distributed to loads correctly. + // CHECK: [[row:%.+]] = affine.apply [[rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK-NOT: vector.load %arg2{{.*}} + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xi8, 3>, vector<16x32xi8> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xi8>, vector<8x32xi8> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xi32>, vector<16x8xi32> + // CHECK: [[d:%.+]] = gpu.mma.sync({{.*}}) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x32xi8>, vector<8x32xi8> into vector<16x8xi32> + + // CHECK: [[row:%.+]] = affine.apply [[rowC0_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + // CHECK: [[row:%.+]] = affine.apply [[rowC8_map]]()[{{%.+}}] + // CHECK: [[col:%.+]] = affine.apply [[colC0_map]]()[{{%.+}}] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xi32>, vector<2xi32> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<16x8xi32>, memref<128x128xi32> + return +} + +// ----- + +//######################################################### +// f64 row-row-row +//######################################################### +// CHECK-DAG: [[rowA0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 1)> +// CHECK-DAG: [[colA0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 1)> + +// CHECK-DAG: [[rowb0_map:#.+]] = affine_map<()[s0] -> (s0 mod 4 + 39)> +// CHECK-DAG: [[colb0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 40)> + +// CHECK-DAG: [[rowC0_map:#.+]] = affine_map<()[s0] -> (s0 floordiv 4 + 49)> +// CHECK-DAG: [[colC0_map:#.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 4) * 8 + 40) + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: func @m8n8k4_f64_row_row_row +func @m8n8k4_f64_row_row_row(%arg0: memref<128x128xf64>, %arg1: memref<128x128xf64>, %arg2: memref<128x128xf64>) { + %cst_0 = arith.constant dense<0.0> : vector<4x8xf64> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c17 = arith.constant 17 : index + %c39 = arith.constant 39 : index + %c40 = arith.constant 40 : index + %c49 = arith.constant 49 : index + %c50 = arith.constant 50 : index + %cst = arith.constant 0.0 : f64 + %cst0 = arith.constant 0.0 : f64 + + // Verify that the operand A is distributed to loads correctly. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA0_map]] + // CHECK: vector.load %arg0[[[row]], [[col]]] : memref<128x128xf64>, vector<1xf64> + + // Verify that the operand B is distributed to loads correctly. It's elements + // must be loaded in a non-vectorized manner to do the transpose. + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowb0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colb0_map]] + // CHECK: memref.load %arg1[[[row]], [[col]]] : memref<128x128xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colC0_map]] + // CHECK: vector.load %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + + %A = vector.transfer_read %arg0[%c1, %c1], %cst {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x4xf64> + %B = vector.transfer_read %arg1[%c39, %c40], %cst {in_bounds = [true, true], permutation_map = #map0} : memref<128x128xf64>, vector<8x4xf64> + %C = vector.transfer_read %arg2[%c49, %c40], %cst0 {in_bounds = [true, true]} : memref<128x128xf64>, vector<8x8xf64> + // CHECK: [[d:%.+]] = gpu.mma.sync({{.*}}) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<8x4xf64>, vector<8x4xf64> into vector<8x8xf64> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowC0_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colC0_map]] + // CHECK: vector.store {{%.+}}, %arg2[[[row]], [[col]]] : memref<128x128xf64>, vector<2xf64> + vector.transfer_write %D, %arg2[%c49, %c40] {in_bounds = [true, true]} : vector<8x8xf64>, memref<128x128xf64> + return +} + +// ----- + +//######################################################### +// FP16 row-row-row +//######################################################### + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-DAG: [[rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK-DAG: [[rowB_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> +// CHECK-DAG: [[colB_map:#.+]] = affine_map<() -> (3)> + +// CHECK: func @m16n8k16_fp16_row_row_row +func @m16n8k16_fp16_row_row_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA_map]] + // CHECK: gpu.mma.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32, transpose = false} + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB_map]] + // CHECK: gpu.mma.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32, transpose = true} + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} + +// ----- + +// CHECK-DAG: [[Arow_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK-DAG: [[Acol_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> +// CHECK-DAG: [[Bcol_map:#.+]] = affine_map<() -> (3)> +// CHECK-DAG: [[Brow_map:#.+]] = affine_map<()[s0] -> (s0 + 3)> + +#map0 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: func @batch_m16n8k16_fp16_row_row_row +func @batch_m16n8k16_fp16_row_row_row(%arg0: memref<2x20x20xf16, 3>, %arg1: memref<2x20x20xf16, 3>, %arg2: memref<2x20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<20x20xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK-DAG: [[row:%.+]] = affine.apply [[Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[Acol_map]] + // CHECK: gpu.mma.ldmatrix %arg0[%c0, [[row]], [[col]]] {numTiles = 4 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<4x2xf16> + %A = vector.transfer_read %arg0[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[Brow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[Bcol_map]] + // CHECK: gpu.mma.ldmatrix %arg1[%c0, [[row]], [[col]]] {numTiles = 2 : i32, transpose = true} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %B = vector.transfer_read %arg1[%c0, %c3, %c3], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<8x16xf16> + + // CHECK-DAG: [[row:%.+]] = affine.apply [[Arow_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[Acol_map]] + // CHECK: gpu.mma.ldmatrix %arg2[%c0, [[row]], [[col]]] {numTiles = 2 : i32, transpose = false} : memref<2x20x20xf16, 3> -> vector<2x2xf16> + %C = vector.transfer_read %arg2[%c0, %c1, %c3], %cst {in_bounds = [true, true]} : memref<2x20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c1, %c3] {in_bounds = [true, true]} : vector<16x8xf16>, memref<2x20x20xf16, 3> + return +} + +// ----- + +//######################################################### +// FP16 row-col-row +//######################################################### + +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK: [[rowA_map:#.+]] = affine_map<()[s0] -> (s0 mod 16 + 1)> +// CHECK: [[colA_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 16) * 8 + 3)> + +// CHECK: [[rowB_map:#.+]] = affine_map<()[s0] -> (s0 mod 8 + 1)> +// CHECK: [[colB_map:#.+]] = affine_map<()[s0] -> ((s0 floordiv 8) * 8 + 3)> + +// CHECK: func @m16n8k16_fp16_row_col_row +func @m16n8k16_fp16_row_col_row(%arg0: memref<20x20xf16, 3>, %arg1: memref<20x20xf16, 3>, %arg2: memref<20x20xf16, 3>) { + %cst_0 = arith.constant dense<0.000000e+00> : vector<16x8xf16> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %cst = arith.constant 0.000000e+00 : f16 + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA_map]] + // CHECK: gpu.mma.ldmatrix %arg0[[[row]], [[col]]] {numTiles = 4 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowB_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colB_map]] + // CHECK: gpu.mma.ldmatrix %arg1[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + + // CHECK-DAG: [[row:%.+]] = affine.apply [[rowA_map]] + // CHECK-DAG: [[col:%.+]] = affine.apply [[colA_map]] + // CHECK: gpu.mma.ldmatrix %arg2[[[row]], [[col]]] {numTiles = 2 : i32 + // CHECK-SAME: transpose = false + %A = vector.transfer_read %arg0[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<8x16xf16> + %C = vector.transfer_read %arg2[%c1, %c3], %cst {in_bounds = [true, true]} : memref<20x20xf16, 3>, vector<16x8xf16> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %A, %B, %C : vector<16x16xf16>, vector<8x16xf16> into vector<16x8xf16> + vector.transfer_write %D, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x8xf16>, memref<20x20xf16, 3> + return +} \ No newline at end of file diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -476,16 +476,6 @@ // ----- -func @mmamatrix_invalid_element_type(){ - %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> - %i = arith.constant 16 : index - // expected-error @+1 {{MMAMatrixType elements must be F16 or F32}} - %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xi32, "AOp"> - return -} - -// ----- - #layout_map_col_major = affine_map<(i, j) -> (j, i)> func @mmaLoadOp_identity_layout(){