diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -192,6 +192,15 @@ let assemblyFormat = "`<` struct(params) `>`"; } +def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", []> { + let parameters = (ins "Type":$tensor); + let assemblyFormat = "`<` struct(params) `>`"; + let description = [{ + It is fragmented result matrix from `nvgpu.wargroup.mma`. + [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + }]; +} + //===----------------------------------------------------------------------===// // NVGPU Op Definitions //===----------------------------------------------------------------------===// @@ -664,5 +673,44 @@ let hasVerifier = 1; } +def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> { + let description = [{ + The `nvgpu.wargroup.mma` op performs the warpgroup-level (4 warps) + matrix-multiply-and-accumulate (mma) operation that results in + `nvvm.wgmma.mma_async`. + + The operands are `descriptorA` and `descriptorB` that are wgmma matrix + descriptors that shows the properties of the matrix in shared memory. The + results are thread-level ownership to the warpgroup-level mma operation + shape. The shape is deduced from the descriptor types and output vector. + + The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the + given shape. As the the instruction `nvvm.wgmma.async` is an asyncronous, + this Op groups the `nvvm.wgmma.async` and surrounds them between + `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`, + `wgmma.wait.group.sync.aligned` Ops. + + Example: + ```mlir + %res = nvgpu.wargroup.mma %wgmmaDescA, %wgmmaDescB, %acc: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + vector<128x128xf32> -> !nvgpu.warpgroup.result + ``` + }]; + + let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, + NVGPU_WarpgroupMatrixDescriptor:$descriptorB, + AnyVector:$matrixC, + DefaultValuedOptionalAttr:$waitGroup, + OptionalAttr:$transposeA, + OptionalAttr:$transposeB); + let results = (outs Variadic:$matrixD); + let assemblyFormat = [{ + $descriptorA`,` $descriptorB`,` $matrixC (`,` `group` `=` $waitGroup^ )? attr-dict + `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD) + }]; + let hasVerifier = 1; +} #endif // NVGPU diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -21,6 +21,8 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc" +constexpr int kWarpSize = 32; + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -17,11 +17,17 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG_TYPE "nvgpu-to-nvvm" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define DBGSE() (llvm::dbgs()) + namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -29,6 +35,10 @@ using namespace mlir; +/// Number of bits that needs to excluded when building matrix descriptor for +/// wgmma operations. +constexpr int exclude4LSB = 4; + /// GPU has 32 bit registers, this function truncates values when larger width /// is not needed. static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc, @@ -398,8 +408,8 @@ using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() override { @@ -432,6 +442,7 @@ LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::memref::MemRefDialect>(); + target.addLegalDialect<::mlir::vector::VectorDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -979,10 +990,9 @@ shiftLeft(val, startBit)); }; - int ex4LSB = 4; - Value strideDim = makeConst((layout << 3) >> ex4LSB); + Value strideDim = makeConst((layout << 3) >> exclude4LSB); int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); - Value leadDim = makeConst((sizeN * layout) >> ex4LSB); + Value leadDim = makeConst((sizeN * layout) >> exclude4LSB); Value baseAddr = getStridedElementPtr( op->getLoc(), cast(op.getTensor().getType()), adaptor.getTensor(), {}, rewriter); @@ -1123,6 +1133,164 @@ } }; +struct NVGPUWarpgroupMmaOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType, + int &wgmmaShapeM, int &wgmmaShapeN, + int &wgmmaShapeK) const { + wgmmaShapeM = 64; + wgmmaShapeN = sizeN; + if (inputElemType.isTF32()) { + wgmmaShapeK = 8; + } else if (inputElemType.isF16() || inputElemType.isBF16()) { + wgmmaShapeK = 16; + } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() || + inputElemType.isInteger(16)) { + wgmmaShapeK = 32; + } else if (inputElemType.isInteger(1)) { + wgmmaShapeK = 256; + } else { + return failure(); + } + LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM + << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK + << "]\n"); + return success(); + } + + Value generateNVVMWgmmaOp(MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Location loc, + int m, int n, int k, Type resultStructType, + Value inout, Value descriptorA, + Value descriptorB) const { + TypeRange resultTypes = {resultStructType}; + auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k); + auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one); + auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one); + auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row); + auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col); + // todo input type + auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16); + auto overflow = + NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped); + Value res = rewriter.create( + loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype, + scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); + return res; + } + + static Type buildOutputStructType(MLIRContext *ctx, Type outElemType, + int sizeN) { + int outputElements = 0; + if (outElemType.isF32() || outElemType.isInteger(32)) + outputElements = sizeN / 2; + if (outElemType.isF16()) + outputElements = sizeN / 4; + SmallVector structBody; + for (int i = 0; i < outputElements; i++) + structBody.push_back(outElemType); + return LLVM::LLVMStructType::getLiteral(ctx, structBody); + } + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector wgmmaResults; + + int64_t sizeM = op.getMatrixC().getType().getDimSize(0); + int64_t sizeN = op.getMatrixC().getType().getDimSize(1); + int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1); + + LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A[" + << sizeM << "][" << sizeK << "] * B[" << sizeK << "][" + << sizeN << "] ---===\n"); + + int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK; + if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM, + wgmmaShapeN, wgmmaShapeK))) { + return failure(); + } + + Value descriptorA = adaptor.getDescriptorA(); + Value descriptorB = adaptor.getDescriptorB(); + + // Generate wgmma group + + auto loc = op->getLoc(); + Type outElemType = op.getMatrixC().getType().getElementType(); + Type stype = buildOutputStructType(op->getContext(), outElemType, sizeN); + MemRefType typeTensorA = op.getDescriptorA().getType().getTensor(); + MemRefType typeTensorB = op.getDescriptorB().getType().getTensor(); + + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + + auto iterateDescA = [&](Value desc, int iterM, int iterN, + int iterK) -> Value { + // todo : Handle column major + int byte = typeTensorA.getElementTypeBitWidth() / 8; + int tileShapeA = typeTensorA.getDimSize(1); + int incrementVal = + ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte; + incrementVal = incrementVal >> exclude4LSB; + LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: " + << iterK << "] [wgmma descriptors] Descriptor A + " + << incrementVal << " | \t "); + return incrementVal + ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal)) + : desc; + }; + + auto iterateDescB = [&](Value desc, int iterM, int iterN, + int iterK) -> Value { + // todo : Handle row major + int byte = typeTensorB.getElementTypeBitWidth() / 8; + int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte; + incrementVal = incrementVal >> exclude4LSB; + LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + return incrementVal + ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal)) + : desc; + }; + + rewriter.create(loc); + for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) { + Value undefOp = rewriter.create(loc, stype); + Value inout = undefOp; + LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":" + << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0 + << ":" << wgmmaShapeN << "] += \n"); + for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) { + Value descA = iterateDescA(descriptorA, iterM, 0, iterK); + Value descB = iterateDescB(descriptorB, iterM, 0, iterK); + LLVM_DEBUG(DBGS() << "\t wgmma." + << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k" + << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM) + << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" + << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * " + << " B[" << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0 + << ":" << wgmmaShapeN << "])\n"); + inout = generateNVVMWgmmaOp(op->getContext(), rewriter, loc, + wgmmaShapeM, wgmmaShapeN, wgmmaShapeK, + stype, inout, descA, descB); + } + wgmmaResults.push_back(inout); + } + + rewriter.create(loc); + rewriter.create(loc, op.getWaitGroup()); + + ValueRange myres(wgmmaResults); + rewriter.replaceOp(op, myres); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, @@ -1138,6 +1306,7 @@ NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor + NVGPUWarpgroupMmaOpLowering, // nvgpu.wargroup.mma MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -151,7 +151,6 @@ // - For F32 (TF32), F16, S8, and S4 data // types the fundamental tensor core operation is of shape 8-by-8-by-128b. // - F64 is an exception and is of shape 8-by-8-by-256b. - constexpr int kThreads = 32; // 32 threads per warp int64_t shapeM = 8; int64_t shapeN = 8; int64_t shapeK; // set based on data type (128b for all data types except F64) @@ -206,17 +205,17 @@ // verify warp-wide size for vector a int64_t sparseFactor = sparse ? 2 : 1; - if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor) + if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor) return op->emitOpError() << "expected " << m * k << " warp-wide matrix A elements"; // verify warp-wide size for vector b - if (bShape[0] * bShape[1] * kThreads != k * n) + if (bShape[0] * bShape[1] * kWarpSize != k * n) return op->emitOpError() << "expected " << k * n << " warp-wide matrix B elements"; // verify warp-wide size for vector c - if (cShape[0] * cShape[1] * kThreads != m * n) + if (cShape[0] * cShape[1] * kWarpSize != m * n) return op->emitOpError() << "expected " << m * n << " warp-wide matrix C elements"; @@ -402,6 +401,107 @@ return success(); } +//===----------------------------------------------------------------------===// +// WarpgroupMmaOp +//===----------------------------------------------------------------------===// + +LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { + // F32 += F16 + F16 + // F16 += F16 + F16 + if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F32 += TF32 + TF32 + if (typeA.isTF32() && typeD.isF32() && typeB.isTF32()) + return success(); + // s32 += i8 + i8 + if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32)) + return success(); + // s32 += i1 + i1 + if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32)) + return success(); + // F32 += BF16 + BF16 + // F16 += BF16 + BF16 + if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F16 += f8 + f8 + // F32 += f8 + f8 + if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) && + (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) && + (typeD.isF32() || typeD.isF16())) + return success(); + + return failure(); +} + +LogicalResult isAllowedSizeN(int sizeN, Type typeA) { + SmallVector allowedN = {8, 16, 24, 32, 40, 48, 56, 64, + 72, 80, 88, 96, 104, 112, 120, 128, + 136, 144, 152, 160, 168, 176, 184, 192, + 200, 208, 216, 224, 232, 240, 248, 256}; + SmallVector allowedNshort = {8, 16, 24, 32, 48, 64, + 80, 96, 112, 128, 144, 160, + 176, 192, 208, 224, 240, 256}; + if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() || + typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2()) + if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; })) + return success(); + + if (typeA.isInteger(8) || typeA.isInteger(1)) + if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; })) + return success(); + return failure(); +} + +LogicalResult WarpgroupMmaOp::verify() { + if (getTransposeA() && !getTransposeB()) + return emitOpError() << "supports non-transpose A (Row Major) " + "and transpose B (Column Major) for the time being"; + auto matrixA = getDescriptorA().getType().getTensor(); + auto matrixB = getDescriptorB().getType().getTensor(); + auto matrixC = getMatrixC().getType(); + if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || + matrixC.getRank() != 2) + return emitOpError() + << "has input matrices A, B and D, they must be 2 dimensional"; + + if (matrixA.getShape()[1] != matrixB.getShape()[0]) + return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1] + << ")!= 1st dim matrix-B (" << matrixB.getShape()[0] + << " )"; + if (matrixA.getShape()[0] != matrixC.getShape()[0]) + return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0] + << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0] + << " )"; + if (matrixB.getShape()[1] != matrixC.getShape()[1]) + return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1] + << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1] + << " )"; + + if (failed(isAllowedWGMMADataType(matrixC.getElementType(), + matrixA.getElementType(), + matrixB.getElementType()))) + return emitOpError() << matrixC.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported."; + // Check N + if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { + return emitOpError() << "has input type " << matrixB << " n is set to " + << matrixB.getDimSize(1) << ", it is not supported"; + } + + // Currently, f16/bf16 supported + if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() && + !matrixA.getElementType().isBF16()) { + return emitOpError() << "hit a limitation: " << matrixC.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported yet"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -672,6 +672,70 @@ func.return %descA : !nvgpu.wgmma.descriptor> } +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL: @warpgroup_mma_128_128_64( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>) +func.func @warpgroup_mma_128_128_64( + %descA: !nvgpu.wgmma.descriptor>, + %descB: !nvgpu.wgmma.descriptor>, + %D: memref<128x128xf32,3>) +{ +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %arg0 : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %arg1 : !nvgpu.wgmma.descriptor> to i64 +// CHECK: nvvm.wgmma.fence.aligned +// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], , D[%3, , ], A[, , ], B[, , ] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64 +// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64 +// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64 +// CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]] : i64 +// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], , D[%[[S4]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64 +// CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]] : i64 +// CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64 +// CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]] : i64 +// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], , D[%[[S9]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64 +// CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]] : i64 +// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64 +// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64 +// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], , D[%[[S14]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S20:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64 +// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64 +// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], , D[%[[S20]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64 +// CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64 +// CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64 +// CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]] : i64 +// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], , D[%[[S23]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64 +// CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]] : i64 +// CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64 +// CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]] : i64 +// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], , D[%[[S28]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64 +// CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]] : i64 +// CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64 +// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64 +// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], , D[%[[S33]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: nvvm.wgmma.commit.group.sync.aligned +// CHECK: nvvm.wgmma.wait.group.sync.aligned 1 + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32> + %wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + vector<128x128xf32> -> !nvgpu.warpgroup.result, !nvgpu.warpgroup.result + + return +} + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 @@ -681,5 +745,5 @@ } with type_converter { transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter {use_opaque_pointers = true} - } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op + } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "vector", "scf"], partial_conversion} : !transform.any_op } \ No newline at end of file diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -221,3 +221,64 @@ %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 3: memref<128x128xf64> to memref<3x16x128xf64, 3> return } + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> + +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + return +} + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op has input matrices A, B and D, they must be 2 dimensional}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult + return +} + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + return +} + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult + return +}