diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -169,6 +169,10 @@ return op.getSource(); } +static Value getMemRefOperand(nvgpu::LdMatrixOp op) { + return op.getSrcMemref(); +} + static Value getMemRefOperand(vector::TransferWriteOp op) { return op.getSource(); } @@ -406,6 +410,11 @@ op, op.getType(), subViewOp.getSource(), sourceIndices, op.getLeadDimension(), op.getTransposeAttr()); }) + .Case([&](nvgpu::LdMatrixOp op) { + rewriter.replaceOpWithNewOp( + op, op.getType(), subViewOp.getSource(), sourceIndices, + op.getTranspose(), op.getNumTiles()); + }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); return success(); } @@ -658,6 +667,7 @@ void memref::populateFoldMemRefAliasOpPatterns(RewritePatternSet &patterns) { patterns.add, LoadOpOfSubViewOpFolder, + LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -599,3 +599,25 @@ // CHECK-DAG: %[[RESOLVED_DST_IDX_1:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_1]], %[[DEST_SUB_IDX_0]]] // CHECK-DAG: %[[RESOLVED_DST_IDX_3:.+]] = affine.apply #[[MAP]]()[%[[DEST_IDX_3]], %[[DEST_SUB_IDX_1]]] // CHECK-DAG: nvgpu.device_async_copy %[[GMEM_MEMREF_3d]][%[[RESOLVED_SRC_IDX_0]], %[[SRC_IDX_1]], %[[RESOLVED_SRC_IDX_1]]], %[[SMEM_MEMREF_4d]][%[[DEST_IDX_0]], %[[RESOLVED_DST_IDX_1]], %[[DEST_IDX_2]], %[[RESOLVED_DST_IDX_3]]], 8 {bypassL1} : memref<2x128x768xf16> to memref<5x1x64x64xf16, #gpu.address_space> + +// ----- + +#map = affine_map<()[s0] -> (-s0 + 4)> +#map1 = affine_map<()[s0] -> (-s0 + 32)> + +func.func @test_ldmatrix(%arg0: memref<4x32x32xf16, 3>, %arg1: index, %arg2: index, %arg3: index) -> vector<4x2xf16> { + %c0 = arith.constant 0 : index + %0 = affine.apply #map()[%arg1] + %1 = affine.apply #map1()[%arg2] + %2 = affine.apply #map1()[%arg3] + %subview = memref.subview %arg0[%arg1, %arg2, %arg3] [%0, %1, %2] [1, 1, 1] : memref<4x32x32xf16, 3> to memref, 3> + %3 = nvgpu.ldmatrix %subview[%c0, %c0, %c0] {numTiles = 4 : i32, transpose = false} : memref, 3> -> vector<4x2xf16> + return %3 : vector<4x2xf16> +} + +// CHECK: func @test_ldmatrix +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<4x32x32xf16, 3> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: nvgpu.ldmatrix %[[ARG0]][%[[ARG1]], %[[ARG2]], %[[ARG3]]] {numTiles = 4 : i32, transpose = false} : memref<4x32x32xf16, 3> -> vector<4x2xf16>