diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -229,7 +229,6 @@ mlir::Location loc, Value input) { MemRefType inputType = input.getType().cast(); - assert(inputType.hasStaticShape()); SmallVector subViewOffsets(inputType.getRank(), 0); SmallVector subViewStrides(inputType.getRank(), 1); ArrayRef subViewSizes = inputType.getShape(); @@ -358,6 +357,45 @@ loc, rankReducedInput, std::array{indices}); } +/// Returns how many elements are statically known to be contiguous in the +/// layout of this `memrefType`, at the position given by `indices`. +/// For example, for `memref<8xf32>` and `indices=[%c2]` where +/// `%c2 = arith.constant 2 : index`, the return value should be 6 = 8 - 2. +/// Returns 0 in error cases. +static int64_t +getContiguousStaticRowMajorElementsAtPosition(MemRefType memrefType, + ValueRange indices) { + auto shape = memrefType.getShape(); + if (shape.size() != indices.size()) + return 0; + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + return 0; + // contiguousSize will be the number of contiguous elements in memrefType in + // the dimensions traversed by the for loop below. + int64_t contiguousSize = 1; + // indicesOffset will be the offset of the position given by `indices` in the + // dimensions traversed by the for loop below. + int64_t indicesOffset = 0; + for (int i = static_cast(strides.size()) - 1; i >= 0; --i) { + if (shape[i] == ShapedType::kDynamicSize) + break; + if (strides[i] != contiguousSize) + break; + arith::ConstantIndexOp cstIndex = + indices[i].getDefiningOp(); + if (!cstIndex) + break; + int64_t indexValue = cstIndex.value(); + if (indexValue < 0 || indexValue >= shape[i]) + return 0; + indicesOffset += indexValue * contiguousSize; + contiguousSize *= shape[i]; + } + return contiguousSize - indicesOffset; +} + /// Rewrites contiguous row-major vector.transfer_read ops by inserting /// memref.collapse_shape on the source so that the resulting /// vector.transfer_read has a 1D source. Requires the source shape to be @@ -379,12 +417,9 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - if (!isStaticShapeAndContiguousRowMajor(sourceType)) - return failure(); - if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) - // This pattern requires the source to already be rank-reduced. - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (getContiguousStaticRowMajorElementsAtPosition( + sourceType, transferReadOp.getIndices()) < + vectorType.getNumElements()) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) @@ -393,17 +428,15 @@ return failure(); if (transferReadOp.getMask()) return failure(); - if (llvm::any_of(transferReadOp.getIndices(), - [](Value v) { return !isZero(v); })) - return failure(); Value c0 = rewriter.create(loc, 0); auto identityMap1D = rewriter.getMultiDimIdentityMap(1); - VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, - sourceType.getElementType()); + VectorType vectorType1d = VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType()); Value source1d = collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); - Value read1d = rewriter.create( + vector::TransferReadOp read1d = rewriter.create( loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + read1d.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.replaceOpWithNewOp( transferReadOp, vector.getType().cast(), read1d); return success(); @@ -431,12 +464,9 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - if (!isStaticShapeAndContiguousRowMajor(sourceType)) - return failure(); - if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) - // This pattern requires the source to already be rank-reduced. - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (getContiguousStaticRowMajorElementsAtPosition( + sourceType, transferWriteOp.getIndices()) < + vectorType.getNumElements()) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) @@ -445,19 +475,17 @@ return failure(); if (transferWriteOp.getMask()) return failure(); - if (llvm::any_of(transferWriteOp.getIndices(), - [](Value v) { return !isZero(v); })) - return failure(); Value c0 = rewriter.create(loc, 0); auto identityMap1D = rewriter.getMultiDimIdentityMap(1); - VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, - sourceType.getElementType()); + VectorType vectorType1d = VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType()); Value source1d = collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); Value vector1d = rewriter.create(loc, vectorType1d, vector); - rewriter.create(loc, vector1d, source1d, - ValueRange{c0}, identityMap1D); + vector::TransferWriteOp write1d = rewriter.create( + loc, vector1d, source1d, ValueRange{c0}, identityMap1D); + write1d.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.eraseOp(transferWriteOp); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -59,3 +59,77 @@ // CHECK: %[[CST:.+]] = arith.constant 0 : i8 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref // CHECK: return %[[READ]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref, %arg1 : index, %arg2 : index) -> vector<8x4xi8> { + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x4xi8> + return %result : vector<8x4xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices +// CHECK-SAME: %[[ARG0:.+]]: memref to vector<8x4xi8> +// CHECK: return %[[VEC2D]] : vector<8x4xi8> + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_read_flattenable_with_dynamic_dims_and_nonzero_static_indices(%arg0 : memref, %arg1 : index, %arg2 : index) -> vector<5x4xi8> { + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + %result = vector.transfer_read %arg0[%arg1, %arg2, %c3, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<5x4xi8> + return %result : vector<5x4xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_nonzero_static_indices +// CHECK-SAME: %[[ARG0:.+]]: memref to vector<5x4xi8> +// CHECK: return %[[VEC2D]] : vector<5x4xi8> + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_read_nonflattenable_due_to_too_large_nonzero_static_indices(%arg0 : memref, %arg1 : index, %arg2 : index) -> vector<5x4xi8> { + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %c3 = arith.constant 4 : index + %result = vector.transfer_read %arg0[%arg1, %arg2, %c3, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<5x4xi8> + return %result : vector<5x4xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_due_to_too_large_nonzero_static_indices +// CHECK-SAME: %[[ARG0:.+]]: memref + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref, %arg1 : index, %arg2 : index) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices +// CHECK-SAME: %[[VEC2D:.+]]: vector<8x4xi8> +// CHECK-SAME: %[[DST:.+]]: memref to vector<32xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-SAME: {in_bounds = [true]}