diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -229,7 +229,8 @@ mlir::Location loc, Value input) { MemRefType inputType = input.getType().cast(); - assert(inputType.hasStaticShape()); + if (!inputType.hasStaticShape()) + return input; SmallVector subViewOffsets(inputType.getRank(), 0); SmallVector subViewStrides(inputType.getRank(), 1); ArrayRef subViewSizes = inputType.getShape(); @@ -339,23 +340,69 @@ } }; +/// Returns the position of the first inner dimension that has contiguous layout +/// with at least `requiredContiguousSize` contiguous elements. +/// When such a dimension is found, the return value satisfies: +/// 0 <= return_value <= memrefType.getRank() - 1. +/// When no such dimension is found, the return value is memrefType.getRank(). +static int64_t getContiguousInnerDim(MemRefType memrefType, + int64_t requiredContiguousSize) { + auto shape = memrefType.getShape(); + SmallVector strides; + int64_t offset; + int64_t innerDim = shape.size(); + if (succeeded(getStridesAndOffset(memrefType, strides, offset))) { + int64_t innerSize = 1; + while (true) { + if (innerDim == 0) + break; + const int64_t nextDim = innerDim - 1; + if (shape[nextDim] == ShapedType::kDynamicSize) + break; + if (strides[nextDim] != innerSize) + break; + innerSize *= shape[nextDim]; + innerDim = nextDim; + if (innerSize >= requiredContiguousSize) + break; + } + } + return innerDim; +} + /// Creates a memref.collapse_shape collapsing all of the dimensions of the /// input into a 1D shape. -// TODO: move helper function -static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter, - mlir::Location loc, - Value input) { - Value rankReducedInput = - rankReducingSubviewDroppingUnitDims(rewriter, loc, input); - ShapedType rankReducedInputType = - rankReducedInput.getType().cast(); - if (rankReducedInputType.getRank() == 1) - return rankReducedInput; - ReassociationIndices indices; - for (int i = 0; i < rankReducedInputType.getRank(); ++i) - indices.push_back(i); - return rewriter.create( - loc, rankReducedInput, std::array{indices}); +static Value collapseInnerDims(PatternRewriter &rewriter, mlir::Location loc, + Value input, int64_t firstDimToCollapse) { + ShapedType inputType = input.getType().cast(); + if (inputType.getRank() == 1) + return input; + SmallVector reassociation; + for (int64_t i = 0; i < firstDimToCollapse; ++i) { + reassociation.push_back(ReassociationIndices{i}); + } + ReassociationIndices collapsedIndices; + for (int64_t i = firstDimToCollapse; i < inputType.getRank(); ++i) + collapsedIndices.push_back(i); + reassociation.push_back(collapsedIndices); + return rewriter.create(loc, input, reassociation); +} + +static LogicalResult +checkAndCollapseInnerZeroIndices(ValueRange indices, int64_t firstDimToCollapse, + SmallVector &outIndices) { + int64_t rank = indices.size(); + if (firstDimToCollapse >= rank) + return failure(); + for (int64_t i = firstDimToCollapse; i < rank; ++i) { + arith::ConstantIndexOp cst = + indices[i].getDefiningOp(); + if (!cst || cst.value() != 0) + return failure(); + } + outIndices = indices; + outIndices.resize(firstDimToCollapse + 1); + return success(); } /// Rewrites contiguous row-major vector.transfer_read ops by inserting @@ -379,12 +426,9 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - if (!isStaticShapeAndContiguousRowMajor(sourceType)) - return failure(); - if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) - // This pattern requires the source to already be rank-reduced. - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + int64_t firstContiguousInnerDim = + getContiguousInnerDim(sourceType, vectorType.getNumElements()); + if (firstContiguousInnerDim >= sourceType.getRank() - 1) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferReadOp.hasOutOfBoundsDim()) @@ -393,19 +437,28 @@ return failure(); if (transferReadOp.getMask()) return failure(); - if (llvm::any_of(transferReadOp.getIndices(), - [](Value v) { return !isZero(v); })) + SmallVector collapsedIndices; + if (failed(checkAndCollapseInnerZeroIndices(transferReadOp.getIndices(), + firstContiguousInnerDim, + collapsedIndices))) return failure(); - Value c0 = rewriter.create(loc, 0); - auto identityMap1D = rewriter.getMultiDimIdentityMap(1); - VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, - sourceType.getElementType()); - Value source1d = - collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); - Value read1d = rewriter.create( - loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + Value collapsedSource = + collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); + MemRefType collapsedSourceType = + collapsedSource.getType().dyn_cast(); + int64_t collapsedRank = collapsedSourceType.getRank(); + assert(collapsedRank == firstContiguousInnerDim + 1); + SmallVector dimExprs{ + getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; + auto collapsedMap = + AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); + VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType()); + vector::TransferReadOp flatRead = rewriter.create( + loc, flatVectorType, collapsedSource, collapsedIndices, collapsedMap); + flatRead.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.replaceOpWithNewOp( - transferReadOp, vector.getType().cast(), read1d); + transferReadOp, vector.getType().cast(), flatRead); return success(); } }; @@ -431,12 +484,9 @@ if (vectorType.getRank() <= 1) // Already 0D/1D, nothing to do. return failure(); - if (!isStaticShapeAndContiguousRowMajor(sourceType)) - return failure(); - if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) - // This pattern requires the source to already be rank-reduced. - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + int64_t firstContiguousInnerDim = + getContiguousInnerDim(sourceType, vectorType.getNumElements()); + if (firstContiguousInnerDim >= sourceType.getRank() - 1) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) @@ -445,19 +495,29 @@ return failure(); if (transferWriteOp.getMask()) return failure(); - if (llvm::any_of(transferWriteOp.getIndices(), - [](Value v) { return !isZero(v); })) + SmallVector collapsedIndices; + if (failed(checkAndCollapseInnerZeroIndices(transferWriteOp.getIndices(), + firstContiguousInnerDim, + collapsedIndices))) return failure(); - Value c0 = rewriter.create(loc, 0); - auto identityMap1D = rewriter.getMultiDimIdentityMap(1); - VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, - sourceType.getElementType()); - Value source1d = - collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); - Value vector1d = - rewriter.create(loc, vectorType1d, vector); - rewriter.create(loc, vector1d, source1d, - ValueRange{c0}, identityMap1D); + Value collapsedSource = + collapseInnerDims(rewriter, loc, source, firstContiguousInnerDim); + MemRefType collapsedSourceType = + collapsedSource.getType().dyn_cast(); + int64_t collapsedRank = collapsedSourceType.getRank(); + assert(collapsedRank == firstContiguousInnerDim + 1); + SmallVector dimExprs{ + getAffineDimExpr(firstContiguousInnerDim, rewriter.getContext())}; + auto collapsedMap = + AffineMap::get(collapsedRank, 0, dimExprs, rewriter.getContext()); + VectorType flatVectorType = VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType()); + Value flatVector = + rewriter.create(loc, flatVectorType, vector); + vector::TransferWriteOp flatWrite = + rewriter.create( + loc, flatVector, collapsedSource, collapsedIndices, collapsedMap); + flatWrite.setInBoundsAttr(rewriter.getBoolArrayAttr({true})); rewriter.eraseOp(transferWriteOp); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -59,3 +59,48 @@ // CHECK: %[[CST:.+]] = arith.constant 0 : i8 // CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][], %[[CST]] : memref // CHECK: return %[[READ]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref, %arg1 : index, %arg2 : index) -> vector<8x4xi8> { + %c0_i8 = arith.constant 0 : i8 + %c0 = arith.constant 0 : index + %result = vector.transfer_read %arg0[%arg1, %arg2, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<8x4xi8> + return %result : vector<8x4xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices +// CHECK-SAME: %[[ARG0:.+]]: memref, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index +// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK-SAME: : memref into memref +// CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK-SAME: [%[[ARG1]], %[[ARG2]], %[[C0]]], %[[C0_I8]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: : memref, vector<32xi8> +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> +// CHECK: return %[[VEC2D]] : vector<8x4xi8> + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3)[s0, s1] -> (d0 * s1 + s0 + d1 * 32 + d2 * 4 + d3)> + +func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref, %arg1 : index, %arg2 : index) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %dst[%arg1, %arg2, %c0, %c0] {in_bounds = [true, true]} : vector<8x4xi8>, memref + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices +// CHECK-SAME: %[[ARG0:.+]]: vector<8x4xi8>, %[[ARG1:.+]]: memref, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK-SAME: : memref into memref +// CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[ARG0]] : vector<8x4xi8> to vector<32xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] +// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] +// CHECK-SAME: {in_bounds = [true]} +// CHECK-SAME: : vector<32xi8>, memref