Index: mlir/include/mlir/Dialect/GPU/GPUBase.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUBase.td +++ mlir/include/mlir/Dialect/GPU/GPUBase.td @@ -60,6 +60,13 @@ GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::AsyncTokenType>()">, "async token type">, BuildableType<"mlir::gpu::AsyncTokenType::get($_builder.getContext())">; +/// Device-side synchronization token. +def GPU_DeviceAsyncToken : DialectType< + GPU_Dialect, CPred<"$_self.isa<::mlir::gpu::DeviceAsyncTokenType>()">, + "device async token type">, + BuildableType< + "mlir::gpu::DeviceAsyncTokenType::get($_builder.getContext())">; + // Predicat to check if type is gpu::MMAMatrixType. def IsMMAMatrixTypePred : CPred<"$_self.isa<::mlir::gpu::MMAMatrixType>()">; Index: mlir/include/mlir/Dialect/GPU/GPUDialect.h =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUDialect.h +++ mlir/include/mlir/Dialect/GPU/GPUDialect.h @@ -46,6 +46,14 @@ using Base::Base; }; +/// Device-side token storage type. There is only one type of device-side token. +class DeviceAsyncTokenType + : public Type::TypeBase { +public: + // Used for generic hooks in TypeBase. + using Base::Base; +}; + /// MMAMatrixType storage and uniquing. Array is uniqued based on its shape /// and type. struct MMAMatrixStorageType : public TypeStorage { Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -1225,4 +1225,119 @@ }]; } +def GPU_DeviceAsyncCopyOp : GPU_Op<"device_async_copy", + [AttrSizedOperandSegments]> { + let summary = "device-side asynchronous copy"; + let description = [{ + The `gpu.device_async_copy` op initiates an asynchronous copy operation of + `$size` bytes from source to the destination without blocking the thread. + The destination has to be in shared memory. + + This is memory access will be pending to be added to a group. + + This op is meant to be used with `gpu.device_async_create_group` and + `gpu.device_async_wait` to synchronize copies as explained in those ops + descriptions. + + In order to do a copy and wait for the result we need the following + combination: + ``` + // copy 1. + gpu.device_async_copy %A[%c0], %B[%c0], 16 :memref<16xf32> to memref<16xf32, 3> + // copy 2. + gpu.device_async_copy %C[%c0], %D[%c0], 16 : memref<16xf32> to memref<16xf32, 3> + // group 1 contains copy 1 and copy 2. + %token1 = gpu.device_async_create_group + // copy 3. + gpu.device_async_copy %E[%c0], %F[%c0], 16 : memref<16xf32> to memref<16xf32, 3> + // group 2 contains copy 3. + %token2 = gpu.device_async_create_group + // after the wait copy 1 and copy 2 are complete. + gpu.device_async_wait %token1 + // after the wait copy 3 is complete. + gpu.device_async_wait %token2 + ``` + + Example: + + ```mlir + gpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 16 : + memref<4x5xf32> to memref<2x7x5xf32, 3> + ``` + }]; + let arguments = (ins Arg:$dst, + Variadic:$dstIndices, + Arg:$src, + Variadic:$srcIndices, + IndexAttr:$size); + let assemblyFormat = [{ + $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` `,` $size attr-dict + `:` type($src) `to` type($dst) + }]; + let hasVerifier = 1; +} + +def GPU_DeviceAsyncCreateGroupOp : GPU_Op<"device_async_create_group", []> { + let summary = "device side asynchronous create group operation"; + let description = [{ + The `gpu.device_async_create_group` op creates a group of memory accesses + containing all the pending memory accesses created by preceding + `device_async_copy` operations. + It returns a token that can be use to wait until the group fully completes. + + This is meant to be used with `gpu.device_async_wait` or + `gpu.device_async_wait_group` to synchronize copies as explained in those ops + descriptions. + + Groups are executed in the order they are created. + + Example: + + ```mlir + %0 = gpu.device_async_create_group + ``` + }]; + let results = (outs GPU_DeviceAsyncToken:$asyncToken); + let assemblyFormat = [{ + attr-dict + }]; +} + +def GPU_DeviceAsyncWaitTokenOp : GPU_Op<"device_async_wait", []> { + let summary = "Wait for async gpu ops to complete."; + let description = [{ + The `gpu.device_async_wait` op will block the execution thread until the group + associated with the source token is fully completed. + + Example: + + ```mlir + gpu.device_async_wait %0 + ``` + }]; + let arguments = (ins GPU_DeviceAsyncToken:$asyncDependencies); + let assemblyFormat = [{ + $asyncDependencies attr-dict + }]; +} + +def GPU_DeviceAsyncWaitGroupOp : GPU_Op<"device_async_wait_group", []> { + let summary = "Wait for async gpu ops to complete."; + let description = [{ + The `gpu.device_async_wait` op will block the execution thread until the + number of groups uncompleted becomes lower or equal to the integer attribute + `numGroups`. + + Example: + + ```mlir + gpu.device_async_wait_group 1 + ``` + }]; + let arguments = (ins I32Attr:$numGroups); + let assemblyFormat = [{ + $numGroups attr-dict + }]; +} + #endif // GPU_OPS Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -40,6 +40,14 @@ namespace { +/// NVVM memory space identifiers. +enum NVVMMemorySpace { + /// Global memory space identifier. + kGlobalMemorySpace = 1, + /// Shared memory space identifier. + kSharedMemorySpace = 3 +}; + /// Convert gpu dialect shfl mode enum to the equivalent nvvm one. static NVVM::ShflKind convertShflKind(gpu::ShuffleMode mode) { switch (mode) { @@ -122,6 +130,74 @@ } }; +struct GPUAsyncCopyLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::DeviceAsyncCopyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto dstMemrefType = op.dst().getType().cast(); + Value dstPtr = getStridedElementPtr(loc, dstMemrefType, adaptor.dst(), + adaptor.dstIndices(), rewriter); + auto i8Ty = IntegerType::get(op.getContext(), 8); + auto dstPointerType = + LLVM::LLVMPointerType::get(i8Ty, dstMemrefType.getMemorySpaceAsInt()); + dstPtr = rewriter.create(loc, dstPointerType, dstPtr); + + auto srcMemrefType = op.src().getType().cast(); + + Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.src(), + adaptor.srcIndices(), rewriter); + auto srcPointerType = + LLVM::LLVMPointerType::get(i8Ty, srcMemrefType.getMemorySpaceAsInt()); + scrPtr = rewriter.create(loc, srcPointerType, scrPtr); + // Intrinsics takes a global pointer so we need an address space cast. + auto srcPointerGlobalType = + LLVM::LLVMPointerType::get(i8Ty, NVVMMemorySpace::kGlobalMemorySpace); + scrPtr = rewriter.create(loc, srcPointerGlobalType, + scrPtr); + + rewriter.replaceOpWithNewOp( + op, dstPtr, scrPtr, + rewriter.getI32IntegerAttr(adaptor.size().getZExtValue())); + return success(); + } +}; + +struct GPUAsyncCreateGroupLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::DeviceAsyncCreateGroupOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::DeviceAsyncCreateGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Token cannot be represented in NVVM, only lower if the op has no use. + if (!op.use_empty()) + return failure(); + rewriter.create(op.getLoc()); + rewriter.eraseOp(op); + return success(); + } +}; + +struct GPUAsyncWaitLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + gpu::DeviceAsyncWaitGroupOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::DeviceAsyncWaitGroupOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int32_t numGroups = adaptor.numGroups(); + rewriter.create(op.getLoc(), numGroups); + rewriter.eraseOp(op); + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -259,6 +335,8 @@ "__nv_sqrt"); patterns.add>(converter, "__nv_tanhf", "__nv_tanh"); + patterns.add(converter); } std::unique_ptr> Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -117,6 +117,7 @@ void GPUDialect::initialize() { addTypes(); + addTypes(); addTypes(); addOperations< #define GET_OP_LIST @@ -139,6 +140,9 @@ // Handle 'async token' types. if (keyword == "async.token") return AsyncTokenType::get(context); + // Handle 'device async token' types. + if (keyword == "device.async.token") + return DeviceAsyncTokenType::get(context); if (keyword == "mma_matrix") { SMLoc beginLoc = parser.getNameLoc(); @@ -179,6 +183,7 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) .Case([&](Type) { os << "async.token"; }) + .Case([&](Type) { os << "device.async.token"; }) .Case([&](MMAMatrixType fragTy) { os << "mma_matrix<"; auto shape = fragTy.getShape(); @@ -1187,6 +1192,28 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// GPU_DeviceAsyncCopyOp +//===----------------------------------------------------------------------===// + +LogicalResult DeviceAsyncCopyOp::verify() { + auto srcMemref = src().getType().cast(); + auto dstMemref = dst().getType().cast(); + unsigned workgroupAddressSpace = GPUDialect::getWorkgroupAddressSpace(); + if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace) + return emitError("destination memref must have memory space ") + << workgroupAddressSpace; + if (dstMemref.getElementType() != srcMemref.getElementType()) + return emitError("source and destination must have the same element type"); + if (size_t(srcMemref.getRank()) != srcIndices().size()) + return emitOpError() << "expected " << srcMemref.getRank() + << " source indices, got " << srcIndices().size(); + if (size_t(dstMemref.getRank()) != dstIndices().size()) + return emitOpError() << "expected " << dstMemref.getRank() + << " destination indices, got " << dstIndices().size(); + return success(); +} + #include "mlir/Dialect/GPU/GPUOpInterfaces.cpp.inc" #include "mlir/Dialect/GPU/GPUOpsEnums.cpp.inc" Index: mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -479,3 +479,37 @@ gpu.return } } + +// ----- + +gpu.module @test_module { + // CHECK-LABEL: @async_cp( + // CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: i64) + gpu.func @async_cp( + %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index) kernel { + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(2048 : index) : i64 + // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX]], %[[S0]] : i64 + // CHECK-DAG: %[[S1:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[FI0:.*]] = llvm.mul %[[IDX]], %[[S1]] : i64 + // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[FI0]] : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX]] : i64 + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr to !llvm.ptr + + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S3:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[FI3:.*]] = llvm.mul %[[IDX]], %[[S3]] : i64 + // CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX]] : i64 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + gpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 16 : memref<128x128xf32> to memref<3x16x128xf32, 3> + // CHECK: nvvm.cp.async.commit.group + %0 = gpu.device_async_create_group + // CHECK: nvvm.cp.async.wait.group 1 + gpu.device_async_wait_group 1 + gpu.return + } +} Index: mlir/test/Dialect/GPU/invalid.mlir =================================================================== --- mlir/test/Dialect/GPU/invalid.mlir +++ mlir/test/Dialect/GPU/invalid.mlir @@ -555,3 +555,35 @@ %D = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<16x32xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp"> return } + +// ----- + +func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () { + // expected-error @+1 {{destination memref must have memory space 3}} + gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32> + return +} + +// ----- + +func @async_cp_memref_type(%dst : memref<16xi32, 3>, %src : memref<16xf32>, %i : index) -> () { + // expected-error @+1 {{source and destination must have the same element type}} + gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xi32, 3> + return +} + +// ----- + +func @async_cp_num_src_indices(%dst : memref<16xf32, 3>, %src : memref<16x16xf32>, %i : index) -> () { + // expected-error @+1 {{expected 2 source indices, got 1}} + gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16x16xf32> to memref<16xf32, 3> + return +} + +// ----- + +func @async_cp_num_dst_indices(%dst : memref<16x16xf32, 3>, %src : memref<16xf32>, %i : index) -> () { + // expected-error @+1 {{expected 2 destination indices, got 1}} + gpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16x16xf32, 3> + return +} Index: mlir/test/Dialect/GPU/ops.mlir =================================================================== --- mlir/test/Dialect/GPU/ops.mlir +++ mlir/test/Dialect/GPU/ops.mlir @@ -240,4 +240,18 @@ %3 = gpu.subgroup_mma_elementwise maxf %2, %1 : (!gpu.mma_matrix<16x16xf32, "COp">, !gpu.mma_matrix<16x16xf32, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> return } + + func @async_cp(%dst : memref<2x7x5xf32, 3>, %src : memref<4x5xf32>){ + // CHECK-LABEL: func @async_cp + %c0 = arith.constant 0 : index + // CHECK: gpu.device_async_copy %{{.*}}[{{.*}}, {{.*}}], %{{.*}}[{{.*}}, {{.*}}, {{.*}}], 16 : memref<4x5xf32> to memref<2x7x5xf32, 3> + gpu.device_async_copy %src[%c0, %c0], %dst[%c0, %c0, %c0], 16 : memref<4x5xf32> to memref<2x7x5xf32, 3> + // CHECK: %{{.*}} = gpu.device_async_create_group + %token = gpu.device_async_create_group + // CHECK: gpu.device_async_wait %{{.*}} + gpu.device_async_wait %token + // CHECK: gpu.device_async_wait_group 1 + gpu.device_async_wait_group 1 + return + } }