diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -67,6 +67,8 @@ /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); +void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3395,6 +3395,160 @@ } }; +static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter, mlir::Location loc, + Value input) { + MemRefType inputType = input.getType().cast(); + int64_t offset; + SmallVector strides; + LogicalResult res = getStridesAndOffset(inputType, strides, offset); + assert(succeeded(res)); + (void) res; + ArrayRef shape = inputType.getShape(); + SmallVector reducedShape; + SmallVector reducedStrides; + for (unsigned i = 0; i < inputType.getRank(); ++i) { + if (shape[i] != 1) { + reducedShape.push_back(shape[i]); + reducedStrides.push_back(strides[i]); + } + } + if (reducedShape.size() == shape.size()) + return input; + AffineMap reducedStridesMap = + makeStridedLinearLayoutMap(reducedStrides, offset, inputType.getContext()); + MemRefType resultType = MemRefType::Builder(inputType).setShape(reducedShape).setLayout(AffineMapAttr::get(reducedStridesMap)); + SmallVector subviewOffsets(inputType.getRank(), 0); + SmallVector subviewStrides(inputType.getRank(), 1); + return rewriter.create(loc, resultType, input, subviewOffsets, shape, subviewStrides); +} + +static Value collapseTo1D(PatternRewriter &rewriter, mlir::Location loc, + Value input) { + Value rankReducedInput = rankReducingSubviewDroppingUnitDims(rewriter, loc, input); + ShapedType rankReducedInputType = rankReducedInput.getType().cast(); + if (rankReducedInputType.getRank() == 1) + return rankReducedInput; + ReassociationIndices indices; + for (int i = 0; i < rankReducedInputType.getRank(); ++i) + indices.push_back(i); + return rewriter.create( + loc, rankReducedInput, std::array{indices}); +} + +// Helper determining if a memref is static-shape and contiguous-row-major +// layout, still allowing an arbitrary offset (unlike some existing similar +// functions). +static bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { + SmallVector strides; + int64_t offset; + if (!memrefType.hasStaticShape()) { + return false; + } + LogicalResult res = getStridesAndOffset(memrefType, strides, offset); + assert(succeeded(res)); + (void) res; + int64_t productOfInnerMostSizes = 1; + for (int i = memrefType.getRank() - 1; i >= 0; --i) { + int64_t dimsize = memrefType.getDimSize(i); + // The dimsize!=1 condition here means that we ignore the strides of + // unit dims, as they don't make a practical difference. + if (dimsize != 1 && strides[i] != productOfInnerMostSizes) { + return false; + } + // This simple arithmetic is correct thanks to having ensured above that + // we have a static shape. + productOfInnerMostSizes *= dimsize; + } + return true; +} + +class FlattenTransferReadPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().cast(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) { + // Already 1D, nothing to do. + return failure(); + } + if (!isStaticShapeAndContiguousRowMajor(sourceType)) { + return failure(); + } + if (sourceType.getNumElements() != vectorType.getNumElements()) { + return failure(); + } + if (transferReadOp.hasOutOfBoundsDim()) { + return failure(); + } + if (!transferReadOp.permutation_map().isMinorIdentity()) { + return failure(); + } + if (transferReadOp.mask()) { + return failure(); + } + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = collapseTo1D(rewriter, loc, source); + Value read1d = rewriter.create( + loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + rewriter.replaceOpWithNewOp( + transferReadOp, vector.getType().cast(), read1d); + return success(); + } +}; + +class FlattenTransferWritePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().cast(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) { + // Already 1D, nothing to do. + return failure(); + } + if (!isStaticShapeAndContiguousRowMajor(sourceType)) { + return failure(); + } + if (sourceType.getNumElements() != vectorType.getNumElements()) { + return failure(); + } + if (transferWriteOp.hasOutOfBoundsDim()) { + return failure(); + } + if (!transferWriteOp.permutation_map().isMinorIdentity()) { + return failure(); + } + if (transferWriteOp.mask()) { + return failure(); + } + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = collapseTo1D(rewriter, loc, source); + Value vector1d = + rewriter.create(loc, vectorType1d, vector); + rewriter.create(loc, vector1d, source1d, + ValueRange{c0}, identityMap1D); + rewriter.eraseOp(transferWriteOp); + return success(); + } +}; + void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool indexOptimizations) { patterns.add(patterns.getContext()); } +void mlir::vector::populateFlattenVectorTransferPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} + void mlir::vector::populateBubbleVectorBitCastOpPatterns( RewritePatternSet &patterns) { patterns.add) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8> +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> +// CHECK: return %[[VEC2D]] + +// ----- + +func @transfer_read_flattenable_with_offset(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> +// CHECK: return %[[VEC2D]] + +// ----- + +func @transfer_read_flattenable_with_offset_with_rank_reducing_subview(%arg : memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>>) -> vector<3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>>, vector<3x2xi8> + return %v : vector<3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[SUBVIEW]] {{.}}[0, 1] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<6xi8> to vector<3x2xi8> +// CHECK: return %[[VEC2D]] + +// ----- + +func @transfer_read_flattenable_with_offset_with_rank_reducing_subview_and_no_collapse(%arg : memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>>) -> vector<2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>>, vector<2xi8> + return %v : vector<2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x1x2xi8 +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 1, 2] [1, 1, 1, 1] +// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[SUBVIEW]] +// CHECK: return %[[READ1D]] + +// ----- + +func @transfer_read_nonflattenable_out_of_bounds(%arg : memref<5x4x3x2xi8>, %i : index) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%i, %c0, %c0, %c0], %cst {in_bounds = [false, true, true, true]} : memref<5x4x3x2xi8>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_out_of_bounds +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8>, +// CHECK-SAME: %[[I:.+]]: index +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]][%[[I]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_read_nonflattenable_non_contiguous(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_non_contiguous +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_read_nonflattenable_non_row_major(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 5 + d2 * 20 + d3 * 60)>>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 5 + d2 * 20 + d3 * 60)>>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_nonflattenable_non_row_major +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK: %[[READ:.+]] = vector.transfer_read %[[ARG]] +// CHECK: return %[[READ]] + +// ----- + +func @transfer_write_flattenable(%arg : memref<5x4x3x2xi8>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8>, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8> into memref<120xi8> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_flattenable_with_offset(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 24 + d1 * 6 + d2 * 2 + d3 + s0)>> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_offset +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, {{.+}}>, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_flattenable_with_offset_with_rank_reducing_subview(%arg : memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>>, %vec : vector<3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<3x2xi8>, memref<1x1x3x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 6 + d1 * 6 + d2 * 2 + d3 + s0)>> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_offset_with_rank_reducing_subview +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8, {{.+}}>, +// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1] +// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[SUBVIEW]] {{.}}[0, 1]{{.}} : memref<3x2xi8, {{.+}}> into memref<6xi8, {{.+}}> +// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<3x2xi8> to vector<6xi8> +// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + +// ----- + +func @transfer_write_flattenable_with_offset_with_rank_reducing_subview_and_no_collapse(%arg : memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>>, %vec : vector<2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : vector<2xi8>, memref<1x1x1x2xi8, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 2 + d1 * 2 + d2 * 2 + d3 + s0)>> + return +} + +// CHECK-LABEL: func @transfer_write_flattenable_with_offset_with_rank_reducing_subview_and_no_collapse +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x1x2xi8, {{.+}}>, +// CHECK-SAME: %[[VEC:.+]]: vector<2xi8> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 1, 2] [1, 1, 1, 1] +// CHECK: vector.transfer_write %[[VEC]], %[[SUBVIEW]] + +// ----- + +func @transfer_write_nonflattenable_out_of_bounds(%arg : memref<5x4x3x2xi8>, %vec : vector<5x4x3x2xi8>, %i : index) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%i, %c0, %c0, %c0] {in_bounds = [false, true, true, true]} : vector<5x4x3x2xi8>, memref<5x4x3x2xi8> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_out_of_bounds +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8>, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK-SAME: %[[I:.+]]: index +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] + +// ----- + +func @transfer_write_nonflattenable_non_contiguous(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 * 25 + d1 * 6 + d2 * 2 + d3)>> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_non_contiguous +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] + +// ----- + +func @transfer_write_nonflattenable_non_row_major(%arg : memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0] : vector<5x4x3x2xi8>, memref<5x4x3x2xi8, affine_map<(d0, d1, d2, d3) -> (d0 + d1 * 4 + d2 * 12 + d3 * 24)>> + return +} + +// CHECK-LABEL: func @transfer_write_nonflattenable_non_row_major +// CHECK-SAME: %[[ARG:.+]]: memref<5x4x3x2xi8, +// CHECK-SAME: %[[VEC:.+]]: vector<5x4x3x2xi8> +// CHECK: vector.transfer_write %[[VEC]], %[[ARG]] diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -583,6 +583,25 @@ } }; +struct TestFlattenVectorTransferPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transfer-flatten-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to rewrite contiguous row-major N-dimensional " + "vector.transfer_{read,write} ops into 1D transfers"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateFlattenVectorTransferPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // end anonymous namespace namespace mlir { @@ -613,6 +632,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir