Index: mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp =================================================================== --- mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -10,7 +10,6 @@ // loading/storing from/to the original memref. // //===----------------------------------------------------------------------===// - #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -19,6 +18,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/AffineMap.h" @@ -26,6 +26,10 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "fold-memref-alias-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace mlir { namespace memref { @@ -293,6 +297,17 @@ return success(); } }; + +/// Folds nvgpu.device_async_copy subviews into the copy itself. This pattern +/// is folds subview on src and dst memref of the copy. +class NvgpuAsyncCopyOpSubViewOpFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(nvgpu::DeviceAsyncCopyOp copyOp, + PatternRewriter &rewriter) const override; +}; } // namespace static SmallVector @@ -590,6 +605,60 @@ return success(); } +LogicalResult NvgpuAsyncCopyOpSubViewOpFolder::matchAndRewrite( + nvgpu::DeviceAsyncCopyOp copyOp, PatternRewriter &rewriter) const { + + LLVM_DEBUG(DBGS() << "copyOp : " << copyOp << "\n"); + + Location loc = copyOp.getLoc(); + auto srcSubViewOp = + copyOp.getSrc().template getDefiningOp(); + auto dstSubViewOp = + copyOp.getDst().template getDefiningOp(); + + if (!(srcSubViewOp || dstSubViewOp)) + return rewriter.notifyMatchFailure(copyOp, "does not use subview ops for " + "source or destination"); + + // If the source is a subview, we need to resolve the indices. + SmallVector srcindices(copyOp.getSrcIndices().begin(), + copyOp.getSrcIndices().end()); + SmallVector foldedSrcIndices(srcindices); + + if (srcSubViewOp) { + LLVM_DEBUG(DBGS() << "srcSubViewOp : " << srcSubViewOp << "\n"); + resolveSourceIndicesOffsetsAndStrides( + rewriter, copyOp.getLoc(), srcSubViewOp.getMixedOffsets(), + srcSubViewOp.getMixedStrides(), srcSubViewOp.getDroppedDims(), + srcindices, foldedSrcIndices); + } + + // If the destination is a subview, we need to resolve the indices. + SmallVector dstindices(copyOp.getDstIndices().begin(), + copyOp.getDstIndices().end()); + SmallVector foldedDstIndices(dstindices); + + if (dstSubViewOp) { + LLVM_DEBUG(DBGS() << "dstSubViewOp : " << dstSubViewOp << "\n"); + resolveSourceIndicesOffsetsAndStrides( + rewriter, copyOp.getLoc(), dstSubViewOp.getMixedOffsets(), + dstSubViewOp.getMixedStrides(), dstSubViewOp.getDroppedDims(), + dstindices, foldedDstIndices); + } + + // Replace the copy op with a new copy op that uses the source and destination + // of the subview. + rewriter.replaceOpWithNewOp( + copyOp, nvgpu::DeviceAsyncTokenType::get(copyOp.getContext()), + (dstSubViewOp ? dstSubViewOp.getSource() : copyOp.getDst()), + foldedDstIndices, + (srcSubViewOp ? srcSubViewOp.getSource() : copyOp.getSrc()), + foldedSrcIndices, copyOp.getDstElements(), copyOp.getSrcElements(), + copyOp.getBypassL1Attr()); + + return success(); +} + void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { patterns.add, LoadOpOfSubViewOpFolder, @@ -607,7 +676,8 @@ LoadOpOfCollapseShapeOpFolder, StoreOpOfCollapseShapeOpFolder, StoreOpOfCollapseShapeOpFolder, - SubViewOfSubViewFolder>(patterns.getContext()); + SubViewOfSubViewFolder, NvgpuAsyncCopyOpSubViewOpFolder>( + patterns.getContext()); } //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir =================================================================== --- mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -541,3 +541,18 @@ gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>> return } + +// ----- + +// CHECK-LABEL: func.func @fold_nvgpu_device_async_copy +// CHECK-SAME: (%[[GMEM_MEMREF_3d:.+]]: memref<2x128x768xf16>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index, %[[IDX_3:.+]]: index) +func.func @fold_nvgpu_device_async_copy(%gmem_memref_3d : memref<2x128x768xf16>, %idx_1 : index, %idx_2 : index, %idx_3 : index) { + // CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-DAG: %[[SMEM_MEMREF_4d:.+]] = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space> + %smem_memref_4d = memref.alloc() : memref<5x1x64x64xf16, #gpu.address_space> + %gmem_memref_subview_2d = memref.subview %gmem_memref_3d[%idx_1, %idx_2, %idx_3] [1, 1, 8] [1, 1, 1] : memref<2x128x768xf16> to memref<1x8xf16, strided<[98304, 1], offset: ?>> + // CHECK: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[IDX_1]], %[[IDX_2]], %[[IDX_3]]], %[[SMEM_MEMREF_4d]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> + %async_token = nvgpu.device_async_copy %gmem_memref_subview_2d[%c0, %c0], %smem_memref_4d[%c0, %c0, %c0, %c0], 8 {bypassL1} : memref<1x8xf16, strided<[98304, 1], offset: ?>> to memref<5x1x64x64xf16, #gpu.address_space> + return +} \ No newline at end of file