diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3573,12 +3573,20 @@ auto loc = readOp.getLoc(); SmallVector offsets(srcType.getRank(), 0); SmallVector strides(srcType.getRank(), 1); + + SmallVector inBounds = {}; + if (readOp.in_bounds()) { + for (size_t i = 0; i < readOp.in_boundsAttr().size() - dimsToDrop; ++i) { + auto attr = readOp.in_boundsAttr()[i].cast(); + inBounds.push_back(attr.getValue()); + } + } Value rankedReducedView = rewriter.create( loc, resultMemrefType, readOp.source(), offsets, srcType.getShape(), strides); Value result = rewriter.create( loc, resultTargetVecType, rankedReducedView, - readOp.indices().drop_back(dimsToDrop)); + readOp.indices().drop_back(dimsToDrop), inBounds); rewriter.replaceOpWithNewOp(readOp, targetType, result); diff --git a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-collapse-inner-most-dims.mlir @@ -31,3 +31,21 @@ // CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]] // CHECK: %[[RESULT]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32> // CHECK: return %[[RESULT]] + + + +// ----- + +func @contiguous_inner_most_dim_bounds(%A: memref<1000x1xf32>, %i:index, %ii:index) -> (vector<4x1xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %0 = memref.subview %A[%i, 0] [40, 1] [1, 1] : memref<1000x1xf32> to memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>> + %1 = vector.transfer_read %0[%ii, %c0], %cst {in_bounds = [true, true]} : memref<40x1xf32, affine_map<(d0, d1)[s0] -> (d0 + s0 + d1)>>, vector<4x1xf32> + return %1 : vector<4x1xf32> +} +// CHECK: func @contiguous_inner_most_dim_bounds(%[[SRC:.+]]: memref<1000x1xf32>, %[[II:.+]]: index, %[[J:.+]]: index) -> vector<4x1xf32> +// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]] +// CHECK: %[[SRC_1:.+]] = memref.subview %[[SRC_0]] +// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_1]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: vector<4xf32>