diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -67,6 +67,8 @@ /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); +void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3391,6 +3391,131 @@ } }; +static Value collapseTo1D(PatternRewriter &rewriter, mlir::Location loc, + Value input) { + auto inputType = input.getType().cast(); + ReassociationIndices indices; + for (int i = 0; i < inputType.getRank(); ++i) { + indices.push_back(i); + } + return rewriter.create( + loc, input, std::array{indices}); +} + +// Helper determining if a memref is static-shape and contiguous-row-major +// layout, still allowing an arbitrary offset (unlike some existing similar +// functions). +static bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { + SmallVector strides; + int64_t offset; + if (!memrefType.hasStaticShape()) { + return false; + } + if (failed(getStridesAndOffset(memrefType, strides, offset))) { + return false; + } + int64_t productOfInnerMostSizes = 1; + for (int i = memrefType.getRank() - 1; i >= 0; --i) { + int64_t dimsize = memrefType.getDimSize(i); + // The dimsize!=1 condition here means that we ignore the strides of + // unit dims, as they don't make a practical difference. + if (dimsize != 1 && strides[i] != productOfInnerMostSizes) { + return false; + } + // This simple arithmetic is correct thanks to having ensured above that + // we have a static shape. + productOfInnerMostSizes *= dimsize; + } + return true; +} + +class FlattenTransferReadPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().cast(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) { + // Already 1D, nothing to do. + return failure(); + } + if (!isStaticShapeAndContiguousRowMajor(sourceType)) { + return failure(); + } + if (sourceType.getNumElements() != vectorType.getNumElements()) { + return failure(); + } + if (transferReadOp.hasOutOfBoundsDim()) { + return failure(); + } + if (!transferReadOp.permutation_map().isMinorIdentity()) { + return failure(); + } + if (transferReadOp.mask()) { + return failure(); + } + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = collapseTo1D(rewriter, loc, source); + Value read1d = rewriter.create( + loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + rewriter.replaceOpWithNewOp( + transferReadOp, vector.getType().cast(), read1d); + return success(); + } +}; + +class FlattenTransferWritePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().cast(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) { + // Already 1D, nothing to do. + return failure(); + } + if (!isStaticShapeAndContiguousRowMajor(sourceType)) { + return failure(); + } + if (sourceType.getNumElements() != vectorType.getNumElements()) { + return failure(); + } + if (transferWriteOp.hasOutOfBoundsDim()) { + return failure(); + } + if (!transferWriteOp.permutation_map().isMinorIdentity()) { + return failure(); + } + if (transferWriteOp.mask()) { + return failure(); + } + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = collapseTo1D(rewriter, loc, source); + Value vector1d = + rewriter.create(loc, vectorType1d, vector); + rewriter.create(loc, vector1d, source1d, + ValueRange{c0}, identityMap1D); + rewriter.eraseOp(transferWriteOp); + return success(); + } +}; + void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool indexOptimizations) { patterns.add(patterns.getContext()); } +void mlir::vector::populateFlattenVectorTransferPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} + void mlir::vector::populateBubbleVectorBitCastOpPatterns( RewritePatternSet &patterns) { patterns.add) -> vector<4x3x2x1xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<4x3x2x1xi8>, vector<4x3x2x1xi8> + return %v : vector<4x3x2x1xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8> +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<4x3x2x1xi8> into memref<24xi8> +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<24xi8> to vector<4x3x2x1xi8> +// CHECK: return %[[VEC2D]] + +// ----- + +func @transfer_read_flattenable_with_offset(%arg : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 2 + d2 + d3 + s0)>>) -> vector<4x3x2x1xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 2 + d2 + d3 + s0)>>, vector<4x3x2x1xi8> + return %v : vector<4x3x2x1xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<24xi8> to vector<4x3x2x1xi8> +// CHECK: return %[[VEC2D]] +// ----- + +func @transfer_read_nonflattenable_out_of_bounds(%arg : memref<4x3x2x1xi8>, %i : index) -> vector<4x3x2x1xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%i, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} : memref<4x3x2x1xi8>, vector<4x3x2x1xi8> + return %v : vector<4x3x2x1xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_out_of_bounds +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8>, +// CHECK-SAME: %[[I:.+]]: index +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][%[[I]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_read_nonflattenable_non_contiguous(%arg : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 8 + d1 * 2 + d2 + d3)>>) -> vector<4x3x2x1xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 8 + d1 * 2 + d2 + d3)>>, vector<4x3x2x1xi8> + return %v : vector<4x3x2x1xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_non_contiguous +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8, +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_read_nonflattenable_non_row_major(%arg : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>>) -> vector<4x3x2x1xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>>, vector<4x3x2x1xi8> + return %v : vector<4x3x2x1xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_non_row_major +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8, +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_write_flattenable(%arg : memref<4x3x2x1xi8>, %vec : vector<4x3x2x1xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<4x3x2x1xi8>, memref<4x3x2x1xi8> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8>, +// CHECK-SAME: %[[VEC:.+]]: vector<4x3x2x1xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<4x3x2x1xi8> into memref<24xi8> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<4x3x2x1xi8> to vector<24xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_flattenable_with_offset(%arg : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 2 + d2 + d3 + s0)>>, %vec : vector<4x3x2x1xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<4x3x2x1xi8>, memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 2 + d2 + d3 + s0)>> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8, {{.+}}>, +// CHECK-SAME: %[[VEC:.+]]: vector<4x3x2x1xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<4x3x2x1xi8, {{.+}}> into memref<24xi8, {{.+}}> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<4x3x2x1xi8> to vector<24xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_nonflattenable_out_of_bounds(%arg : memref<4x3x2x1xi8>, %vec : vector<4x3x2x1xi8>, %i : index) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%i, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} : vector<4x3x2x1xi8>, memref<4x3x2x1xi8> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_out_of_bounds +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8>, +// CHECK-SAME: %[[VEC:.+]]: vector<4x3x2x1xi8> +// CHECK-SAME: %[[I:.+]]: index +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] + +// ----- + +func @transfer_write_nonflattenable_non_contiguous(%arg : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 8 + d1 * 2 + d2 + d3)>>, %vec : vector<4x3x2x1xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : vector<4x3x2x1xi8>, memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 8 + d1 * 2 + d2 + d3)>> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_non_contiguous +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8, +// CHECK-SAME: %[[VEC:.+]]: vector<4x3x2x1xi8> +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] + +// ----- + +func @transfer_write_nonflattenable_non_row_major(%arg : memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>>, %vec : vector<4x3x2x1xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : vector<4x3x2x1xi8>, memref<4x3x2x1xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_non_row_major +// CHECK-SAME: %[[ARG:.+]]: memref<4x3x2x1xi8, +// CHECK-SAME: %[[VEC:.+]]: vector<4x3x2x1xi8> +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -583,6 +583,25 @@ } }; +struct TestFlattenVectorTransferPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transfer-flatten-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to rewrite contiguous row-major N-dimensional " + "vector.transfer_{read,write} ops into 1D transfers"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateFlattenVectorTransferPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // end anonymous namespace namespace mlir { @@ -613,6 +632,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir