diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -181,6 +181,24 @@ StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides)); } +/// Casts the given memref to a compatible memref type. If the source memref has +/// a different address space than the target type, a `memref.memory_space_cast` +/// is first inserted, followed by a `memref.cast`. +static Value castToCompatibleMemRefType(OpBuilder &b, Value memref, + MemRefType compatibleMemRefType) { + MemRefType sourceType = memref.getType().cast(); + Value res = memref; + if (sourceType.getMemorySpace() != compatibleMemRefType.getMemorySpace()) { + sourceType = MemRefType::get( + sourceType.getShape(), sourceType.getElementType(), + sourceType.getLayout(), compatibleMemRefType.getMemorySpace()); + res = b.create(memref.getLoc(), sourceType, res); + } + if (sourceType == compatibleMemRefType) + return res; + return b.create(memref.getLoc(), compatibleMemRefType, res); +} + /// Operates under a scoped context to build the intersection between the /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. // TODO: view intersection/union/differences should be a proper std op. @@ -230,6 +248,7 @@ /// Produce IR resembling: /// ``` /// %1:3 = scf.if (%inBounds) { +/// (memref.memory_space_cast %A: memref to memref) /// %view = memref.cast %A: memref to compatibleMemRefType /// scf.yield %view, ... : compatibleMemRefType, index, index /// } else { @@ -252,9 +271,7 @@ return b.create( loc, inBoundsCond, [&](OpBuilder &b, Location loc) { - Value res = memref; - if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, compatibleMemRefType, memref); + Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), xferOp.getIndices().end()); @@ -271,7 +288,7 @@ alloc); b.create(loc, copyArgs.first, copyArgs.second); Value casted = - b.create(loc, compatibleMemRefType, alloc); + castToCompatibleMemRefType(b, alloc, compatibleMemRefType); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); @@ -285,6 +302,7 @@ /// Produce IR resembling: /// ``` /// %1:3 = scf.if (%inBounds) { +/// (memref.memory_space_cast %A: memref to memref) /// memref.cast %A: memref to compatibleMemRefType /// scf.yield %view, ... : compatibleMemRefType, index, index /// } else { @@ -307,9 +325,7 @@ return b.create( loc, inBoundsCond, [&](OpBuilder &b, Location loc) { - Value res = memref; - if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, compatibleMemRefType, memref); + Value res = castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), xferOp.getIndices().end()); @@ -324,7 +340,7 @@ loc, MemRefType::get({}, vector.getType()), alloc)); Value casted = - b.create(loc, compatibleMemRefType, alloc); + castToCompatibleMemRefType(b, alloc, compatibleMemRefType); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); @@ -358,9 +374,8 @@ .create( loc, inBoundsCond, [&](OpBuilder &b, Location loc) { - Value res = memref; - if (compatibleMemRefType != xferOp.getShapedType()) - res = b.create(loc, compatibleMemRefType, memref); + Value res = + castToCompatibleMemRefType(b, memref, compatibleMemRefType); scf::ValueVector viewAndIndices{res}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getIndices().begin(), @@ -369,7 +384,7 @@ }, [&](OpBuilder &b, Location loc) { Value casted = - b.create(loc, compatibleMemRefType, alloc); + castToCompatibleMemRefType(b, alloc, compatibleMemRefType); scf::ValueVector viewAndIndices{casted}; viewAndIndices.insert(viewAndIndices.end(), xferOp.getTransferRank(), zero); diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -101,6 +101,37 @@ return %1 : vector<4x8xf32> } +func.func @split_vector_transfer_read_mem_space(%A: memref, %i: index, %j: index) -> vector<4x8xf32> { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + // CHECK: scf.if {{.*}} -> (memref>, index, index) { + // inBounds with a different memory space + // CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} : + // CHECK-SAME: memref to memref + // CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] : + // CHECK-SAME: memref to memref> + // CHECK: scf.yield %[[cast]], {{.*}} : memref>, index, index + // CHECK: } else { + // slow path, fill tmp alloc and yield a memref_casted version of it + // CHECK: %[[slow:.*]] = vector.transfer_read %[[A]][%[[i]], %[[j]]], %cst : + // CHECK-SAME: memref, vector<4x8xf32> + // CHECK: %[[cast_alloc:.*]] = vector.type_cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref> + // CHECK: store %[[slow]], %[[cast_alloc]][] : memref> + // CHECK: %[[yielded:.*]] = memref.cast %[[alloc]] : + // CHECK-SAME: memref<4x8xf32> to memref> + // CHECK: scf.yield %[[yielded]], %[[c0]], %[[c0]] : + // CHECK-SAME: memref>, index, index + // CHECK: } + // CHECK: %[[res:.*]] = vector.transfer_read %[[ifres]]#0[%[[ifres]]#1, %[[ifres]]#2], %cst + // CHECK-SAME: {in_bounds = [true, true]} : memref>, vector<4x8xf32> + + %1 = vector.transfer_read %A[%i, %j], %f0 : memref, vector<4x8xf32> + + return %1: vector<4x8xf32> +} + transform.sequence failures(propagate) { ^bb1(%func_op: !transform.op<"func.func">): transform.apply_patterns to %func_op { @@ -228,6 +259,40 @@ } : !transform.op<"func.func"> } +// ----- + +func.func @split_vector_transfer_write_mem_space(%V: vector<4x8xf32>, %A: memref, %i: index, %j: index) { + vector.transfer_write %V, %A[%i, %j] : + vector<4x8xf32>, memref + return +} + +// CHECK: func @split_vector_transfer_write_mem_space( +// CHECK: scf.if {{.*}} -> (memref>, index, index) { +// CHECK: %[[space_cast:.*]] = memref.memory_space_cast %{{.*}} : +// CHECK-SAME: memref to memref +// CHECK: %[[cast:.*]] = memref.cast %[[space_cast]] : +// CHECK-SAME: memref to memref> +// CHECK: scf.yield %[[cast]], {{.*}} : memref>, index, index +// CHECK: } else { +// CHECK: %[[VAL_15:.*]] = memref.cast %[[TEMP]] +// CHECK-SAME: : memref<4x8xf32> to memref> +// CHECK: scf.yield %[[VAL_15]], %[[C0]], %[[C0]] +// CHECK-SAME: : memref>, index, index +// CHECK: } +// CHECK: vector.transfer_write %[[VEC]], +// CHECK-SAME: %[[IN_BOUND_DEST:.*]]#0[%[[IN_BOUND_DEST]]#1, %[[IN_BOUND_DEST]]#2] +// CHECK-SAME: {in_bounds = [true, true]} : vector<4x8xf32>, memref> + + +transform.sequence failures(propagate) { +^bb1(%func_op: !transform.op<"func.func">): + transform.apply_patterns to %func_op { + transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" + } : !transform.op<"func.func"> +} + + // ----- func.func private @fake_side_effecting_fun(%0: vector<2x2xf32>) -> ()