diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -67,13 +67,21 @@ /// pairs or forward write-read pairs. void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns); -/// Collect a set of leading one dimension removal patterns. +/// Collect a set of one dimension removal patterns. /// /// These patterns insert rank-reducing memref.subview ops to remove one /// dimensions. With them, there are more chances that we can avoid /// potentially exensive vector.shape_cast operations. void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns); +/// Collect a set of patterns to flatten n-D vector transfers on contiguous +/// memref. +/// +/// These patterns insert memref.collapse_shape + vector.shape_cast patterns +/// to transform multiple small n-D transfers into a larger 1-D transfer where +/// the memref contiguity properties allow it. +void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns); + /// Collect a set of patterns that bubble up/down bitcast ops. /// /// These patterns move vector.bitcast ops to be before insert ops or after diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -531,6 +531,11 @@ /// Return null if the layout is not compatible with a strided layout. AffineMap getStridedLinearLayoutMap(MemRefType t); +/// Helper determining if a memref is static-shape and contiguous-row-major +/// layout, while still allowing for an arbitrary offset (any static or +/// dynamic value). +bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType); + } // namespace mlir #endif // MLIR_IR_BUILTINTYPES_H diff --git a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -227,7 +227,8 @@ MemRefType inputType = input.getType().cast(); assert(inputType.hasStaticShape()); MemRefType resultType = dropUnitDims(inputType); - if (resultType == inputType) + if (canonicalizeStridedLayout(resultType) == + canonicalizeStridedLayout(inputType)) return input; SmallVector subviewOffsets(inputType.getRank(), 0); SmallVector subviewStrides(inputType.getRank(), 1); @@ -333,6 +334,130 @@ } }; +/// Creates a memref.collapse_shape collapsing all of the dimensions of the +/// input into a 1D shape. +// TODO: move helper function +static Value collapseContiguousRowMajorMemRefTo1D(PatternRewriter &rewriter, + mlir::Location loc, + Value input) { + Value rankReducedInput = + rankReducingSubviewDroppingUnitDims(rewriter, loc, input); + ShapedType rankReducedInputType = + rankReducedInput.getType().cast(); + if (rankReducedInputType.getRank() == 1) + return rankReducedInput; + ReassociationIndices indices; + for (int i = 0; i < rankReducedInputType.getRank(); ++i) + indices.push_back(i); + return rewriter.create( + loc, rankReducedInput, std::array{indices}); +} + +/// Rewrites contiguous row-major vector.transfer_read ops by inserting +/// memref.collapse_shape on the source so that the resulting +/// vector.transfer_read has a 1D source. Requires the source shape to be +/// already reduced i.e. without unit dims. +class FlattenContiguousRowMajorTransferReadPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, + PatternRewriter &rewriter) const override { + auto loc = transferReadOp.getLoc(); + Value vector = transferReadOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferReadOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // Contiguity check is valid on tensors only. + if (!sourceType) + return failure(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) + // Already 1D, nothing to do. + return failure(); + if (!isStaticShapeAndContiguousRowMajor(sourceType)) + return failure(); + if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) + // This pattern requires the source to already be rank-reduced. + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferReadOp.hasOutOfBoundsDim()) + return failure(); + if (!transferReadOp.permutation_map().isMinorIdentity()) + return failure(); + if (transferReadOp.mask()) + return failure(); + if (llvm::any_of(transferReadOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = + collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); + Value read1d = rewriter.create( + loc, vectorType1d, source1d, ValueRange{c0}, identityMap1D); + rewriter.replaceOpWithNewOp( + transferReadOp, vector.getType().cast(), read1d); + return success(); + } +}; + +/// Rewrites contiguous row-major vector.transfer_write ops by inserting +/// memref.collapse_shape on the source so that the resulting +/// vector.transfer_write has a 1D source. Requires the source shape to be +/// already reduced i.e. without unit dims. +class FlattenContiguousRowMajorTransferWritePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, + PatternRewriter &rewriter) const override { + auto loc = transferWriteOp.getLoc(); + Value vector = transferWriteOp.vector(); + VectorType vectorType = vector.getType().cast(); + Value source = transferWriteOp.source(); + MemRefType sourceType = source.getType().dyn_cast(); + // Contiguity check is valid on tensors only. + if (!sourceType) + return failure(); + if (vectorType.getRank() == 1 && sourceType.getRank() == 1) + // Already 1D, nothing to do. + return failure(); + if (!isStaticShapeAndContiguousRowMajor(sourceType)) + return failure(); + if (getReducedRank(sourceType.getShape()) != sourceType.getRank()) + // This pattern requires the source to already be rank-reduced. + return failure(); + if (sourceType.getNumElements() != vectorType.getNumElements()) + return failure(); + // TODO: generalize this pattern, relax the requirements here. + if (transferWriteOp.hasOutOfBoundsDim()) + return failure(); + if (!transferWriteOp.permutation_map().isMinorIdentity()) + return failure(); + if (transferWriteOp.mask()) + return failure(); + if (llvm::any_of(transferWriteOp.indices(), + [](Value v) { return !isZero(v); })) + return failure(); + Value c0 = rewriter.create(loc, 0); + auto identityMap1D = rewriter.getMultiDimIdentityMap(1); + VectorType vectorType1d = VectorType::get({sourceType.getNumElements()}, + sourceType.getElementType()); + Value source1d = + collapseContiguousRowMajorMemRefTo1D(rewriter, loc, source); + Value vector1d = + rewriter.create(loc, vectorType1d, vector); + rewriter.create(loc, vector1d, source1d, + ValueRange{c0}, identityMap1D); + rewriter.eraseOp(transferWriteOp); + return success(); + } +}; + } // namespace void mlir::vector::transferOpflowOpt(FuncOp func) { @@ -358,3 +483,11 @@ patterns.getContext()); populateShapeCastFoldingPatterns(patterns); } + +void mlir::vector::populateFlattenVectorTransferPatterns( + RewritePatternSet &patterns) { + patterns.add( + patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -1168,3 +1168,40 @@ return AffineMap(); return makeStridedLinearLayoutMap(strides, offset, t.getContext()); } + +/// Return the AffineExpr representation of the offset, assuming `memRefType` +/// is a strided memref. +static AffineExpr getOffsetExpr(MemRefType memrefType) { + SmallVector strides; + AffineExpr offset; + if (failed(getStridesAndOffset(memrefType, strides, offset))) + assert(false && "expected strided memref"); + return offset; +} + +/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and +/// `offset` AffineExpr. +static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context, + ArrayRef shape, + Type elementType, + AffineExpr offset) { + AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context); + AffineExpr contiguousRowMajor = canonical + offset; + AffineMap contiguousRowMajorMap = + AffineMap::inferFromExprList({contiguousRowMajor})[0]; + return MemRefType::get(shape, elementType, contiguousRowMajorMap); +} + +/// Helper determining if a memref is static-shape and contiguous-row-major +/// layout, while still allowing for an arbitrary offset (any static or +/// dynamic value). +bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) { + if (!memrefType.hasStaticShape()) + return false; + AffineExpr offset = getOffsetExpr(memrefType); + MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType( + memrefType.getContext(), memrefType.getShape(), + memrefType.getElementType(), offset); + return canonicalizeStridedLayout(memrefType) == + canonicalizeStridedLayout(contiguousRowMajorMemRefType); +} diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s + +func @transfer_read_flattenable_with_offset( + %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>) -> vector<5x4x3x2xi8> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0 : i8 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : + memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, vector<5x4x3x2xi8> + return %v : vector<5x4x3x2xi8> +} + +// CHECK-LABEL: func @transfer_read_flattenable_with_offset +// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] +// C-HECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] +// C-HECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> +// C-HECK: return %[[VEC2D]] + +// ----- + +func @transfer_write_flattenable_with_offset( + %arg : memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]>, %vec : vector<5x4x3x2xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : + vector<5x4x3x2xi8>, memref<5x4x3x2xi8, offset:?, strides:[24, 6, 2, 1]> + return +} + +// C-HECK-LABEL: func @transfer_write_flattenable_with_offset +// C-HECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 +// C-HECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> +// C-HECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> +// C-HECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> +// C-HECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] + diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -598,6 +598,25 @@ } }; +struct TestFlattenVectorTransferPatterns + : public PassWrapper { + StringRef getArgument() const final { + return "test-vector-transfer-flatten-patterns"; + } + StringRef getDescription() const final { + return "Test patterns to rewrite contiguous row-major N-dimensional " + "vector.transfer_{read,write} ops into 1D transfers"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateFlattenVectorTransferPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + } // namespace namespace mlir { @@ -630,6 +649,8 @@ PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir