Index: mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h @@ -0,0 +1,26 @@ +//===- NVGPUToNVVMPass.h - Convert NVGPU to NVVM dialect --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_ +#define MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_ + +#include + +namespace mlir { + +class LLVMTypeConverter; +class RewritePatternSet; +class Pass; + +void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +std::unique_ptr createConvertNVGPUToNVVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_ Index: mlir/include/mlir/Conversion/Passes.h =================================================================== --- mlir/include/mlir/Conversion/Passes.h +++ mlir/include/mlir/Conversion/Passes.h @@ -35,6 +35,7 @@ #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h" +#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" #include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h" #include "mlir/Conversion/OpenACCToSCF/ConvertOpenACCToSCF.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -506,6 +506,22 @@ ]; } +//===----------------------------------------------------------------------===// +// NVGPUToNVVM +//===----------------------------------------------------------------------===// + +def ConvertNVGPUToNVVM : Pass<"convert-nvgpu-to-nvvm"> { + let summary = "Convert NVGPU dialect to NVVM dialect"; + let description = [{ + This pass converts supported NVGPU ops to NVVM dialect intrinsics. + }]; + let constructor = "mlir::createConvertNVGPUToNVVMPass()"; + let dependentDialects = [ + "NVVM::NVVMDialect", + ]; +} + + //===----------------------------------------------------------------------===// // OpenACCToSCF //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -1369,58 +1369,4 @@ }]; } -def GPU_MmaLdMatrixOp : GPU_Op<"mma.ldmatrix", - [MemoryEffects<[MemRead]>]> { - let description = [{ - The `gpu.mma.ldmatrix` op represents loading a matrix fragment from - memory. The load source and result type must be compatible with lowering - to the `nvvm.ldmatrix` instruction. This op is meant to represent - the distributed version of a `vector.transfer_read` as an intermediate - step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`. - - Example: - - ```mlir - gpu.mma.ldmatrix %shm_buffer[%c0, %c0] : memref<16x16xf16, 3> -> vector<4x2xf16> - ``` - }]; - - let arguments = (ins Arg:$srcMemref, - Variadic:$indices, BoolAttr:$transpose, - I32Attr:$numTiles); - let results = (outs AnyVector:$res); - let assemblyFormat = [{ - $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) - }]; -} - -def GPU_MmaSyncOp : GPU_Op<"mma.sync", [NoSideEffect]> { - let description = [{ - The `gpu.mma.sync` op represents the distributed form of a collective - matrix-multiply-and-accumulate (mma) operation that is compatible with - `nvvm.mma.sync`. The operands and results are fragments of the full matrix - operands. The full shape of the distributed mma operation is given by the - `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`. - - This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and - is an intermediate point between lowering from `vector.contract` to - `nvvm.mma.sync`. - - Example: - - ```mlir - gpu.mma.sync (%a, %b, %c) : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> - ``` - }]; - let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, AnyVector:$matrixC, - I64ArrayAttr:$mmaShape); - - let results = (outs AnyVector:$res); - - let assemblyFormat = [{ - `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict - `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) - }]; -} - #endif // GPU_OPS Index: mlir/include/mlir/Dialect/NVGPU/NVGPU.td =================================================================== --- mlir/include/mlir/Dialect/NVGPU/NVGPU.td +++ mlir/include/mlir/Dialect/NVGPU/NVGPU.td @@ -69,4 +69,37 @@ }]; } +def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> { + let description = [{ + The `nvgpu.mma.sync` op represents the distributed form of a collective + matrix-multiply-and-accumulate (mma) operation that is compatible with + `nvvm.mma.sync`. The operands and results are fragments of the full matrix + operands. The full shape of the distributed mma operation is given by the + `mmaShape` attribute in the form of a list of dimensions `[m, n, k]`. + + This operation is meant to be lowered to the `nvvm.mma.sync` instruction, and + is an intermediate point between lowering from `vector.contract` to + `nvvm.mma.sync`. + + This operation is meant to follow the semantic of described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma + + Example: + + ```mlir + nvgpu.mma.sync (%a, %b, %c) : + (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + ``` + }]; + let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB, + AnyVector:$matrixC, I64ArrayAttr:$mmaShape); + + let results = (outs AnyVector:$res); + + let assemblyFormat = [{ + `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict + `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) + }]; +} + #endif // NVGPU Index: mlir/lib/Conversion/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/CMakeLists.txt +++ mlir/lib/Conversion/CMakeLists.txt @@ -24,6 +24,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) +add_subdirectory(NVGPUToNVVM) add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -208,290 +208,6 @@ } }; -struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(gpu::MmaLdMatrixOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = getContext(); - Location loc = op->getLoc(); - - // The result type of ldmatrix will always be a struct of 32bit integer - // registers if more than one 32bit value is returned. Otherwise, the result - // is a single i32. The result type of the GPU operation is always a vector - // of shape (NumRegisters, VectorRegister) where VectorRegister is the - // vector type of the result and always 32 bits long. We bitcast the result - // of the NVVM::LdMatrix to this vector type. - auto vectorResultType = op->getResultTypes()[0].dyn_cast(); - if (!vectorResultType) { - return failure(); - } - Type innerVectorType = LLVM::getFixedVectorType( - vectorResultType.getElementType(), vectorResultType.getDimSize(1)); - - int64_t num32BitRegs = vectorResultType.getDimSize(0); - - Type ldMatrixResultType; - if (num32BitRegs > 1) { - ldMatrixResultType = LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(num32BitRegs, rewriter.getI32Type())); - } else { - ldMatrixResultType = rewriter.getI32Type(); - } - - auto srcMemrefType = op.srcMemref().getType().cast(); - Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), - adaptor.indices(), rewriter); - Value ldMatrixResult = rewriter.create( - loc, ldMatrixResultType, srcPtr, - /*num=*/op.numTiles(), - /*layout=*/op.transpose() ? NVVM::MMALayout::col - : NVVM::MMALayout::row); - - // The ldmatrix operation returns either a single i32 value or a struct of - // i32 values. Here we unpack those values and cast them back to their - // actual vector type (still of width 32b) and repack them into a result - // struct. - Type finalResultType = typeConverter->convertType(vectorResultType); - Value result = rewriter.create(loc, finalResultType); - for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { - Value i32Register = num32BitRegs > 1 - ? rewriter.create( - loc, rewriter.getI32Type(), ldMatrixResult, - rewriter.getI64ArrayAttr(i)) - : ldMatrixResult; - Value casted = - rewriter.create(loc, innerVectorType, i32Register); - result = rewriter.create( - loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - -/// Checks if all the operands of the op being lowered are of LLVM Types. The -/// types are expected to be converted by the `LLVMTypeConverter` before the -/// op is actually lowered. If the type of an operands is not already -/// converted it hints a missing typeConversion and failure is returned in -/// that case. -LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, - ConversionPatternRewriter &rewriter) { - if (!llvm::all_of(operands, [](Value value) { - return LLVM::isCompatibleType(value.getType()); - })) { - return rewriter.notifyMatchFailure( - op, "cannot convert if operands aren't of LLVM type."); - } - - return success(); -} - -/// Returns the type for the intrinsic given the vectorResultType of the -/// `gpu.mma.sync` operation. -Type inferIntrinsicResultType(Type vectorResultType) { - MLIRContext *ctx = vectorResultType.getContext(); - auto a = vectorResultType.cast(); - auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); - auto i32Ty = IntegerType::get(ctx, 32); - auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); - Type f64Ty = Float64Type::get(ctx); - Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); - if (a.getElementType() == f16x2Ty) { - return LLVM::LLVMStructType::getLiteral( - ctx, SmallVector(a.getNumElements(), f16x2Ty)); - } - if (a.getElementType() == i32x2Ty) { - return LLVM::LLVMStructType::getLiteral( - ctx, - SmallVector(static_cast(a.getNumElements()) * 2, i32Ty)); - } - if (a.getElementType() == f64x2Ty) { - return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); - } - return vectorResultType; -} - -/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is -/// always an LLVM struct) into a fragment that is compatible with the vector -/// type of this operation. This involves extracting elements from the struct -/// and inserting them into an LLVM array. These extra data-movement -/// operations should be canonicalized away by the LLVM backend. -Value convertIntrinsicResult(Location loc, Type intrinsicResultType, - Type resultType, Value intrinsicResult, - RewriterBase &rewriter) { - MLIRContext *ctx = rewriter.getContext(); - auto structType = intrinsicResultType.dyn_cast(); - auto arrayType = resultType.dyn_cast(); - Type i32Ty = rewriter.getI32Type(); - Type f64Ty = rewriter.getF64Type(); - Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); - Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); - Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); - - auto makeConst = [&](int32_t index) -> Value { - return rewriter.create(loc, IntegerType::get(ctx, 32), - rewriter.getI32IntegerAttr(index)); - }; - - if (arrayType) { - SmallVector elements; - - if (arrayType.getElementType() == f16x2Ty) { - for (unsigned i = 0; i < structType.getBody().size(); i++) { - elements.push_back(rewriter.create( - loc, structType.getBody()[i], intrinsicResult, - rewriter.getI64ArrayAttr(i))); - } - } - - // The intrinsic returns i32 and f64 values as individual scalars. We need - // to extract them from the struct and pack them into vectors. - if (arrayType.getElementType() == i32x2Ty || - arrayType.getElementType() == f64x2Ty) { - Value vec = - rewriter.create(loc, arrayType.getElementType()); - for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { - Value x1 = rewriter.create( - loc, structType.getBody()[i * 2], intrinsicResult, - rewriter.getI64ArrayAttr(i * 2)); - Value x2 = rewriter.create( - loc, structType.getBody()[i * 2 + 1], intrinsicResult, - rewriter.getI64ArrayAttr(i * 2 + 1)); - vec = rewriter.create(loc, vec.getType(), vec, - x1, makeConst(0)); - vec = rewriter.create(loc, vec.getType(), vec, - x2, makeConst(1)); - } - elements.push_back(vec); - } - - // Create the final vectorized result. - Value result = rewriter.create(loc, arrayType); - for (const auto &el : llvm::enumerate(elements)) { - result = rewriter.create( - loc, arrayType, result, el.value(), - rewriter.getI64ArrayAttr(el.index())); - } - return result; - } - - return intrinsicResult; -} - -/// The `gpu.mma.sync` converter below expects matrix fragment operands to be -/// given as 2D `vectors` where the rows are 32b or 64b wide. The -/// `nvvm.mma.sync` op expects these argments to be a given in a long list of -/// scalars of certain types. This function helps unpack the `vector` arguments -/// and cast them to the types expected by `nvvm.mma.sync`. -SmallVector unpackOperandVector(RewriterBase &rewriter, Location loc, - Value operand) { - SmallVector result; - Type i32Ty = rewriter.getI32Type(); - Type f64Ty = rewriter.getF64Type(); - Type i8Ty = rewriter.getI8Type(); - Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); - auto arrayTy = operand.getType().cast(); - - for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { - Value toUse = rewriter.create( - loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); - - // For 4xi8 vectors, the intrinsic expects these to be provided as i32 - // scalar types. - if (arrayTy.getElementType() == i8x4Ty) { - result.push_back( - rewriter.create(loc, rewriter.getI32Type(), toUse)); - continue; - } - - // For some element types (i32, f64), we need to unpack the inner - // vector/array type as well because the intrinsic expects individual - // scalars to be provided. - VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); - if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || - innerArrayTy.getElementType() == f64Ty)) { - for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); - idx < innerSize; idx++) { - result.push_back(rewriter.create( - loc, toUse, - rewriter.create( - loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); - } - continue; - } - result.push_back(toUse); - } - return result; -} - -struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(gpu::MmaSyncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) { - return failure(); - } - - // Get the shapes of the MMAMatrix type being used. The shapes will - // choose which intrinsic this op will be lowered to. - auto aType = op.matrixA().getType().cast(); - - int64_t m = op.mmaShape()[0].cast().getInt(); - int64_t n = op.mmaShape()[1].cast().getInt(); - int64_t k = op.mmaShape()[2].cast().getInt(); - std::array gemmShape{m, n, k}; - - SmallVector matA = - unpackOperandVector(rewriter, loc, adaptor.matrixA()); - SmallVector matB = - unpackOperandVector(rewriter, loc, adaptor.matrixB()); - SmallVector matC = - unpackOperandVector(rewriter, loc, adaptor.matrixC()); - - NVVM::MMATypes ptxTypeA; - NVVM::MMATypes ptxTypeB; - Optional overflow(llvm::None); - if (aType.getElementType().isInteger(8)) { - ptxTypeA = NVVM::MMATypes::s8; - ptxTypeB = NVVM::MMATypes::s8; - overflow = NVVM::MMAIntOverflow::satfinite; - - } else if (aType.getElementType().isF16()) { - ptxTypeA = NVVM::MMATypes::f16; - ptxTypeB = NVVM::MMATypes::f16; - } else if (aType.getElementType().isF64()) { - ptxTypeA = NVVM::MMATypes::f64; - ptxTypeB = NVVM::MMATypes::f64; - } else { - return op->emitError("could not deduce operand PTX types"); - } - - Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); - Type intrinsicResTy = inferIntrinsicResultType( - typeConverter->convertType(op->getResultTypes()[0])); - Value intrinsicResult = rewriter.create( - op.getLoc(), intrinsicResTy, matA, matB, matC, - /*shape=*/gemmShape, - /*b1Op=*/llvm::None, - /*intOverflow=*/overflow, - /*multiplicandPtxTypes=*/ - std::array{ptxTypeA, ptxTypeB}, - /*multiplicandLayouts=*/ - std::array{NVVM::MMALayout::row, - NVVM::MMALayout::col}); - rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, - desiredRetTy, intrinsicResult, - rewriter)); - return success(); - } -}; - struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -611,8 +327,8 @@ NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering, - MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM>(converter); + GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>( + converter); // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default Index: mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Conversion/NVGPUToNVVM/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIRNVGPUToNVVM + NVGPUToNVVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/NVGPUToNVVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMCommonConversion + MLIRLLVMIR + MLIRNVVMIR + MLIRNVGPU + MLIRPass + MLIRTransforms + ) Index: mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -0,0 +1,308 @@ +//===- NVGPUToNVVM.cpp - NVGPU to NVVM dialect conversion -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h" +#include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/NVGPU/NVGPUDialect.h" + +using namespace mlir; + +/// Returns the type for the intrinsic given the vectorResultType of the +/// `gpu.mma.sync` operation. +static Type inferIntrinsicResultType(Type vectorResultType) { + MLIRContext *ctx = vectorResultType.getContext(); + auto a = vectorResultType.cast(); + auto f16x2Ty = LLVM::getFixedVectorType(Float16Type::get(ctx), 2); + auto i32Ty = IntegerType::get(ctx, 32); + auto i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64Ty = Float64Type::get(ctx); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + if (a.getElementType() == f16x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(a.getNumElements(), f16x2Ty)); + } + if (a.getElementType() == i32x2Ty) { + return LLVM::LLVMStructType::getLiteral( + ctx, + SmallVector(static_cast(a.getNumElements()) * 2, i32Ty)); + } + if (a.getElementType() == f64x2Ty) { + return LLVM::LLVMStructType::getLiteral(ctx, {f64Ty, f64Ty}); + } + return vectorResultType; +} + +/// Convert the SSA result of the NVVM intrinsic `nvvm.mma.sync` (which is +/// always an LLVM struct) into a fragment that is compatible with the vector +/// type of this operation. This involves extracting elements from the struct +/// and inserting them into an LLVM array. These extra data-movement +/// operations should be canonicalized away by the LLVM backend. +static Value convertIntrinsicResult(Location loc, Type intrinsicResultType, + Type resultType, Value intrinsicResult, + RewriterBase &rewriter) { + MLIRContext *ctx = rewriter.getContext(); + auto structType = intrinsicResultType.dyn_cast(); + auto arrayType = resultType.dyn_cast(); + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type f16x2Ty = LLVM::getFixedVectorType(rewriter.getF16Type(), 2); + Type i32x2Ty = LLVM::getFixedVectorType(i32Ty, 2); + Type f64x2Ty = LLVM::getFixedVectorType(f64Ty, 2); + + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create(loc, IntegerType::get(ctx, 32), + rewriter.getI32IntegerAttr(index)); + }; + + if (arrayType) { + SmallVector elements; + + if (arrayType.getElementType() == f16x2Ty) { + for (unsigned i = 0; i < structType.getBody().size(); i++) { + elements.push_back(rewriter.create( + loc, structType.getBody()[i], intrinsicResult, + rewriter.getI64ArrayAttr(i))); + } + } + + // The intrinsic returns i32 and f64 values as individual scalars. We need + // to extract them from the struct and pack them into vectors. + if (arrayType.getElementType() == i32x2Ty || + arrayType.getElementType() == f64x2Ty) { + Value vec = + rewriter.create(loc, arrayType.getElementType()); + for (unsigned i = 0, e = structType.getBody().size() / 2; i < e; i++) { + Value x1 = rewriter.create( + loc, structType.getBody()[i * 2], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2)); + Value x2 = rewriter.create( + loc, structType.getBody()[i * 2 + 1], intrinsicResult, + rewriter.getI64ArrayAttr(i * 2 + 1)); + vec = rewriter.create(loc, vec.getType(), vec, + x1, makeConst(0)); + vec = rewriter.create(loc, vec.getType(), vec, + x2, makeConst(1)); + } + elements.push_back(vec); + } + + // Create the final vectorized result. + Value result = rewriter.create(loc, arrayType); + for (const auto &el : llvm::enumerate(elements)) { + result = rewriter.create( + loc, arrayType, result, el.value(), + rewriter.getI64ArrayAttr(el.index())); + } + return result; + } + + return intrinsicResult; +} + +/// The `gpu.mma.sync` converter below expects matrix fragment operands to be +/// given as 2D `vectors` where the rows are 32b or 64b wide. The +/// `nvvm.mma.sync` op expects these argments to be a given in a long list of +/// scalars of certain types. This function helps unpack the `vector` arguments +/// and cast them to the types expected by `nvvm.mma.sync`. +static SmallVector unpackOperandVector(RewriterBase &rewriter, + Location loc, Value operand) { + SmallVector result; + Type i32Ty = rewriter.getI32Type(); + Type f64Ty = rewriter.getF64Type(); + Type i8Ty = rewriter.getI8Type(); + Type i8x4Ty = LLVM::getFixedVectorType(i8Ty, 4); + auto arrayTy = operand.getType().cast(); + + for (unsigned i = 0, e = arrayTy.getNumElements(); i < e; ++i) { + Value toUse = rewriter.create( + loc, arrayTy.getElementType(), operand, rewriter.getI64ArrayAttr(i)); + + // For 4xi8 vectors, the intrinsic expects these to be provided as i32 + // scalar types. + if (arrayTy.getElementType() == i8x4Ty) { + result.push_back( + rewriter.create(loc, rewriter.getI32Type(), toUse)); + continue; + } + + // For some element types (i32, f64), we need to unpack the inner + // vector/array type as well because the intrinsic expects individual + // scalars to be provided. + VectorType innerArrayTy = arrayTy.getElementType().dyn_cast(); + if (innerArrayTy && (innerArrayTy.getElementType() == i32Ty || + innerArrayTy.getElementType() == f64Ty)) { + for (unsigned idx = 0, innerSize = innerArrayTy.getNumElements(); + idx < innerSize; idx++) { + result.push_back(rewriter.create( + loc, toUse, + rewriter.create( + loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(idx)))); + } + continue; + } + result.push_back(toUse); + } + return result; +} + +namespace { + +struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::LdMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *ctx = getContext(); + Location loc = op->getLoc(); + + // The result type of ldmatrix will always be a struct of 32bit integer + // registers if more than one 32bit value is returned. Otherwise, the result + // is a single i32. The result type of the GPU operation is always a vector + // of shape (NumRegisters, VectorRegister) where VectorRegister is the + // vector type of the result and always 32 bits long. We bitcast the result + // of the NVVM::LdMatrix to this vector type. + auto vectorResultType = op->getResultTypes()[0].dyn_cast(); + if (!vectorResultType) { + return failure(); + } + Type innerVectorType = LLVM::getFixedVectorType( + vectorResultType.getElementType(), vectorResultType.getDimSize(1)); + + int64_t num32BitRegs = vectorResultType.getDimSize(0); + + Type ldMatrixResultType; + if (num32BitRegs > 1) { + ldMatrixResultType = LLVM::LLVMStructType::getLiteral( + ctx, SmallVector(num32BitRegs, rewriter.getI32Type())); + } else { + ldMatrixResultType = rewriter.getI32Type(); + } + + auto srcMemrefType = op.srcMemref().getType().cast(); + Value srcPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.srcMemref(), + adaptor.indices(), rewriter); + Value ldMatrixResult = rewriter.create( + loc, ldMatrixResultType, srcPtr, + /*num=*/op.numTiles(), + /*layout=*/op.transpose() ? NVVM::MMALayout::col + : NVVM::MMALayout::row); + + // The ldmatrix operation returns either a single i32 value or a struct of + // i32 values. Here we unpack those values and cast them back to their + // actual vector type (still of width 32b) and repack them into a result + // struct. + Type finalResultType = typeConverter->convertType(vectorResultType); + Value result = rewriter.create(loc, finalResultType); + for (int64_t i = 0, e = vectorResultType.getDimSize(0); i < e; i++) { + Value i32Register = num32BitRegs > 1 + ? rewriter.create( + loc, rewriter.getI32Type(), ldMatrixResult, + rewriter.getI64ArrayAttr(i)) + : ldMatrixResult; + Value casted = + rewriter.create(loc, innerVectorType, i32Register); + result = rewriter.create( + loc, finalResultType, result, casted, rewriter.getI64ArrayAttr(i)); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(nvgpu::MmaSyncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + // Get the shapes of the MMAMatrix type being used. The shapes will + // choose which intrinsic this op will be lowered to. + auto aType = op.matrixA().getType().cast(); + + int64_t m = op.mmaShape()[0].cast().getInt(); + int64_t n = op.mmaShape()[1].cast().getInt(); + int64_t k = op.mmaShape()[2].cast().getInt(); + std::array gemmShape{m, n, k}; + + SmallVector matA = + unpackOperandVector(rewriter, loc, adaptor.matrixA()); + SmallVector matB = + unpackOperandVector(rewriter, loc, adaptor.matrixB()); + SmallVector matC = + unpackOperandVector(rewriter, loc, adaptor.matrixC()); + + NVVM::MMATypes ptxTypeA; + NVVM::MMATypes ptxTypeB; + Optional overflow(llvm::None); + if (aType.getElementType().isInteger(8)) { + ptxTypeA = NVVM::MMATypes::s8; + ptxTypeB = NVVM::MMATypes::s8; + overflow = NVVM::MMAIntOverflow::satfinite; + + } else if (aType.getElementType().isF16()) { + ptxTypeA = NVVM::MMATypes::f16; + ptxTypeB = NVVM::MMATypes::f16; + } else if (aType.getElementType().isF64()) { + ptxTypeA = NVVM::MMATypes::f64; + ptxTypeB = NVVM::MMATypes::f64; + } else { + return op->emitError("could not deduce operand PTX types"); + } + + Type desiredRetTy = typeConverter->convertType(op->getResultTypes()[0]); + Type intrinsicResTy = inferIntrinsicResultType( + typeConverter->convertType(op->getResultTypes()[0])); + Value intrinsicResult = rewriter.create( + op.getLoc(), intrinsicResTy, matA, matB, matC, + /*shape=*/gemmShape, + /*b1Op=*/llvm::None, + /*intOverflow=*/overflow, + /*multiplicandPtxTypes=*/ + std::array{ptxTypeA, ptxTypeB}, + /*multiplicandLayouts=*/ + std::array{NVVM::MMALayout::row, + NVVM::MMALayout::col}); + rewriter.replaceOp(op, convertIntrinsicResult(op.getLoc(), intrinsicResTy, + desiredRetTy, intrinsicResult, + rewriter)); + return success(); + } +}; + +struct ConvertNVGPUToNVVMPass + : public ConvertNVGPUToNVVMBase { + ConvertNVGPUToNVVMPass() = default; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + LLVMTypeConverter converter(&getContext()); + populateNVGPUToNVVMConversionPatterns(converter, patterns); + LLVMConversionTarget target(getContext()); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace +void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add(converter); +} + +std::unique_ptr mlir::createConvertNVGPUToNVVMPass() { + return std::make_unique(); +} Index: mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/mma-sync-to-nvvm.mlir +++ /dev/null @@ -1,129 +0,0 @@ -// RUN: mlir-opt --convert-gpu-to-nvvm --split-input-file %s | FileCheck %s - -gpu.module @test_module { - // CHECK-LABEL: @m16n8k16_fp16 - func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { - // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<4 x vector<2xf16>> - // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<4 x vector<2xf16>> - // CHECK: llvm.extractvalue %arg0[2] : !llvm.array<4 x vector<2xf16>> - // CHECK: llvm.extractvalue %arg0[3] : !llvm.array<4 x vector<2xf16>> - - // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.extractvalue %arg1[1] : !llvm.array<2 x vector<2xf16>> - - // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>> - // CHECK-NOT llvm.extractvalue - // CHECK: [[d:%.+]] = nvvm.mma.sync - // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} - %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> - // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> - // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> - // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> - // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> - // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>> - return %d : vector<2x2xf16> - } - - // CHECK-LABEL: @m16n8k8_fp16 - func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { - // CHECK: llvm.extractvalue %arg0[0] : !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.extractvalue %arg0[1] : !llvm.array<2 x vector<2xf16>> - - // CHECK: llvm.extractvalue %arg1[0] : !llvm.array<1 x vector<2xf16>> - - // CHECK: llvm.extractvalue %arg2[0] : !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.extractvalue %arg2[1] : !llvm.array<2 x vector<2xf16>> - // CHECK-NOT llvm.extractvalue - // CHECK: [[d:%.+]] = nvvm.mma.sync - // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32} - %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> - // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> - // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> - // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> - // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> - // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> - // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xf16>> - return %d : vector<2x2xf16> - } - - // CHECK-LABEL: @m16n8k32_int8 - func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { - - // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> - // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 - // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> - // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 - // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> - // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 - // CHECK: [[el:%.+]] = llvm.extractvalue %arg0[{{.*}}] : !llvm.array<4 x vector<4xi8>> - // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 - - // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> - // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 - // CHECK: [[el:%.+]] = llvm.extractvalue %arg1[{{.*}}] : !llvm.array<2 x vector<4xi8>> - // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 - - // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> - // CHECK: [[el:%.+]] = llvm.extractvalue %arg2[{{.*}}] : !llvm.array<2 x vector<2xi32>> - - // CHECK: [[d:%.+]] = nvvm.mma.sync - // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow - // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type - // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type - // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} - %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> - - // CHECK: llvm.return {{%.+}} : !llvm.array<2 x vector<2xi32>> - return %d : vector<2x2xi32> - } - - // CHECK-LABEL: @m8n8k4_f64 - func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { - // CHECK: llvm.extractvalue %arg0 - // CHECK: llvm.extractvalue %arg1 - // CHECK: llvm.extractvalue %arg2 - - // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}] - // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32} - %d = gpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> - // CHECK: llvm.mlir.undef : vector<2xf64> - // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)> - // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)> - // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64> - // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>> - // CHECK: llvm.return {{%.+}} : !llvm.array<1 x vector<2xf64>> - return %d : vector<1x2xf64> - } - - // CHECK-LABEL: @ldmatrix_x4 - func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { - %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) - %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> - // CHECK: llvm.extractvalue - // CHECK: llvm.bitcast - // CHECK: llvm.insertvalue - // CHECK: llvm.extractvalue - // CHECK: llvm.bitcast - // CHECK: llvm.insertvalue - // CHECK: llvm.extractvalue - // CHECK: llvm.bitcast - // CHECK: llvm.insertvalue - // CHECK: llvm.extractvalue - // CHECK: llvm.bitcast - // CHECK: llvm.insertvalue - return %a : vector<4x2xf16> - } - - // CHECK-LABEL: @ldmatrix_x1 - func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { - %c0 = arith.constant 0 : index - // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 1 : i32} {{.*}} -> i32 - %a = gpu.mma.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> - // CHECK: llvm.bitcast - // CHECK: llvm.insertvalue - return %a : vector<1x2xf16> - } -} Index: mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/NVGPUToNVVM/mma-sync-to-nvvm.mlir @@ -0,0 +1,127 @@ +// RUN: mlir-opt --convert-nvgpu-to-nvvm --split-input-file %s | FileCheck %s + +// CHECK-LABEL: @m16n8k16_fp16 +func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[2] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[3] : !llvm.array<4 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK-NOT llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> + return %d : vector<2x2xf16> +} + +// ----- + +// CHECK-LABEL: @m16n8k8_fp16 +func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<1 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK-NOT llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)> + // CHECK: llvm.mlir.undef : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<2 x vector<2xf16>> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[1] : !llvm.array<2 x vector<2xf16>> + // CHECK: return + return %d : vector<2x2xf16> +} + +// ----- + + +// CHECK-LABEL: @m16n8k32_int8 +func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<4 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<4xi8>> + // CHECK: llvm.bitcast [[el]] : vector<4xi8> to i32 + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[el:%.+]] = llvm.extractvalue %{{.*}}[{{.*}}] : !llvm.array<2 x vector<2xi32>> + // CHECK: [[d:%.+]] = nvvm.mma.sync + // CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow + // CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type + // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type + // CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} + +// ----- + +// CHECK-LABEL: @m8n8k4_f64 +func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { + // CHECK: llvm.extractvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.extractvalue + // CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}] + // CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + // CHECK: llvm.mlir.undef : vector<2xf64> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f64, f64)> + // CHECK-COUNT-2: llvm.insertelement {{.*}} : vector<2xf64> + // CHECK-DAG: llvm.insertvalue {{%.+}}, {{%.+}}[0] : !llvm.array<1 x vector<2xf64>> + // CHECK: return + return %d : vector<1x2xf64> +} + +// ----- + + +// CHECK-LABEL: @ldmatrix_x4 +func @ldmatrix_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x2xf16> { + %c0 = arith.constant 0 : index + // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 4 : i32} {{.*}} -> !llvm.struct<(i32, i32, i32, i32) + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x2xf16> + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + // CHECK: llvm.extractvalue + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + return %a : vector<4x2xf16> +} + +// ----- + +// CHECK-LABEL: @ldmatrix_x1 +func @ldmatrix_x1(%arg0: memref<128x128xf16, 3>) -> vector<1x2xf16> { + %c0 = arith.constant 0 : index + // CHECK: nvvm.ldmatrix {{%.+}} {layout = #nvvm.mma_layout, num = 1 : i32} {{.*}} -> i32 + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 1 : i32} : memref<128x128xf16, 3> -> vector<1x2xf16> + // CHECK: llvm.bitcast + // CHECK: llvm.insertvalue + return %a : vector<1x2xf16> +} Index: mlir/test/Dialect/NVGPU/roundtrip.mlir =================================================================== --- mlir/test/Dialect/NVGPU/roundtrip.mlir +++ mlir/test/Dialect/NVGPU/roundtrip.mlir @@ -8,3 +8,13 @@ memref -> vector<4x2xf16> return } + +// CHECK-LABEL: func @mma_sync( +func @mma_sync(%arg0: vector<4x2xf16>, + %arg1: vector<2x2xf16>, + %arg2: vector<2x2xf16>) -> vector<2x2xf16> { +// CHECK: nvgpu.mma.sync(%{{.*}}, %{{.*}}, %{{.*}}) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + %d = nvgpu.mma.sync(%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : + (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +}