diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -330,16 +330,13 @@ if (failed(resolveSourceIndicesSubView(loadOp.getLoc(), rewriter, subViewOp, indices, sourceIndices))) return failure(); + llvm::TypeSwitch(loadOp) .Case([&](auto op) { rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), sourceIndices); }) .Case([&](vector::TransferReadOp transferReadOp) { - if (transferReadOp.getTransferRank() == 0) { - // TODO: Propagate the error. - return; - } rewriter.replaceOpWithNewOp( transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), sourceIndices, @@ -439,15 +436,13 @@ if (failed(resolveSourceIndicesSubView(storeOp.getLoc(), rewriter, subViewOp, indices, sourceIndices))) return failure(); + llvm::TypeSwitch(storeOp) .Case([&](auto op) { rewriter.replaceOpWithNewOp( storeOp, storeOp.getValue(), subViewOp.source(), sourceIndices); }) .Case([&](vector::TransferWriteOp op) { - // TODO: support 0-d corner case. - if (op.getTransferRank() == 0) - return; rewriter.replaceOpWithNewOp( op, op.getValue(), subViewOp.source(), sourceIndices, getPermutationMapAttr(rewriter.getContext(), subViewOp, diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -81,6 +81,28 @@ // ----- +func.func @fold_subview_with_transfer_read_0d( + %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) + -> vector { + %f1 = arith.constant 1.0 : f32 + %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref> + %1 = vector.transfer_read %0[], %f1 : memref>, vector + return %1 : vector +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> +// CHECK: func @fold_subview_with_transfer_read_0d +// CHECK-SAME: %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32> +// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ST1:[a-zA-Z0-9_]+]]: index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]](%[[C0]])[%[[SZ0]]] +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[ST1]], %[[SZ1]]] +// CHECK: vector.transfer_read %[[MEM]][%[[I1]], %[[I2]]] + +// ----- + func.func @fold_subview_with_transfer_read(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) -> vector<4xf32> { %f1 = arith.constant 1.0 : f32 %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>> @@ -102,6 +124,29 @@ // ----- +func.func @fold_static_stride_subview_with_transfer_write_0d( + %arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, + %v : vector) { + %f1 = arith.constant 1.0 : f32 + %0 = memref.subview %arg0[%arg1, %arg2][1, 1][2, %arg3] : memref<12x32xf32> to memref> + vector.transfer_write %v, %0[] {in_bounds = []} : vector, memref> + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (d0 * s0 + s1)> +// CHECK: func @fold_static_stride_subview_with_transfer_write_0d +// CHECK-SAME: %[[MEM:[a-zA-Z0-9_]+]]: memref<12x32xf32> +// CHECK-SAME: %[[SZ0:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[SZ1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ST1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[V:[a-zA-Z0-9_]+]]: vector +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[I1:.+]] = affine.apply #[[MAP0]](%[[C0]])[%[[SZ0]]] +// CHECK-DAG: %[[I2:.+]] = affine.apply #[[MAP1]](%[[C0]])[%[[ST1]], %[[SZ1]]] +// CHECK: vector.transfer_write %[[V]], %[[MEM]][%[[I1]], %[[I2]]] + +// ----- + func.func @fold_static_stride_subview_with_transfer_write(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5: index, %arg6 : index, %arg7 : vector<4xf32>) { %0 = memref.subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, strided<[?, ?], offset: ?>>