diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -35,6 +35,11 @@ void populateVectorToVectorTransformationPatterns( OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of patterns that casts away leading one dimensions from +/// vector ops. +void populateCastAwayVectorLeadingOneDimPatterns( + OwningRewritePatternList &patterns, MLIRContext *context); + /// Collect a set of vector slices transformation patterns: /// ExtractSlicesOpLowering, InsertSlicesOpLowering /// Useful for clients that want to express all vector "slices" diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2607,6 +2607,188 @@ } }; +// Trims leading one dimensions from `oldType` and returns the result type. +// Returns `vector<1xT>` if `oldType` only has one element. +static VectorType trimLeadingOneDims(VectorType oldType) { + ArrayRef oldShape = oldType.getShape(); + ArrayRef newShape = + oldShape.drop_while([](int64_t dim) { return dim == 1; }); + // Make sure we have at least 1 dimension per vector type requirements. + if (newShape.empty()) + newShape = oldShape.take_back(); + return VectorType::get(newShape, oldType.getElementType()); +} + +// Casts away leading one dimensions in vector.extract_strided_slice's vector +// input by inserting vector.shape_cast. +struct CastAwayExtractStridedSliceLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, + PatternRewriter &rewriter) const override { + // vector.extract_strided_slice requires the input and output vector to have + // the same rank. Here we drop leading one dimensions from the input vector + // type to make sure we don't cause mismatch. + VectorType oldSrcType = extractOp.getVectorType(); + VectorType newSrcType = trimLeadingOneDims(oldSrcType); + + if (newSrcType.getRank() == oldSrcType.getRank()) + return failure(); + + int64_t dropCount = oldSrcType.getRank() - newSrcType.getRank(); + + VectorType oldDstType = extractOp.getType(); + VectorType newDstType = + VectorType::get(oldDstType.getShape().drop_front(dropCount), + oldDstType.getElementType()); + + Location loc = extractOp.getLoc(); + + Value newSrcVector = rewriter.create( + loc, newSrcType, extractOp.vector()); + + // The offsets/sizes/strides attribute can have a less number of elements + // than the input vector's rank: it is meant for the leading dimensions. + auto newOffsets = rewriter.getArrayAttr( + extractOp.offsets().getValue().drop_front(dropCount)); + auto newSizes = rewriter.getArrayAttr( + extractOp.sizes().getValue().drop_front(dropCount)); + auto newStrides = rewriter.getArrayAttr( + extractOp.strides().getValue().drop_front(dropCount)); + + auto newExtractOp = rewriter.create( + loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); + + rewriter.replaceOpWithNewOp(extractOp, oldDstType, + newExtractOp); + + return success(); + } +}; + +// Casts away leading one dimensions in vector.extract_strided_slice's vector +// inputs by inserting vector.shape_cast. +struct CastAwayInsertStridedSliceLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, + PatternRewriter &rewriter) const override { + VectorType oldSrcType = insertOp.getSourceVectorType(); + VectorType newSrcType = trimLeadingOneDims(oldSrcType); + VectorType oldDstType = insertOp.getDestVectorType(); + VectorType newDstType = trimLeadingOneDims(oldDstType); + + if (newSrcType.getRank() == oldSrcType.getRank() && + newDstType.getRank() == oldDstType.getRank()) + return failure(); + + // Trim leading one dimensions from both operands. + Location loc = insertOp.getLoc(); + + Value newSrcVector = rewriter.create( + loc, newSrcType, insertOp.source()); + Value newDstVector = + rewriter.create(loc, newDstType, insertOp.dest()); + + auto newOffsets = rewriter.getArrayAttr( + insertOp.offsets().getValue().take_back(newDstType.getRank())); + auto newStrides = rewriter.getArrayAttr( + insertOp.strides().getValue().take_back(newSrcType.getRank())); + + auto newInsertOp = rewriter.create( + loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); + + rewriter.replaceOpWithNewOp(insertOp, oldDstType, + newInsertOp); + + return success(); + } +}; + +// Turns vector.transfer_read on vector with leading 1 dimensions into +// vector.shape_cast followed by vector.transfer_read on vector without leading +// 1 dimensions. +struct CastAwayTransferReadLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp read, + PatternRewriter &rewriter) const override { + auto memrefType = read.source().getType().dyn_cast(); + if (!memrefType || + memrefType.getElementType() != read.getVectorType().getElementType()) + return failure(); + + VectorType oldType = read.getVectorType(); + VectorType newType = trimLeadingOneDims(oldType); + + if (newType == oldType) + return failure(); + + AffineMap oldMap = read.permutation_map(); + ArrayRef newResults = + oldMap.getResults().take_back(newType.getRank()); + AffineMap newMap = + AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, + rewriter.getContext()); + + ArrayAttr mask; + if (read.masked()) + mask = rewriter.getArrayAttr( + read.maskedAttr().getValue().take_back(newType.getRank())); + + auto newRead = rewriter.create( + read.getLoc(), newType, read.source(), read.indices(), newMap, + read.padding(), mask); + rewriter.replaceOpWithNewOp(read, oldType, newRead); + + return success(); + } +}; + +// Turns vector.transfer_write on vector with leading 1 dimensions into +// vector.shape_cast followed by vector.transfer_write on vector without leading +// 1 dimensions. +struct CastAwayTransferWriteLeadingOneDim + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + auto memrefType = write.source().getType().dyn_cast(); + if (!memrefType || + memrefType.getElementType() != write.getVectorType().getElementType()) + return failure(); + + VectorType oldType = write.getVectorType(); + VectorType newType = trimLeadingOneDims(oldType); + + if (newType == oldType) + return failure(); + + AffineMap oldMap = write.permutation_map(); + ArrayRef newResults = + oldMap.getResults().take_back(newType.getRank()); + AffineMap newMap = + AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), newResults, + rewriter.getContext()); + + ArrayAttr mask; + if (write.masked()) + mask = rewriter.getArrayAttr( + write.maskedAttr().getValue().take_back(newType.getRank())); + + auto newVector = rewriter.create( + write.getLoc(), newType, write.vector()); + rewriter.replaceOpWithNewOp( + write, newVector, write.source(), write.indices(), newMap, mask); + + return success(); + } +}; + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( @@ -2622,6 +2804,15 @@ // clang-format on } +void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert( + context); +} + void mlir::vector::populateVectorSlicesLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -601,3 +601,49 @@ : vector<4x4xf32>, tensor<4x4xf32> return %r : tensor<4x4xf32> } + +// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims +func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> { + // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16> + // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16> + %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16> + // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> + // CHECK: return %[[RET]] + return %0: vector<1x1x8xf16> +} + +// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims +func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> { + // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16> + // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16> + // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16> + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16> + // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> + // CHECK: return %[[RET]] + return %0: vector<1x8x8xf16> +} + +// CHECK-LABEL: func @cast_away_transfer_read_leading_one_dims +func @cast_away_transfer_read_leading_one_dims(%arg0: memref<1x4x8x16xf16>) -> vector<1x4xf16> { + // CHECK: %[[C0:.+]] = constant 0 : index + %c0 = constant 0 : index + // CHECK: %[[F0:.+]] = constant 0.000000e+00 : f16 + %f0 = constant 0. : f16 + // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {masked = [false]} : memref<1x4x8x16xf16>, vector<4xf16> + // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16> + %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {masked = [false, false]} : memref<1x4x8x16xf16>, vector<1x4xf16> + // CHECK: return %[[CAST]] + return %0: vector<1x4xf16> +} + +// CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims +func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) { + // CHECK: %[[C0:.+]] = constant 0 : index + %c0 = constant 0 : index + // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16> + // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {masked = [false]} : vector<4xf16>, memref<1x4x8x16xf16> + + vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x4xf16>, memref<1x4x8x16xf16> + return +} + diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -45,6 +45,7 @@ } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); + populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); }