diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -681,4 +681,72 @@ let hasVerifier = 1; } +def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> { + let description = [{ + The `nvgpu.wargroup.mma` op represents the warpgroup-level (4 warps) + matrix-multiply-and-accumulate (mma) operation that results in + `nvvm.wgmma.mma_async`. + + The operands are `descriptorA` and `descriptorB` that are wgmma matrix + descriptors that shows the properties of the matrix in shared memory. The + results are thread-level ownership to the warpgroup-level mma operation + shape. The shape is deduced from the descriptor types and output vector. + + The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the + given shape. As the the instruction `nvvm.wgmma.async` is an asyncronous, + this Op groups the `nvvm.wgmma.async` and surrounds them between + `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`, + `wgmma.wait.group.sync.aligned` Ops. + + Example: + ```mlir + %res = nvgpu.wargroup.mma %wgmmaDescA, %wgmmaDescB, %acc: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + vector<128x128xf32> -> vector<128x128xf32> + ``` + + Results into following IR in NVVM dialect: + ```mlir + %undefStruct = llvm.mlir.undef : !llvm.struct<...> + nvvm.wgmma.fence.aligned + %32 = nvvm.wgmma.mma_async %wgmmaDescA,%wgmmaDescB,, + D[%undefStruct,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + %35 = nvvm.wgmma.mma_async %33,%34,, + D[%32,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + %38 = nvvm.wgmma.mma_async %36,%37,, + D[%35,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + %41 = nvvm.wgmma.mma_async %39,%40,, + D[%38,,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + %43 = nvvm.wgmma.mma_async %42,%wgmmaDescB,, + D[%undefStruct,,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + %45 = nvvm.wgmma.mma_async %44,%34,, + D[%43,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + %47 = nvvm.wgmma.mma_async %46,%37,, + D[%45,,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + %49 = nvvm.wgmma.mma_async %48,%40,, + D[%47,,],A[,,],B[,,] + : !llvm.struct<(...)> -> !llvm.struct<(...)> + nvvm.wgmma.commit.group.sync.aligned + nvvm.wgmma.wait.group.sync.aligned 1 + ``` + }]; + let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, + NVGPU_WarpgroupMatrixDescriptor:$descriptorB, + AnyVector:$matrixD); + let results = (outs AnyVector:$res); + let assemblyFormat = [{ + $descriptorA`,` $descriptorB`,` $matrixD attr-dict + `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixD) `->` type($res) + }]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -21,6 +21,8 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc" +constexpr int kWarpSize = 32; + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -17,11 +17,17 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#define DEBUG_TYPE "nvgpu-to-nvvm" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << (X)) + namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS #include "mlir/Conversion/Passes.h.inc" @@ -398,8 +404,8 @@ using Base::Base; void getDependentDialects(DialectRegistry ®istry) const override { - registry - .insert(); + registry.insert(); } void runOnOperation() override { @@ -432,6 +438,7 @@ LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::memref::MemRefDialect>(); + target.addLegalDialect<::mlir::vector::VectorDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -1115,6 +1122,226 @@ } }; +struct NVGPUWarpgroupMmaOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType, + int &wgmmaShapeM, int &wgmmaShapeN, + int &wgmmaShapeK) const { + wgmmaShapeM = 64; + wgmmaShapeN = sizeN; + if (inputElemType.isTF32()) { + wgmmaShapeK = 8; + } else if (inputElemType.isF16() || inputElemType.isBF16()) { + wgmmaShapeK = 16; + } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() || + inputElemType.isInteger(16)) { + wgmmaShapeK = 32; + } else if (inputElemType.isInteger(1)) { + wgmmaShapeK = 256; + } else { + return failure(); + } + LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM + << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK + << "]\n"); + return success(); + } + + Value generateNVVMWgmmaOp(MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Location loc, + int m, int n, int k, Type resultStructType, + Value inout, Value descriptorA, + Value descriptorB) const { + TypeRange resultTypes = {resultStructType}; + auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k); + auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one); + auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one); + auto layout = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row); + // todo input type + auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16); + auto overflow = + NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped); + Value res = rewriter.create( + loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype, + scaleOut, scaleIn, scaleIn, layout, layout, overflow); + return res; + } + + SmallVector generateWgmmaGroup(ConversionPatternRewriter &rewriter, + nvgpu::WarpgroupMmaOp op, int sizeN, + int sizeM, int sizeK, int wgmmaShapeM, + int wgmmaShapeN, int wgmmaShapeK, + Value descriptorA, + Value descriptorB) const { + auto loc = op->getLoc(); + VectorType outVtype = op.getMatrixD().getType(); + Type outElemType = outVtype.getElementType(); + SmallVector structBody; + for (int i = 0; i < wgmmaShapeM; i++) + structBody.push_back(outElemType); + auto stype = LLVM::LLVMStructType::getLiteral(op->getContext(), structBody); + + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + + rewriter.create(loc); + SmallVector results; + for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) { + Value undefOp = rewriter.create(loc, stype); + Value inout = undefOp; + LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":" + << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0 + << ":" << wgmmaShapeN << "] += \n"); + for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) { + LLVM_DEBUG(DBGS() << "\t wgmma." + << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k" + << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM) + << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" + << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * " + << " B[" << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0 + << ":" << wgmmaShapeN << "])\n"); + Value descA = + makeAdd(descriptorA, + makeI64Const(rewriter, op, (2 * iterK) + (512 * iterM))); + Value descB = + makeAdd(descriptorB, makeI64Const(rewriter, op, 2 * iterK)); + + inout = generateNVVMWgmmaOp(op->getContext(), rewriter, loc, + wgmmaShapeM, wgmmaShapeN, wgmmaShapeK, + stype, inout, descA, descB); + } + results.push_back(inout); + } + rewriter.create(loc); + rewriter.create(loc, 1); + + return results; + } + + Value generateStoreResult(ConversionPatternRewriter &rewriter, + nvgpu::WarpgroupMmaOp op, int sizeN, int sizeM, + int sizeK, int wgmmaShapeM, int wgmmaShapeN, + int wgmmaShapeK, + const SmallVector &wgmmaResults) const { + auto loc = op->getLoc(); + VectorType outVtype = op.getMatrixD().getType(); + Type outElemType = outVtype.getElementType(); + + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create( + loc, IntegerType::get(op->getContext(), 32), + rewriter.getI32IntegerAttr(index)); + }; + Value c4 = makeConst(4); + Value c32 = makeConst(32); + Value c8 = makeConst(8); + Value c2 = makeConst(2); + Value c1 = makeConst(1); + Value c16 = makeConst(16); + + auto makeMul = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + auto makeExtractInsert = [&](int index, Value wgmmaResult, + Value resultVectorIndex, + Value resultVector) -> Value { + Value extracted1 = + rewriter.create(loc, wgmmaResult, index); + Value extracted2 = + rewriter.create(loc, wgmmaResult, index + 1); + Value res = rewriter.create( + loc, extracted1, resultVector, resultVectorIndex); + Value nextIndex = makeAdd(resultVectorIndex, c1); + return rewriter.create(loc, extracted2, res, + nextIndex); + }; + + Value tidx = rewriter.create(loc, rewriter.getI32Type()); + Value laneId = + rewriter.create(loc, rewriter.getI32Type(), tidx, c32); + Value warpId = + rewriter.create(loc, rewriter.getI32Type(), tidx, c32); + Value lane4Id = + rewriter.create(loc, rewriter.getI32Type(), laneId, c4); + Value lane4modId = + rewriter.create(loc, rewriter.getI32Type(), laneId, c4); + + SmallVector resultVectors; + for (Value wgmmaResult : wgmmaResults) { + Value rv = rewriter.create( + loc, VectorType::get({wgmmaShapeM * wgmmaShapeN}, outElemType)); + Value idx = makeMul(lane4modId, c2); + Value valSizeN = makeConst(wgmmaShapeN); + idx = makeAdd(idx, makeMul(lane4Id, valSizeN)); + idx = makeAdd(idx, makeMul(c16, makeMul(warpId, valSizeN))); + for (int j = 0; j < 2; ++j) { + Value vj = makeMul(makeConst(j), makeMul(c8, valSizeN)); + for (int i = 0; i < 16; ++i) { + idx = makeAdd(idx, vj); + idx = makeAdd(idx, makeMul(makeConst(i), c8)); + rv = makeExtractInsert(i, wgmmaResult, idx, rv); + } + } + resultVectors.push_back(rv); + } + + Value rv = rewriter.create( + loc, VectorType::get({sizeM * sizeN}, outElemType)); + int offset = 0; + SmallVector stride = {1}; + for (Value resultVector : resultVectors) { + SmallVector offsets = {offset}; + rv = rewriter.create(loc, resultVector, rv, + offsets, stride); + offset += resultVector.getType().cast().getNumElements(); + } + + return rewriter.create(loc, outVtype, rv); + } + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int64_t sizeM = op.getMatrixD().getType().getDimSize(0); + int64_t sizeN = op.getMatrixD().getType().getDimSize(1); + int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1); + + LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A[" + << sizeM << "][" << sizeK << "] * B[" << sizeK << "][" + << sizeN << "] ---===\n"); + + int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK; + if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM, + wgmmaShapeN, wgmmaShapeK))) { + return failure(); + } + + Value descriptorA = adaptor.getDescriptorA(); + Value descriptorB = adaptor.getDescriptorB(); + + // Generate wgmma group + SmallVector wgmmaResults = + generateWgmmaGroup(rewriter, op, sizeN, sizeM, sizeK, wgmmaShapeM, + wgmmaShapeN, wgmmaShapeK, descriptorA, descriptorB); + + // Store back result structs to vector + Value result = + generateStoreResult(rewriter, op, sizeN, sizeM, sizeK, wgmmaShapeM, + wgmmaShapeN, wgmmaShapeK, wgmmaResults); + + rewriter.replaceOp(op, result); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, @@ -1131,6 +1358,7 @@ NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUTmaAsyncLoadOpLowering, // nvgpu.tma.async.load NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor + NVGPUWarpgroupMmaOpLowering, // nvgpu.wargroup.mma MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -151,7 +151,6 @@ // - For F32 (TF32), F16, S8, and S4 data // types the fundamental tensor core operation is of shape 8-by-8-by-128b. // - F64 is an exception and is of shape 8-by-8-by-256b. - constexpr int kThreads = 32; // 32 threads per warp int64_t shapeM = 8; int64_t shapeN = 8; int64_t shapeK; // set based on data type (128b for all data types except F64) @@ -206,17 +205,17 @@ // verify warp-wide size for vector a int64_t sparseFactor = sparse ? 2 : 1; - if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor) + if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor) return op->emitOpError() << "expected " << m * k << " warp-wide matrix A elements"; // verify warp-wide size for vector b - if (bShape[0] * bShape[1] * kThreads != k * n) + if (bShape[0] * bShape[1] * kWarpSize != k * n) return op->emitOpError() << "expected " << k * n << " warp-wide matrix B elements"; // verify warp-wide size for vector c - if (cShape[0] * cShape[1] * kThreads != m * n) + if (cShape[0] * cShape[1] * kWarpSize != m * n) return op->emitOpError() << "expected " << m * n << " warp-wide matrix C elements"; @@ -402,6 +401,96 @@ return success(); } +//===----------------------------------------------------------------------===// +// WarpgroupMmaOp +//===----------------------------------------------------------------------===// + +LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { + // F32 += F16 + F16 + // F16 += F16 + F16 + if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F32 += TF32 + TF32 + if (typeA.isTF32() && typeD.isF32() && typeB.isTF32()) + return success(); + // s32 += i8 + i8 + if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32)) + return success(); + // s32 += i1 + i1 + if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32)) + return success(); + // F32 += BF16 + BF16 + // F16 += BF16 + BF16 + if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F16 += f8 + f8 + // F32 += f8 + f8 + if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) && + (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) && + (typeD.isF32() || typeD.isF16())) + return success(); + + return failure(); +} + +LogicalResult isAllowedSizeN(int sizeN, Type typeA) { + SmallVector allowedN = {8, 16, 24, 32, 40, 48, 56, 64, + 72, 80, 88, 96, 104, 112, 120, 128, + 136, 144, 152, 160, 168, 176, 184, 192, + 200, 208, 216, 224, 232, 240, 248, 256}; + SmallVector allowedNshort = {8, 16, 24, 32, 48, 64, + 80, 96, 112, 128, 144, 160, + 176, 192, 208, 224, 240, 256}; + if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() || + typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2()) + if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; })) + return success(); + + if (typeA.isInteger(8) || typeA.isInteger(1)) + if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; })) + return success(); + return failure(); +} + +LogicalResult WarpgroupMmaOp::verify() { + auto matrixA = getDescriptorA().getType().getTensor(); + auto matrixB = getDescriptorB().getType().getTensor(); + auto matrixD = getMatrixD().getType(); + if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || + matrixD.getRank() != 2) + return emitOpError() << "Matrix A, B and D must be 2 dimensional."; + if (matrixA.getShape()[0] != matrixB.getShape()[1]) + return emitOpError() << "has " << matrixA << " but it is not valid"; + if (matrixA.getShape()[1] != matrixB.getShape()[0]) + return emitOpError() << "has " << matrixA << " but it is not valid"; + if (matrixA.getShape()[0] != matrixD.getShape()[0]) + return emitOpError() << "has " << matrixA << " but it is not valid"; + if (matrixB.getShape()[1] != matrixD.getShape()[1]) + return emitOpError() << "has " << matrixB << " but it is not valid"; + if (failed(isAllowedWGMMADataType(matrixD.getElementType(), + matrixA.getElementType(), + matrixB.getElementType()))) + return emitOpError() << matrixD.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported."; + // Check N + if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { + return emitOpError() << "has input type " << matrixB << " n is set to " + << matrixB.getDimSize(1) << ", it is not supported."; + } + + // Currently, f16/bf16 supported + if (!matrixA.getElementType().isF16() && !matrixA.getElementType().isBF16()) { + return emitOpError() << "hit a limitation: " << matrixD.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported yet."; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -convert-nvgpu-to-nvvm='use-opaque-pointers=1' | FileCheck %s -// RUN: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s +// RUN1: mlir-opt %s -test-transform-dialect-interpreter | FileCheck %s // CHECK-LABEL: @m16n8k16_fp16 func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { @@ -673,6 +673,68 @@ } +// CHECK-LABEL: @warpgroup_mma_128_128_64( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>) +func.func @warpgroup_mma_128_128_64( + %descA: !nvgpu.wgmma.descriptor>, + %descB: !nvgpu.wgmma.descriptor>, + %D: memref<128x128xf32,3>) +{ +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor> to i64 +// CHECK: nvvm.wgmma.fence.aligned +// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S4:.+]] = llvm.mlir.constant(0 : i32) : i64 +// CHECK: %[[S5:.+]] = llvm.add %[[S0]], %[[S4]] : i64 +// CHECK: %[[S6:.+]] = llvm.mlir.constant(0 : i32) : i64 +// CHECK: %[[S7:.+]] = llvm.add %[[S1]], %[[S6]] : i64 +// CHECK: %[[S8:.+]] = nvvm.wgmma.mma_async %5, %7, , D[%[[S3]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S9:.+]] = llvm.mlir.constant(2 : i32) : i64 +// CHECK: %[[S10:.+]] = llvm.add %[[S0]], %[[S9]] : i64 +// CHECK: %[[S11:.+]] = llvm.mlir.constant(2 : i32) : i64 +// CHECK: %[[S12:.+]] = llvm.add %[[S1]], %[[S11]] : i64 +// CHECK: %[[S13:.+]] = nvvm.wgmma.mma_async %10, %12, , D[%[[S8]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S14:.+]] = llvm.mlir.constant(4 : i32) : i64 +// CHECK: %[[S15:.+]] = llvm.add %[[S0]], %[[S14]] : i64 +// CHECK: %[[S16:.+]] = llvm.mlir.constant(4 : i32) : i64 +// CHECK: %[[S17:.+]] = llvm.add %[[S1]], %[[S16]] : i64 +// CHECK: %[[S18:.+]] = nvvm.wgmma.mma_async %[[S15]], %[[S17]], , D[%[[S13]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S19:.+]] = llvm.mlir.constant(6 : i32) : i64 +// CHECK: %[[S20:.+]] = llvm.add %[[S0]], %[[S19]] : i64 +// CHECK: %[[S21:.+]] = llvm.mlir.constant(6 : i32) : i64 +// CHECK: %[[S22:.+]] = llvm.add %[[S1]], %[[S21]] : i64 +// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %20, %22, , D[%[[S18]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S24:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S25:.+]] = llvm.mlir.constant(512 : i32) : i64 +// CHECK: %[[S26:.+]] = llvm.add %[[S0]], %[[S25]] : i64 +// CHECK: %[[S27:.+]] = llvm.mlir.constant(0 : i32) : i64 +// CHECK: %[[S28:.+]] = llvm.add %[[S1]], %[[S27]] : i64 +// CHECK: %[[S29:.+]] = nvvm.wgmma.mma_async %26, %28, , D[%[[S24]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S30:.+]] = llvm.mlir.constant(514 : i32) : i64 +// CHECK: %[[S31:.+]] = llvm.add %[[S0]], %[[S30]] : i64 +// CHECK: %[[S32:.+]] = llvm.mlir.constant(2 : i32) : i64 +// CHECK: %[[S33:.+]] = llvm.add %[[S1]], %[[S32]] : i64 +// CHECK: %[[S34:.+]] = nvvm.wgmma.mma_async %[[S31]], %[[S33]], , D[%[[S29]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S35:.+]] = llvm.mlir.constant(516 : i32) : i64 +// CHECK: %[[S36:.+]] = llvm.add %[[S0]], %[[S35]] : i64 +// CHECK: %[[S37:.+]] = llvm.mlir.constant(4 : i32) : i64 +// CHECK: %[[S38:.+]] = llvm.add %[[S1]], %[[S37]] : i64 +// CHECK: %[[S39:.+]] = nvvm.wgmma.mma_async %[[S36]], %[[S38]], , D[%[[S34]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S40:.+]] = llvm.mlir.constant(518 : i32) : i64 +// CHECK: %[[S41:.+]] = llvm.add %[[S0]], %[[S40]] : i64 +// CHECK: %[[S42:.+]] = llvm.mlir.constant(6 : i32) : i64 +// CHECK: %[[S43:.+]] = llvm.add %[[S1]], %[[S42]] : i64 +// CHECK: %[[S44:.+]] = nvvm.wgmma.mma_async %[[S41]], %[[S43]], , D[%[[S39]], , ], A[, , ], B[, , ] +// CHECK: nvvm.wgmma.commit.group.sync.aligned +// CHECK: nvvm.wgmma.wait.group.sync.aligned 1 + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32> + %wgmmaResult = nvgpu.wargroup.mma %descA, %descB, %acc : !nvgpu.wgmma.descriptor>, !nvgpu.wgmma.descriptor>, vector<128x128xf32> -> vector<128x128xf32> + vector.transfer_write %wgmmaResult, %D[%c0, %c0] : vector<128x128xf32>, memref<128x128xf32,3> + return +} + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 @@ -682,5 +744,5 @@ } with type_converter { transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter {use_opaque_pointers = true} - } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op + } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "vector", "scf"], partial_conversion} : !transform.any_op }