diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -713,4 +713,23 @@ let hasVerifier = 1; } +def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"wargroup.mma.store"> { + let description = [{ + The `nvgpu.wargroup.mma.store` op performs the store of fragmented result + in $matrixD to give memref. + + [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + + Note that, the op must be run with warp group. + }]; + + let arguments = (ins Variadic:$matrixD, + Arg:$dstMemref); + + let assemblyFormat = [{ + `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref) + }]; + let hasVerifier = 1; +} + #endif // NVGPU diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -11,6 +11,7 @@ #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -409,7 +410,7 @@ void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + arith::ArithDialect>(); } void runOnOperation() override { @@ -424,6 +425,9 @@ converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 32)); }); + converter.addConversion([&](nvgpu::WarpgroupResultType type) -> Type { + return converter.convertType(type.getTensor()); + }); converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 64)); }); @@ -441,8 +445,8 @@ populateNVGPUToNVVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::arith::ArithDialect>(); target.addLegalDialect<::mlir::memref::MemRefDialect>(); - target.addLegalDialect<::mlir::vector::VectorDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -1291,11 +1295,88 @@ } }; +struct NVGPUWarpgroupMmaStoreOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; + + void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + int offset) const { + Location loc = op->getLoc(); + Type i32 = rewriter.getI32Type(); + + auto makeConst = [&](int32_t index) -> Value { + return rewriter.create( + loc, i32, rewriter.getI32IntegerAttr(index)); + }; + Value c4 = makeConst(4); + Value c32 = makeConst(kWarpSize); + Value c8 = makeConst(8); + Value c2 = makeConst(2); + Value c1 = makeConst(1); + Value c16 = makeConst(16); + + auto makeMul = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + + Value tidx = rewriter.create(loc, i32); + Value laneId = rewriter.create(loc, i32, tidx, c32); + Value warpId = rewriter.create(loc, i32, tidx, c32); + Value lane4Id = rewriter.create(loc, i32, laneId, c4); + Value lane4modId = rewriter.create(loc, i32, laneId, c4); + + auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, + TypedValue<::mlir::MemRefType> memref) { + Type it = rewriter.getIndexType(); + Value idx = rewriter.create(loc, it, x); + Value idy0 = rewriter.create(loc, it, y); + Value idy1 = rewriter.create(loc, it, makeAdd(y, c1)); + Value d0 = rewriter.create(loc, wgmmaResult, i); + Value d1 = rewriter.create(loc, wgmmaResult, i + 1); + rewriter.create(loc, d0, memref, ValueRange{idx, idy0}); + rewriter.create(loc, d1, memref, ValueRange{idx, idy1}); + }; + + Value tj = makeMul(lane4modId, c2); + Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); + if (offset) + ti = makeAdd(ti, makeConst(offset)); + for (int i = 0; i < 2; ++i) { + Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); + for (int j = 0; j < 16; ++j) { + Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); + int sIndex = i * 2 + j * 4; + makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref()); + } + } + } + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int offset = 0; + for (auto result : adaptor.getMatrixD()) { + auto stype = result.getType().cast(); + storeFragmentedMatrix(result, op, adaptor, rewriter, offset); + offset += stype.getBody().size(); + } + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.add< + NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.wargroup.mma.store` NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -22,6 +23,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -502,6 +505,34 @@ return success(); } +LogicalResult WarpgroupMmaStoreOp::verify() { + Type stype = + getMatrixD().front().getType().cast().getTensor(); + + for (auto result : getMatrixD()) { + auto resultStype = result.getType() + .cast() + .getTensor() + .dyn_cast(); + if (!resultStype) + return emitOpError() << "result is " << result.getType() + << " but must keep type of llvm struct"; + if (stype != resultStype) + return emitOpError() << "all results must be the same type"; + + // todo improve this limitation + if (!resultStype.getBody().front().isF32()) { + return emitOpError() << "supporst only f32 results for the time being"; + } + } + + if (!llvm::all_equal(stype.cast().getBody())) { + return emitOpError() << "all element types must be equal "; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -732,10 +732,93 @@ !nvgpu.wgmma.descriptor>, !nvgpu.wgmma.descriptor>, vector<128x128xf32> -> !nvgpu.warpgroup.result, !nvgpu.warpgroup.result - return } +// CHECK-LABEL: @warpgroup_mma_store( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.result>, %[[arg1:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3> +func.func @warpgroup_mma_store(%result1 : !nvgpu.warpgroup.result, %matrixD: memref<128x128xf32,3>) { +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : +// CHECK: %[[S2:.+]] = llvm.mlir.constant(4 : i32) : i32 +// CHECK: %[[S3:.+]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[S4:.+]] = llvm.mlir.constant(8 : i32) : i32 +// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S6:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S7:.+]] = llvm.mlir.constant(16 : i32) : i32 + +// ### Store {d0, d1} of each thread ### + +// CHECK: %[[S8:.+]] = nvvm.read.ptx.sreg.tid.x : i32 +// CHECK: %[[S9:.+]] = llvm.urem %[[S8]], %[[S3]] : i32 +// CHECK: %[[S10:.+]] = llvm.udiv %[[S8]], %[[S3]] : i32 +// CHECK: %[[S11:.+]] = llvm.udiv %[[S9]], %[[S2]] : i32 +// CHECK: %[[S12:.+]] = llvm.urem %[[S9]], %[[S2]] : i32 +// CHECK: %[[S13:.+]] = llvm.mul %[[S12]], %[[S5]] : i32 +// CHECK: %[[S14:.+]] = llvm.mul %[[S10]], %[[S7]] : i32 +// CHECK: %[[S15:.+]] = llvm.add %[[S11]], %[[S14]] : i32 +// CHECK: %[[S16:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S17:.+]] = llvm.mul %[[S16]], %[[S4]] : i32 +// CHECK: %[[S18:.+]] = llvm.add %[[S15]], %[[S17]] : i32 +// CHECK: %[[S19:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK: %[[S20:.+]] = llvm.mul %[[S19]], %[[S4]] : i32 +// CHECK: %[[S21:.+]] = llvm.add %[[S13]], %[[S20]] : i32 +// CHECK: %[[S22:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S23:.+]] = arith.index_cast %[[S21]] : i32 to index +// CHECK: %[[S24:.+]] = llvm.add %[[S21]], %[[S6]] : i32 +// CHECK: %[[S25:.+]] = arith.index_cast %[[S24]] : i32 to index +// CHECK: %[[S26:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct +// CHECK: %[[S27:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct +// CHECK: memref.store %[[S26]], %[[arg1]][%[[S22]], %[[S23]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S27]], %[[arg1]][%[[S22]], %[[S25]]] : memref<128x128xf32, 3> + +// ### Store {d2, d3} of each thread ### + +// CHECK: %[[S28:.+]] = llvm.mlir.constant(1 : i32) : i32 +// CHECK: %[[S29:.+]] = llvm.mul %[[S28]], %[[S4]] : i32 +// CHECK: %[[S30:.+]] = llvm.add %[[S13]], %[[S29]] : i32 +// CHECK: %[[S31:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S32:.+]] = arith.index_cast %[[S30]] : i32 to index +// CHECK: %[[S33:.+]] = llvm.add %[[S30]], %[[S6]] : i32 +// CHECK: %[[S34:.+]] = arith.index_cast %[[S33]] : i32 to index +// CHECK: %[[S35:.+]] = llvm.extractvalue %[[S0]][4] : !llvm.struct< +// CHECK: %[[S36:.+]] = llvm.extractvalue %[[S0]][5] : !llvm.struct< +// CHECK: memref.store %[[S35]], %[[arg1]][%[[S31]], %[[S32]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S36]], %[[arg1]][%[[S31]], %[[S34]]] : memref<128x128xf32, 3> + +// ### Store {d4, d5} of each thread ### + +// CHECK: %[[S37:.+]] = llvm.mlir.constant(2 : i32) : i32 +// CHECK: %[[S38:.+]] = llvm.mul %[[S37]], %[[S4]] : i32 +// CHECK: %[[S39:.+]] = llvm.add %[[S13]], %[[S38]] : i32 +// CHECK: %[[S40:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S41:.+]] = arith.index_cast %[[S39]] : i32 to index +// CHECK: %[[S42:.+]] = llvm.add %[[S39]], %[[S6]] : i32 +// CHECK: %[[S43:.+]] = arith.index_cast %[[S42]] : i32 to index +// CHECK: %[[S44:.+]] = llvm.extractvalue %[[S0]][8] : !llvm.struct< +// CHECK: %[[S45:.+]] = llvm.extractvalue %[[S0]][9] : !llvm.struct< +// CHECK: memref.store %[[S44]], %[[arg1]][%[[S40]], %[[S41]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S45]], %[[arg1]][%[[S40]], %[[S43]]] : memref<128x128xf32, 3> + +// ### Store {d6, d7} of each thread ### + +// CHECK: %[[S46:.+]] = llvm.mlir.constant(3 : i32) : i32 +// CHECK: %[[S47:.+]] = llvm.mul %[[S46]], %[[S4]] : i32 +// CHECK: %[[S48:.+]] = llvm.add %[[S13]], %[[S47]] : i32 +// CHECK: %[[S49:.+]] = arith.index_cast %[[S18]] : i32 to index +// CHECK: %[[S50:.+]] = arith.index_cast %[[S48]] : i32 to index +// CHECK: %[[S51:.+]] = llvm.add %[[S48]], %[[S6]] : i32 +// CHECK: %[[S52:.+]] = arith.index_cast %[[S51]] : i32 to index +// CHECK: %[[S53:.+]] = llvm.extractvalue %[[S0]][12] : !llvm.struct< +// CHECK: %[[S54:.+]] = llvm.extractvalue %[[S0]][13] : !llvm.struct< +// CHECK: memref.store %[[S53]], %[[arg1]][%[[S49]], %[[S50]]] : memref<128x128xf32, 3> +// CHECK: memref.store %[[S54]], %[[arg1]][%[[S49]], %[[S52]]] : memref<128x128xf32, 3> + +// Pattern continues similarly 60x times until {... d126, d127} + + nvgpu.wargroup.mma.store [%result1], %matrixD : !nvgpu.warpgroup.result to memref<128x128xf32,3> + return +} + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1