diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -23,6 +23,7 @@ MLIRArithTransforms MLIRBufferizationDialect MLIRFuncDialect + MLIRGPUOps MLIRInferTypeOpInterface MLIRLoopLikeInterface MLIRMemRefDialect diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -210,6 +211,14 @@ return op.getSource(); } +static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) { + return op.getSrcMemref(); +} + +static Value getMemRefOperand(gpu::SubgroupMmaStoreMatrixOp op) { + return op.getDstMemref(); +} + /// Given the permutation map of the original /// `vector.transfer_read`/`vector.transfer_write` operations compute the /// permutation map to use after the subview is folded with it. @@ -401,6 +410,11 @@ transferReadOp.getPadding(), /*mask=*/Value(), transferReadOp.getInBoundsAttr()); }) + .Case([&](gpu::SubgroupMmaLoadMatrixOp op) { + rewriter.replaceOpWithNewOp( + op, op.getType(), subViewOp.getSource(), sourceIndices, + op.getLeadDimension(), op.getTransposeAttr()); + }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); return success(); } @@ -496,11 +510,11 @@ llvm::TypeSwitch(storeOp) .Case([&](AffineStoreOp op) { rewriter.replaceOpWithNewOp( - storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices); + op, op.getValue(), subViewOp.getSource(), sourceIndices); }) .Case([&](memref::StoreOp op) { rewriter.replaceOpWithNewOp( - storeOp, storeOp.getValue(), subViewOp.getSource(), sourceIndices, + op, op.getValue(), subViewOp.getSource(), sourceIndices, op.getNontemporal()); }) .Case([&](vector::TransferWriteOp op) { @@ -510,6 +524,11 @@ op.getPermutationMap()), op.getInBoundsAttr()); }) + .Case([&](gpu::SubgroupMmaStoreMatrixOp op) { + rewriter.replaceOpWithNewOp( + op, op.getSrc(), subViewOp.getSource(), sourceIndices, + op.getLeadDimension(), op.getTransposeAttr()); + }) .Default([](Operation *) { llvm_unreachable("unexpected operation."); }); return success(); } @@ -584,9 +603,11 @@ patterns.add, LoadOpOfSubViewOpFolder, LoadOpOfSubViewOpFolder, + LoadOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, StoreOpOfSubViewOpFolder, + StoreOpOfSubViewOpFolder, LoadOpOfExpandShapeOpFolder, LoadOpOfExpandShapeOpFolder, StoreOpOfExpandShapeOpFolder, diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -524,3 +524,54 @@ memref.store %arg5, %0[%arg3, %arg4] {nontemporal=true}: memref<4x4xf32, strided<[64, 3], offset: ?>> return } + +// ----- + +func.func @fold_gpu_subgroup_mma_load_matrix_1d(%src: memref>, %offset: index, %i: index) -> !gpu.mma_matrix<16x16xf16, "COp"> { + %subview = memref.subview %src[%offset] [81920] [1] : memref> to memref<81920xvector<4xf32>, strided<[1], offset: ?>> + %matrix = gpu.subgroup_mma_load_matrix %subview[%i] {leadDimension = 160 : index} : memref<81920xvector<4xf32>, strided<[1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp"> + return %matrix: !gpu.mma_matrix<16x16xf16, "COp"> +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func.func @fold_gpu_subgroup_mma_load_matrix_1d +// CHECK-SAME: (%[[SRC:.+]]: memref>, %[[OFFSET:.+]]: index, %[[I:.+]]: index) +// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[I]])[%[[OFFSET]]] +// CHECK: %[[LOAD:.+]] = gpu.subgroup_mma_load_matrix %[[SRC]][%[[APPLY]]] {leadDimension = 160 : index} : memref> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: return %[[LOAD]] + +// ----- + +func.func @fold_gpu_subgroup_mma_store_matrix_1d(%dst: memref>, %offset: index, %i: index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) { + %subview = memref.subview %dst[%offset] [81920] [1] : memref> to memref<81920xvector<4xf32>, strided<[1], offset: ?>> + gpu.subgroup_mma_store_matrix %matrix, %subview[%i] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<81920xvector<4xf32>, strided<[1], offset: ?>> + return +} + +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func.func @fold_gpu_subgroup_mma_store_matrix_1d +// CHECK-SAME: (%[[DST:.+]]: memref>, %[[OFFSET:.+]]: index, %[[I0:.+]]: index, %[[VAL:.+]]: !gpu.mma_matrix<16x16xf16, "COp">) +// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[I0]])[%[[OFFSET]]] +// CHECK: gpu.subgroup_mma_store_matrix %[[VAL]], %[[DST]][%[[APPLY]]] {leadDimension = 160 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref> + +// ----- + +// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d +// CHECK-SAME: %[[SRC:.+]]: memref<128x128xf32> +func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) -> !gpu.mma_matrix<16x16xf16, "COp"> { + %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>> + // CHECK: gpu.subgroup_mma_load_matrix %[[SRC]][{{.+}}] {leadDimension = 32 : index} : memref<128x128xf32> -> !gpu.mma_matrix<16x16xf16, "COp"> + %matrix = gpu.subgroup_mma_load_matrix %subview[%arg3, %arg4] {leadDimension = 32 : index} : memref<64x32xf32, strided<[64, 1], offset: ?>> -> !gpu.mma_matrix<16x16xf16, "COp"> + return %matrix : !gpu.mma_matrix<16x16xf16, "COp"> +} + +// ----- + +// CHECK-LABEL: func.func @fold_gpu_subgroup_mma_load_matrix_2d +// CHECK-SAME: %[[DST:.+]]: memref<128x128xf32> +func.func @fold_gpu_subgroup_mma_load_matrix_2d(%arg0 : memref<128x128xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %matrix: !gpu.mma_matrix<16x16xf16, "COp">) { + %subview = memref.subview %arg0[%arg1, %arg2][64, 32][2, 1] : memref<128x128xf32> to memref<64x32xf32, strided<[64, 1], offset: ?>> + // CHECK: gpu.subgroup_mma_store_matrix %{{.+}}, %[[DST]][{{.+}}] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<128x128xf32> + gpu.subgroup_mma_store_matrix %matrix, %subview[%arg3, %arg4] {leadDimension = 32 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<64x32xf32, strided<[64, 1], offset: ?>> + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9931,6 +9931,7 @@ ":ControlFlowDialect", ":DialectUtils", ":FuncDialect", + ":GPUDialect", ":IR", ":InferTypeOpInterface", ":LoopLikeInterface",