diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1835,9 +1835,8 @@ /// Operates under a scoped context to build the intersection between the /// view `xferOp.source()` @ `xferOp.indices()` and the view `alloc`. // TODO: view intersection/union/differences should be a proper std op. -static Value createSubViewIntersection(OpBuilder &b, - VectorTransferOpInterface xferOp, - Value alloc) { +static std::pair createSubViewIntersection( + OpBuilder &b, VectorTransferOpInterface xferOp, Value alloc) { ImplicitLocOpBuilder lb(xferOp.getLoc(), b); int64_t memrefRank = xferOp.getShapedType().getRank(); // TODO: relax this precondition, will require rank-reducing subviews. @@ -1864,11 +1863,15 @@ sizes.push_back(affineMin); }); - SmallVector indices = llvm::to_vector<4>(llvm::map_range( + SmallVector srcIndices = llvm::to_vector<4>(llvm::map_range( xferOp.indices(), [](Value idx) -> OpFoldResult { return idx; })); - return lb.create( - isaWrite ? alloc : xferOp.source(), indices, sizes, - SmallVector(memrefRank, OpBuilder(xferOp).getIndexAttr(1))); + SmallVector destIndices(memrefRank, b.getIndexAttr(0)); + SmallVector strides(memrefRank, b.getIndexAttr(1)); + auto copySrc = lb.create( + isaWrite ? alloc : xferOp.source(), srcIndices, sizes, strides); + auto copyDest = lb.create( + isaWrite ? xferOp.source() : alloc, destIndices, sizes, strides); + return std::make_pair(copySrc, copyDest); } /// Given an `xferOp` for which: @@ -1877,14 +1880,15 @@ /// Produce IR resembling: /// ``` /// %1:3 = scf.if (%inBounds) { -/// memref.cast %A: memref to compatibleMemRefType +/// %view = memref.cast %A: memref to compatibleMemRefType /// scf.yield %view, ... : compatibleMemRefType, index, index /// } else { /// %2 = linalg.fill(%pad, %alloc) /// %3 = subview %view [...][...][...] -/// linalg.copy(%3, %alloc) -/// memref.cast %alloc: memref to compatibleMemRefType -/// scf.yield %4, ... : compatibleMemRefType, index, index +/// %4 = subview %alloc [0, 0] [...] [...] +/// linalg.copy(%3, %4) +/// %5 = memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %5, ... : compatibleMemRefType, index, index /// } /// ``` /// Return the produced scf::IfOp. @@ -1910,9 +1914,9 @@ b.create(loc, xferOp.padding(), alloc); // Take partial subview of memref which guarantees no dimension // overflows. - Value memRefSubView = createSubViewIntersection( + std::pair copyArgs = createSubViewIntersection( b, cast(xferOp.getOperation()), alloc); - b.create(loc, memRefSubView, alloc); + b.create(loc, copyArgs.first, copyArgs.second); Value casted = b.create(loc, alloc, compatibleMemRefType); scf::ValueVector viewAndIndices{casted}; @@ -2030,7 +2034,8 @@ /// %notInBounds = xor %inBounds, %true /// scf.if (%notInBounds) { /// %3 = subview %alloc [...][...][...] -/// linalg.copy(%3, %view) +/// %4 = subview %view [0, 0][...][...] +/// linalg.copy(%3, %4) /// } /// ``` static void createFullPartialLinalgCopy(OpBuilder &b, @@ -2040,9 +2045,9 @@ auto notInBounds = lb.create(inBoundsCond, lb.create(true, 1)); lb.create(notInBounds, [&](OpBuilder &b, Location loc) { - Value memRefSubView = createSubViewIntersection( + std::pair copyArgs = createSubViewIntersection( b, cast(xferOp.getOperation()), alloc); - b.create(loc, memRefSubView, xferOp.source()); + b.create(loc, copyArgs.first, copyArgs.second); b.create(loc, ValueRange{}); }); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split.mlir @@ -81,7 +81,8 @@ // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] // LINALG-SAME: memref to memref - // LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref, memref<4x8xf32> + // LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1] + // LINALG: linalg.copy(%[[sv]], %[[alloc_view]]) : memref, memref // LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] : // LINALG-SAME: memref<4x8xf32> to memref // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : @@ -172,7 +173,8 @@ // LINALG: %[[sv1:.*]] = affine.min #[[$bounds_map_8]](%[[c8]], %[[j]], %[[c8]]) // LINALG: %[[sv:.*]] = memref.subview %[[A]][%[[i]], %[[j]]] [%[[sv0]], %[[sv1]]] [1, 1] // LINALG-SAME: memref<7x8xf32, #[[$map_2d_stride_1]]> to memref - // LINALG: linalg.copy(%[[sv]], %[[alloc]]) : memref, memref<4x8xf32> + // LINALG: %[[alloc_view:.*]] = memref.subview %[[alloc]][0, 0] [%[[sv0]], %[[sv1]]] [1, 1] + // LINALG: linalg.copy(%[[sv]], %[[alloc_view]]) : memref, memref // LINALG: %[[yielded:.*]] = memref.cast %[[alloc]] : // LINALG-SAME: memref<4x8xf32> to memref // LINALG: scf.yield %[[yielded]], %[[c0]], %[[c0]] : @@ -276,8 +278,9 @@ // LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]] // LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] // LINALG-SAME: [1, 1] : memref<4x8xf32> to memref -// LINALG: linalg.copy(%[[VAL_22]], %[[DEST]]) -// LINALG-SAME: : memref, memref +// LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1] +// LINALG: linalg.copy(%[[VAL_22]], %[[DEST_VIEW]]) +// LINALG-SAME: : memref, memref // LINALG: } // LINALG: return // LINALG: } @@ -384,8 +387,9 @@ // LINALG: %[[VAL_22:.*]] = memref.subview %[[TEMP]] // LINALG-SAME: [%[[I]], %[[J]]] [%[[VAL_20]], %[[VAL_21]]] // LINALG-SAME: [1, 1] : memref<4x8xf32> to memref -// LINALG: linalg.copy(%[[VAL_22]], %[[DEST]]) -// LINALG-SAME: : memref, memref<7x8xf32, #[[MAP0]]> +// LINALG: %[[DEST_VIEW:.*]] = memref.subview %[[DEST]][0, 0] [%[[VAL_20]], %[[VAL_21]]] [1, 1] +// LINALG: linalg.copy(%[[VAL_22]], %[[DEST_VIEW]]) +// LINALG-SAME: : memref, memref // LINALG: } // LINALG: return // LINALG: }