diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -67,6 +67,8 @@ /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); +void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -10,12 +10,14 @@ // transfer_write ops. // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -209,6 +211,281 @@ opToErase.push_back(read.getOperation()); } +/// Drops unit dimensions from the input MemRefType. +static MemRefType dropUnitDims(MemRefType inputType) { + ArrayRef none{}; + Type rankReducedType = memref::SubViewOp::inferRankReducedResultType( + 0, inputType, none, none, none); + return canonicalizeStridedLayout(rankReducedType.cast()); +} + +/// Creates a rank-reducing memref.subview op that drops unit dims from its +/// input. Or just returns the input if it was already without unit dims. +static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, + mlir::Location loc, + Value input) { + MemRefType inputType = input.getType().cast(); + assert(inputType.hasStaticShape()); + MemRefType resultType = dropUnitDims(inputType); + if (resultType == inputType) + return input; + SmallVector subviewOffsets(inputType.getRank(), 0); + SmallVector subviewStrides(inputType.getRank(), 1); + return rewriter.create( + loc, resultType, input, subviewOffsets, inputType.getShape(), + subviewStrides); +} + +/// Returns the number of dims that aren't unit dims. +static int getReducedRank(ArrayRef shape) { + return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); +} + +/// Returns true if all values are `arith.constant 0 : index` +static bool allZeroConstantIndexValues(ValueRange values) { + for (Value value : values) { + auto cst = value.getDefiningOp(); + if (!cst) + return false; + if (cst.value() != 0) + return false; + } + return true; +} + +/// Rewrites vector.transfer_read ops where the source has unit dims, by +/// inserting a memref.subview dropping those unit dims. +class TransferReadDropUnitDimsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().cast(); + if (!sourceType.hasStaticShape()) + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferReadOp.hasOutOfBoundsDim()) + return failure(); + if (!transferReadOp.permutation_map().isMinorIdentity()) + return failure(); + int reducedRank = getReducedRank(sourceType.getShape()); + if (reducedRank == sourceType.getRank()) + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. + if (!allZeroConstantIndexValues(transferReadOp.indices())) + return failure(); + Value reducedShapeSource = + rankReducingSubviewDroppingUnitDims(rewriter, loc, source); + Value c0 = rewriter.create(loc, 0); + SmallVector zeros(reducedRank, c0); + auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + rewriter.replaceOpWithNewOp( + transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); + return success(); + } +}; + +/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has +/// unit dims, by inserting a memref.subview dropping those unit dims. +class TransferWriteDropUnitDimsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().cast(); + if (!sourceType.hasStaticShape()) + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferWriteOp.hasOutOfBoundsDim()) + return failure(); + if (!transferWriteOp.permutation_map().isMinorIdentity()) + return failure(); + int reducedRank = getReducedRank(sourceType.getShape()); + if (reducedRank == sourceType.getRank()) + return failure(); // The source shape can't be further reduced. + if (reducedRank != vectorType.getRank()) + return failure(); // This pattern requires the vector shape to match the + // reduced source shape. + if (!allZeroConstantIndexValues(transferWriteOp.indices())) + return failure(); + Value reducedShapeSource = + rankReducingSubviewDroppingUnitDims(rewriter, loc, source); + Value c0 = rewriter.create(loc, 0); + SmallVector zeros(reducedRank, c0); + auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + rewriter.replaceOpWithNewOp( + transferWriteOp, vector, reducedShapeSource, zeros, identityMap); + return success(); + } +}; + +static AffineExpr getOffsetExpr(MemRefType memrefType) { + SmallVector strides; + AffineExpr offset; + LogicalResult res = getStridesAndOffset(memrefType, strides, offset); + assert(succeeded(res)); + (void)res; + (void)strides; + return offset; +} + +static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, + ArrayRef shape, + Type elementType, + AffineExpr offset) { + AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); + AffineExpr contiguousRowMajor = canonical + offset; + AffineMap contiguousRowMajorMap = + AffineMap::inferFromExprList({contiguousRowMajor})[0]; + return MemRefType::get(shape, elementType, contiguousRowMajorMap); +} + +/// Helper determining if a memref is static-shape and contiguous-row-major +/// layout, still allowing an arbitrary offset (unlike some existing similar +/// functions). +static bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { + if (!memrefType.hasStaticShape()) { + return false; + } + AffineExpr offset = getOffsetExpr(memrefType); + MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( + memrefType.getContext(), memrefType.getShape(), + memrefType.getElementType(), offset); + return canonicalizeStridedLayout(memrefType) == + canonicalizeStridedLayout(contiguousRowMajorMemRefType); +} + +/// Creates a memref.collapse_shape collapsing all of the dimensions of the +/// input into a 1D shape. +static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter, + mlir::Location loc, + Value input) { + Value rankReducedInput = + rankReducingSubviewDroppingUnitDims(rewriter, loc, input); + ShapedType rankReducedInputType = + rankReducedInput.getType().cast(); + if (rankReducedInputType.getRank() == 1) + return rankReducedInput; + ReassociationIndices indices; + for (int i = 0; i < rankReducedInputType.getRank(); ++i) + indices.push_back(i); + return rewriter.create( + loc, rankReducedInput, std::array{indices}); +} + +/// Rewrites contiguous row-major vector.transfer_read ops by inserting +/// memref.collapse_shape on the source so that the resulting +/// vector.transfer_read has a 1D source. Requires the source shape to be +/// already reduced i.e. without unit dims. +class FlattenContiguousRowMajorTransferReadPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().cast(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) + // Already 1D, nothing to do. + return failure(); + if (!isStaticShapeAndContiguousRowMajor(sourceType)) + return failure(); + if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) + // This pattern requires the source to already be rank-reduced. + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferReadOp.hasOutOfBoundsDim()) + return failure(); + if (!transferReadOp.permutation_map().isMinorIdentity()) + return failure(); + if (transferReadOp.mask()) + return failure(); + if (!allZeroConstantIndexValues(transferReadOp.indices())) + return failure(); + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = + collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); + Value read1d = rewriter.create( + loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + rewriter.replaceOpWithNewOp( + transferReadOp, vector.getType().cast(), read1d); + return success(); + } +}; + +/// Rewrites contiguous row-major vector.transfer_write ops by inserting +/// memref.collapse_shape on the source so that the resulting +/// vector.transfer_write has a 1D source. Requires the source shape to be +/// already reduced i.e. without unit dims. +class FlattenContiguousRowMajorTransferWritePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().cast(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) + // Already 1D, nothing to do. + return failure(); + if (!isStaticShapeAndContiguousRowMajor(sourceType)) + return failure(); + if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) + // This pattern requires the source to already be rank-reduced. + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferWriteOp.hasOutOfBoundsDim()) + return failure(); + if (!transferWriteOp.permutation_map().isMinorIdentity()) + return failure(); + if (transferWriteOp.mask()) + return failure(); + if (!allZeroConstantIndexValues(transferWriteOp.indices())) + return failure(); + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = + collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); + Value vector1d = + rewriter.create(loc, vectorType1d, vector); + rewriter.create(loc, vector1d, source1d, + ValueRange{c0}, identityMap1D); + rewriter.eraseOp(transferWriteOp); + return success(); + } +}; + } // namespace void mlir::vector::transferOpflowOpt(FuncOp func) { @@ -226,3 +503,13 @@ }); opt.removeDeadOp(); } + +void mlir::vector::populateFlattenVectorTransferPatterns( + RewritePatternSet &patterns) { + patterns + .add( + patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -0,0 +1,206 @@ +// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s + +func @transfer_read_flattenable(%arg : memref<5x4x3x2xi8>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8> +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> +// CHECK: return %[[VEC2D]] + +// ----- + +func @transfer_read_flattenable_with_offset(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> +// CHECK: return %[[VEC2D]] + +// ----- + +func @transfer_read_flattenable_with_offset_with_rank_reducing_subview(%arg : memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>>) -> vector<3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>>, vector<3x2xi8> + return %v : vector<3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[SUBVIEW]] {{.}}[0, 1] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<6xi8> to vector<3x2xi8> +// CHECK: return %[[VEC2D]] + +// ----- + +func @transfer_read_flattenable_with_offset_with_rank_reducing_subview_and_no_collapse(%arg : memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>>) -> vector<2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>>, vector<2xi8> + return %v : vector<2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x1x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 1, 2] [1, 1, 1, 1] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[SUBVIEW]] +// CHECK: return %[[READ1D]] + +// ----- + +func @transfer_read_nonflattenable_out_of_bounds(%arg : memref<5x4x3x2xi8>, %i : index) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%i, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} : memref<5x4x3x2xi8>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_out_of_bounds +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8>, +// CHECK-SAME: %[[I:.+]]: index +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][%[[I]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_read_nonflattenable_non_contiguous(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_non_contiguous +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_read_nonflattenable_non_row_major(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 5 + d2 * 20 + d3 * 60)>>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 5 + d2 * 20 + d3 * 60)>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_non_row_major +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_write_flattenable(%arg : memref<5x4x3x2xi8>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8>, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_flattenable_with_offset(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, {{.+}}>, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_flattenable_with_offset_with_rank_reducing_subview(%arg : memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>>, %vec : vector<3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<3x2xi8>, memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_offset_with_rank_reducing_subview +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8, {{.+}}>, +// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[SUBVIEW]] {{.}}[0, 1]{{.}} : memref<3x2xi8, {{.+}}> into memref<6xi8, {{.+}}> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<3x2xi8> to vector<6xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_flattenable_with_offset_with_rank_reducing_subview_and_no_collapse(%arg : memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>>, %vec : vector<2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<2xi8>, memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_offset_with_rank_reducing_subview_and_no_collapse +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x1x2xi8, {{.+}}>, +// CHECK-SAME: %[[VEC:.+]]: vector<2xi8> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 1, 2] [1, 1, 1, 1] +// CHECK: vector.transfer_write %[[VEC]], %[[SUBVIEW]] + +// ----- + +func @transfer_write_nonflattenable_out_of_bounds(%arg : memref<5x4x3x2xi8>, %vec : vector<5x4x3x2xi8>, %i : index) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%i, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} : vector<5x4x3x2xi8>, memref<5x4x3x2xi8> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_out_of_bounds +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8>, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK-SAME: %[[I:.+]]: index +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] + +// ----- + +func @transfer_write_nonflattenable_non_contiguous(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_non_contiguous +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] + +// ----- + +func @transfer_write_nonflattenable_non_row_major(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_non_row_major +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -583,6 +583,25 @@ } }; +struct TestFlattenVectorTransferPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transfer-flatten-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to rewrite contiguous row-major N-dimensional " + "vector.transfer_{read,write} ops into 1D transfers"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateFlattenVectorTransferPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // end anonymous namespace namespace mlir { @@ -613,6 +632,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir