diff --git a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td --- a/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td +++ b/mlir/include/mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td @@ -15,6 +15,47 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" +//===----------------------------------------------------------------------===// +// CreateAsyncGroupsOp +//===----------------------------------------------------------------------===// + +def CreateAsyncGroupsOp : + Op, + TransformEachOpTrait, + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Look for global to shared memory copies within the targeted op in the form + of vector transfer ops and convert them to async copies when possible. + Consecutive copies are put into the same group. A "wait" operation is put + right at the of the group. + + `use_mma_sync` specifies whether `bypassL1` attributes should be added to + the async copies. + + #### Return modes + + This op reads the `target` handle and modifies the payload. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + UnitAttr:$use_mma_sync); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + //===----------------------------------------------------------------------===// // RewriteMatmulAsMmaSyncOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h @@ -17,6 +17,8 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { +class RewriterBase; + namespace nvgpu { /// @@ -68,6 +70,12 @@ RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision = nvgpu::MmaSyncF32Lowering::TF32); +/// Convert global->shared vector transfers to async device copies. This +/// function looks for suitable vector transfers within the specified op and +/// converts them to "nvgpu.device_async_copy" ops. Consecutive copies are put +/// into the same sync group. +void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool useMMASync); + } // namespace nvgpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h --- a/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h +++ b/mlir/include/mlir/Dialect/NVGPU/Transforms/Utils.h @@ -17,5 +17,12 @@ /// Set the indices that the given load/store operation is operating on. void setIndices(Operation *op, ArrayRef indices); +/// Get the value that is stored by the given store operation. +Value getValueStored(Operation *op); + +/// Get the memref that is loaded from/stored into by the given load/store +/// operation. +Value getMemrefOperand(Operation *op); + } // namespace nvgpu } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2311,6 +2311,13 @@ OpBuilder<(ins "VectorType":$type, "ArrayRef":$mixedOperands)> ]; + let extraClassDeclaration = [{ + /// Return the result type of this op. + VectorType getVectorType() { + return cast(getOperation()->getResultTypes()[0]); + } + }]; + let hasCanonicalizer = 1; let hasVerifier = 1; let assemblyFormat = "$operands attr-dict `:` type(results)"; diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/NVGPU/TransformOps/CMakeLists.txt @@ -13,6 +13,7 @@ MLIRIR MLIRLinalgDialect MLIRNVGPUDialect + MLIRNVGPUTransforms MLIRParser MLIRSideEffectInterfaces MLIRTransformDialect diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -39,6 +40,23 @@ #define DBGSNL() (llvm::dbgs() << "\n") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +//===---------------------------------------------------------------------===// +// CreateAsyncGroupsOp +//===---------------------------------------------------------------------===// + +void transform::CreateAsyncGroupsOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + +DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( + TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, TransformState &state) { + nvgpu::createAsyncGroups(rewriter, target, getUseMmaSync()); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // RewriteMatmulAsMmaSyncOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRNVGPUTransforms + CreateAsyncGroups.cpp OptimizeSharedMemory.cpp MmaSyncTF32Transform.cpp Utils.cpp diff --git a/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp @@ -0,0 +1,214 @@ +//===- CreateAsyncGroups.cpp - Create async device copies -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Transforms/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; + +/// Return "true" if the given op is a contiguous vector.transfer_write or +/// vector.store op. +static bool isContiguousStore(Operation *write) { + if (auto transferWrite = dyn_cast(write)) { + return transferWrite.getPermutationMap().isMinorIdentity() && + transferWrite.isDimInBounds(0) && !transferWrite.getMask() && + transferWrite.hasBufferSemantics() && + isLastMemrefDimUnitStride(cast( + nvgpu::getMemrefOperand(transferWrite).getType())); + } + // vector.store are always contiguous. + return isa(write); +} + +/// Return "true" if the given op is a contiguous vector.transfer_read or +/// vector.load op. +static bool isContiguousRead(Operation *read) { + if (auto transferRead = dyn_cast(read)) { + return transferRead.isDimInBounds(0) && + transferRead.getPermutationMap().isMinorIdentity() && + transferRead.hasBufferSemantics() && + isLastMemrefDimUnitStride(cast( + nvgpu::getMemrefOperand(transferRead).getType())); + } + // vector.load are always contiguous. + return isa(read); +} + +/// If the given vector load op has a mask that is defined by +/// vector.create_mask, return that op. +static vector::CreateMaskOp getMaskOp(Operation *loadOp) { + // TODO: Support 2D masks and higher. + auto transferRead = dyn_cast(loadOp); + if (!transferRead || !transferRead.getMask()) + return {}; + return transferRead.getMask().getDefiningOp(); +} + +/// Return "true" if the conversion to async copy is legal. +static bool resultsInSupportedAsyncCopy(MemRefType memrefType, + Operation::operand_range indices, + VectorType vecType) { + assert(vecType.getRank() == 1 && "expected 1-D vector"); + constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16}; + + // Condition 1: the copy size must be supported. + bool supportedCopySize = false; + int64_t numElements = vecType.getNumElements(); + Type elementType = vecType.getElementType(); + for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) { + if (alignmentInBytes * 8 == + numElements * elementType.getIntOrFloatBitWidth()) { + supportedCopySize = true; + break; + } + } + if (!supportedCopySize) + return false; + + // TODO: Condition 2: the alignments must be supported. For cp.async the + // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to + // a multiple of the access size. If an address is not properly aligned, the + // resulting behavior is undefined.". + return true; +} + +void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, + bool useMMASync) { + llvm::SmallSetVector copyToSharedMem; + + // Look for all the copy that can be converted to async copy ops. + op->walk([&](Operation *writeOp) { + // Look for contiguous 1D vector store into shared memory. + if (!isContiguousStore(writeOp)) + return WalkResult::advance(); + Value vectorVal = nvgpu::getValueStored(writeOp); + if (llvm::cast(vectorVal.getType()).getRank() != 1) + return WalkResult::advance(); + Value storeBase = nvgpu::getMemrefOperand(writeOp); + if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( + llvm::cast(storeBase.getType()))) + return WalkResult::advance(); + + // The stored vector must originate from a contiguous 1D vector load. + Operation *readOp = vectorVal.getDefiningOp(); + if (readOp == nullptr || !isContiguousRead(readOp)) + return WalkResult::advance(); + Value loadBase = nvgpu::getMemrefOperand(readOp); + // Should be reading from global memory (not shared memory). + if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( + llvm::cast(loadBase.getType()))) + return WalkResult::advance(); + + // Look for compatible mask and padding. + if (auto transferRead = dyn_cast(readOp)) { + if (Value mask = transferRead.getMask()) { + if (getConstantIntValue(transferRead.getPadding()) == + static_cast(0)) + return WalkResult::advance(); + if (!getMaskOp(readOp)) + return WalkResult::advance(); + } + } + + // Check whether both accesses are supported before we emit: this is + // necessary to ensure the correctness of DeviceAsyncCopyOp. + VectorType vecType = llvm::cast(vectorVal.getType()); + + if (!resultsInSupportedAsyncCopy(cast(loadBase.getType()), + nvgpu::getIndices(readOp), vecType) || + !resultsInSupportedAsyncCopy(cast(storeBase.getType()), + nvgpu::getIndices(writeOp), vecType)) + return WalkResult::advance(); + + copyToSharedMem.insert(writeOp); + return WalkResult::advance(); + }); + + while (!copyToSharedMem.empty()) { + // Start a group with the first write. + SmallVector group; + Operation *writeOp = *copyToSharedMem.begin(); + copyToSharedMem.remove(writeOp); + group.push_back(writeOp); + Operation *nextNode = writeOp; + + // Look in the next nodes for more copies to add to the same group. + while ((nextNode = nextNode->getNextNode())) { + // Ignore ops without side effects. + auto memInterface = dyn_cast(nextNode); + if (memInterface && memInterface.hasNoEffect() && + !nextNode->hasTrait()) + continue; + // Ignore read from a different address space. + if (isa(nextNode)) { + Operation *readOp = nextNode; + Value memrefOperand = nvgpu::getMemrefOperand(readOp); + if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( + llvm::cast(memrefOperand.getType()))) { + continue; + } + } + if (copyToSharedMem.count(nextNode)) { + // Found another copy, add it to the group. + copyToSharedMem.remove(nextNode); + group.push_back(nextNode); + continue; + } + // If the op is something else stop the accumulating op in the group. + break; + } + + // Emit the group. + SmallVector tokens; + for (Operation *writeOp : group) { + rewriter.setInsertionPoint(writeOp); + Value vectorVal = nvgpu::getValueStored(writeOp); + auto vectorType = llvm::cast(vectorVal.getType()); + int64_t numElements = vectorType.getNumElements(); + Operation *readOp = vectorVal.getDefiningOp(); + Value storeBase = nvgpu::getMemrefOperand(writeOp); + Value loadBase = nvgpu::getMemrefOperand(readOp); + Value numReadElements; + if (vector::CreateMaskOp maskOp = getMaskOp(readOp)) { + assert(maskOp.getNumOperands() == 1 && "expected single operand"); + numReadElements = maskOp.getOperand(0); + } + auto dstMemref = llvm::cast(storeBase.getType()); + int64_t sizeInBytes = + (dstMemref.getElementTypeBitWidth() * numElements) / 8; + UnitAttr bypassL1 = + useMMASync && sizeInBytes == 16 ? rewriter.getUnitAttr() : UnitAttr(); + Value token = rewriter.create( + writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), + /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp), + /*src=*/loadBase, + /*srcIndices=*/nvgpu::getIndices(readOp), + /*dstElements=*/rewriter.getIndexAttr(numElements), + /*srcElements=*/numReadElements, + /*bypassL1=*/bypassL1); + tokens.push_back(token); + } + + // Create the group and wait for it right after. + Value groupToken = rewriter.create( + op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), + tokens); + rewriter.create(op->getLoc(), groupToken, + nullptr); + // Clean up old stores. + for (Operation *writeOp : group) + rewriter.eraseOp(writeOp); + } +} diff --git a/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp --- a/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/Utils.cpp @@ -28,6 +28,10 @@ return vectorReadOp.getIndices(); if (auto vectorStoreOp = dyn_cast(op)) return vectorStoreOp.getIndices(); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndices(); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndices(); llvm_unreachable("unsupported op type"); } @@ -44,5 +48,35 @@ return vectorReadOp.getIndicesMutable().assign(indices); if (auto vectorStoreOp = dyn_cast(op)) return vectorStoreOp.getIndicesMutable().assign(indices); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndicesMutable().assign(indices); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndicesMutable().assign(indices); + llvm_unreachable("unsupported op type"); +} + +Value nvgpu::getValueStored(Operation *op) { + if (auto storeOp = dyn_cast(op)) + return storeOp.getValueToStore(); + if (auto transferWrite = dyn_cast(op)) + return transferWrite.getValue(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getValueToStore(); + llvm_unreachable("unsupported op type"); +} + +Value nvgpu::getMemrefOperand(Operation *op) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getMemref(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getMemref(); + if (auto transferWrite = dyn_cast(op)) + return transferWrite.getSource(); + if (auto transferRead = dyn_cast(op)) + return transferRead.getSource(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getBase(); + if (auto loadOp = dyn_cast(op)) + return loadOp.getBase(); llvm_unreachable("unsupported op type"); } diff --git a/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/NVGPU/transform-create-async-groups.mlir @@ -0,0 +1,153 @@ +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file --verify-diagnostics | FileCheck %s + +// Check that we produce async copies from the vector.transfer_xxx operations. +builtin.module { + // CHECK-LABEL: @copies_to_asyncs + func.func @copies_to_asyncs(%a: memref<1024x1024xf32>) { + %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + // Make sure we emit the bypassL1. + // CHECK: %[[CP0:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 4 {bypassL1} : + %1 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<4xf32> + vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<4x32x16xf32, #gpu.address_space> + // CHECK-NOT: nvgpu.device_async_create_group + + // CHECK: %[[CP1:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 1 + %2 = vector.transfer_read %a[%c0, %c4], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<1xf32> + vector.transfer_write %2, %0[%c0, %c4, %c0] {in_bounds = [true]} : vector<1xf32>, memref<4x32x16xf32, #gpu.address_space> + // CHECK: %[[G:.*]] = nvgpu.device_async_create_group %[[CP0]], %[[CP1]] + // CHECK: nvgpu.device_async_wait %[[G]] + return + } + + transform.sequence failures(propagate) { + ^bb1(%variant_op: !transform.any_op): + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.nvgpu.create_async_groups %top_level_func {use_mma_sync} : (!transform.any_op) -> () + } +} + +// ----- + +// Check that we properly take `use_mma_sync = false` into account. +// I.e., we shouldn't be generating bypassL1 attributes. +builtin.module { + // CHECK-LABEL: @copies_to_asyncs_no_mma + func.func @copies_to_asyncs_no_mma(%a: memref<1024x1024xf32>) { + %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + // Make sure we don't emit the bypassL1. + // CHECK: %[[CP0:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 4 : + %1 = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<4xf32> + vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<4x32x16xf32, #gpu.address_space> + // CHECK-NOT: nvgpu.device_async_create_group + + // CHECK: %[[CP1:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 1 : + %2 = vector.transfer_read %a[%c0, %c4], %cst_0 {in_bounds = [true]} : memref<1024x1024xf32>, vector<1xf32> + vector.transfer_write %2, %0[%c0, %c4, %c0] {in_bounds = [true]} : vector<1xf32>, memref<4x32x16xf32, #gpu.address_space> + // CHECK: %[[G:.*]] = nvgpu.device_async_create_group %[[CP0]], %[[CP1]] + // CHECK: nvgpu.device_async_wait %[[G]] + return + } + + transform.sequence failures(propagate) { + ^bb1(%variant_op: !transform.any_op): + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.nvgpu.create_async_groups %top_level_func : (!transform.any_op) -> () + } +} + +// ----- + +// Check that pattern works with vector.load/vector.store. +builtin.module { + // CHECK-LABEL: @copies_to_asyncs_load_store + func.func @copies_to_asyncs_load_store(%a: memref<1024x1024xf32>) { + %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + // CHECK: %[[CP0:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 4 : + %1 = vector.load %a[%c0, %c0] : memref<1024x1024xf32>, vector<4xf32> + vector.store %1, %0[%c0, %c0, %c0] : memref<4x32x16xf32, #gpu.address_space>, vector<4xf32> + // CHECK-NOT: nvgpu.device_async_create_group + + // CHECK: %[[CP1:.*]] = nvgpu.device_async_copy {{.*}}, {{.*}}, 1 : + %2 = vector.load %a[%c0, %c4] : memref<1024x1024xf32>, vector<1xf32> + vector.store %2, %0[%c0, %c4, %c0] : memref<4x32x16xf32, #gpu.address_space>, vector<1xf32> + // CHECK: %[[G:.*]] = nvgpu.device_async_create_group %[[CP0]], %[[CP1]] + // CHECK: nvgpu.device_async_wait %[[G]] + return + } + + transform.sequence failures(propagate) { + ^bb1(%variant_op: !transform.any_op): + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.nvgpu.create_async_groups %top_level_func : (!transform.any_op) -> () + } +} + +// ----- + +// Check that pattern skips unaligned and unsupported sizes. +builtin.module { + // CHECK-LABEL: @copies_to_asyncs_load_store + func.func @copies_to_asyncs_load_store(%a: memref<1024x1024xf32>, %b: memref<1024x1024xf16>) { + %alloc = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> + %alloc_1 = memref.alloc() : memref<4x32x16xf16, #gpu.address_space> + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + + // Requires 1-D vector load + // CHECK-NOT: nvgpu.device_async_copy + // CHECK: vector.load + // CHECK: vector.store + %1 = vector.load %a[%c0, %c4] : memref<1024x1024xf32>, vector<2x2xf32> + vector.store %1, %alloc[%c0, %c4, %c0] : memref<4x32x16xf32, #gpu.address_space>, vector<2x2xf32> + // CHECK-NOT: nvgpu.device_async_create_group + + // CHECK-NOT: nvgpu.device_async_copy + // CHECK: vector.load + // CHECK: vector.store + %2 = vector.load %b[%c0, %c4] : memref<1024x1024xf16>, vector<1xf16> + vector.store %2, %alloc_1[%c0, %c4, %c0] : memref<4x32x16xf16, #gpu.address_space>, vector<1xf16> + // CHECK-NOT: nvgpu.device_async_create_group + return + } + + transform.sequence failures(propagate) { + ^bb1(%variant_op: !transform.any_op): + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.nvgpu.create_async_groups %top_level_func : (!transform.any_op) -> () + } +} + +// ----- + +// vector.transfer_read with a mask. +builtin.module { + // CHECK-LABEL: @read_with_mask( + // CHECK-SAME: %{{.*}}: memref<1024x1024xf32>, %[[sz:.*]]: index + func.func @read_with_mask(%a: memref<1024x1024xf32>, %sz: index) { + %0 = memref.alloc() : memref<4x32x16xf32, #gpu.address_space> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant 0.000000e+00 : f32 + // CHECK: nvgpu.device_async_copy {{.*}}, {{.*}}, 4, %[[sz]] {bypassL1} : + %mask = vector.create_mask %sz : vector<4xi1> + %1 = vector.transfer_read %a[%c0, %c0], %cst_0, %mask {in_bounds = [true]} : memref<1024x1024xf32>, vector<4xf32> + vector.transfer_write %1, %0[%c0, %c0, %c0] {in_bounds = [true]} : vector<4xf32>, memref<4x32x16xf32, #gpu.address_space> + + return + } + + transform.sequence failures(propagate) { + ^bb1(%variant_op: !transform.any_op): + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.nvgpu.create_async_groups %top_level_func {use_mma_sync} : (!transform.any_op) -> () + } +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2795,6 +2795,7 @@ ":MemRefDialect", ":NVGPUDialect", ":NVGPUTransformOpsIncGen", + ":NVGPUTransforms", ":Support", ":TransformDialect", ":VectorDialect",