diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -229,7 +229,7 @@ mlir::Location loc, Value input) { MemRefType inputType = input.getType().cast(); - assert(inputType.hasStaticShape()); + if (!inputType.hasStaticShape()) return input; SmallVector subViewOffsets(inputType.getRank(), 0); SmallVector subViewStrides(inputType.getRank(), 1); ArrayRef subViewSizes = inputType.getShape(); @@ -358,6 +358,40 @@ loc, rankReducedInput, std::array{indices}); } +/// Returns how many elements are statically known to be contiguous in the +/// layout of this `memrefType`, in the dimensions for which `indices` are +/// constant 0. +/// +/// For example, for `memref<2x3x4xf32>` and `indices=[%x, %c0, %c0]` where +/// `%c2 = arith.constant 0 : index` and %x is anything else, the return value +/// should be 12 ( = 3*4). +/// +/// Returns 0 in error cases. +static int64_t +getContiguousStaticRowMajorElementsAtDimsWithZeroIndex(MemRefType memrefType, + ValueRange indices) { + auto shape = memrefType.getShape(); + if (shape.size() != indices.size()) + return 0; + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return 0; + int64_t contiguousSize = 1; + for (int i = static_cast(strides.size()) - 1; i >= 0; --i) { + if (shape[i] == ShapedType::kDynamicSize) + break; + if (strides[i] != contiguousSize) + break; + arith::ConstantIndexOp cstIndex = + indices[i].getDefiningOp(); + if (!cstIndex || cstIndex.value() != 0) + break; + contiguousSize *= shape[i]; + } + return contiguousSize; +} + /// Rewrites contiguous row-major vector.transfer_read ops by inserting /// memref.collapse_shape on the source so that the resulting /// vector.transfer_read has a 1D source. Requires the source shape to be @@ -379,12 +413,9 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - if (!isStaticShapeAndContiguousRowMajor(sourceType)) - return failure(); - if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) - // This pattern requires the source to already be rank-reduced. - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (getContiguousStaticRowMajorElementsAtDimsWithZeroIndex( + sourceType, transferReadOp.getIndices()) < + vectorType.getNumElements()) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) @@ -393,17 +424,15 @@ return failure(); if (transferReadOp.getMask()) return failure(); - if (llvm::any_of(transferReadOp.getIndices(), - [](Value v) { return !isZero(v); })) - return failure(); Value c0 = rewriter.create(loc, 0); auto identityMap1D = rewriter.getMultiDimIdentityMap(1); - VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, - sourceType.getElementType()); + VectorType vectorType1d = VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType()); Value source1d = collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); - Value read1d = rewriter.create( + vector::TransferReadOp read1d = rewriter.create( loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + read1d.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.replaceOpWithNewOp( transferReadOp, vector.getType().cast(), read1d); return success(); @@ -431,12 +460,9 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - if (!isStaticShapeAndContiguousRowMajor(sourceType)) - return failure(); - if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) - // This pattern requires the source to already be rank-reduced. - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (getContiguousStaticRowMajorElementsAtDimsWithZeroIndex( + sourceType, transferWriteOp.getIndices()) < + vectorType.getNumElements()) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) @@ -445,19 +471,17 @@ return failure(); if (transferWriteOp.getMask()) return failure(); - if (llvm::any_of(transferWriteOp.getIndices(), - [](Value v) { return !isZero(v); })) - return failure(); Value c0 = rewriter.create(loc, 0); auto identityMap1D = rewriter.getMultiDimIdentityMap(1); - VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, - sourceType.getElementType()); + VectorType vectorType1d = VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType()); Value source1d = collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); Value vector1d = rewriter.create(loc, vectorType1d, vector); - rewriter.create(loc, vector1d, source1d, - ValueRange{c0}, identityMap1D); + vector::TransferWriteOp write1d = rewriter.create( + loc, vector1d, source1d, ValueRange{c0}, identityMap1D); + write1d.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.eraseOp(transferWriteOp); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -59,3 +59,40 @@ // CHECK: %[[CST:.+]] = arith.constant 0 : i8 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref // CHECK: return %[[READ]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref, %arg1 : index, %arg2 : index) -> vector<8x4xi8> { + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x4xi8> + return %result : vector<8x4xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices +// CHECK-SAME: %[[ARG0:.+]]: memref to vector<8x4xi8> +// CHECK: return %[[VEC2D]] : vector<8x4xi8> + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref, %arg1 : index, %arg2 : index) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices +// CHECK-SAME: %[[VEC2D:.+]]: vector<8x4xi8> +// CHECK-SAME: %[[DST:.+]]: memref to vector<32xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-SAME: {in_bounds = [true]}