diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -56,6 +56,9 @@ void populateVectorToVectorCanonicalizationPatterns( RewritePatternSet &patterns); +/// Collect a set of vector.shape_cast folding patterns. +void populateShapeCastFoldingPatterns(RewritePatternSet &patterns); + /// Collect a set of leading one dimension removal patterns. /// /// These patterns insert vector.shape_cast to remove leading one dimensions diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1215,29 +1215,12 @@ namespace { -// If extractOp is only removing unit dimensions it can be transformed to a -// shapecast. -class ExtractToShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractOp extractOp, - PatternRewriter &rewriter) const override { - auto dstVecType = extractOp.getResult().getType().dyn_cast(); - if (!dstVecType || extractOp.getVectorType().getNumElements() != - dstVecType.getNumElements()) - return failure(); - rewriter.replaceOpWithNewOp(extractOp, dstVecType, - extractOp.vector()); - return success(); - } -}; - } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + // ExtractToShapeCast is not a default canonicalization, it is opt-in by + // calling `populateCastAwayVectorLeadingOneDimPatterns` } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, @@ -1401,27 +1384,6 @@ namespace { -// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In -// the degenerated case where the broadcast only adds dimensions of size 1 it -// can be replaced by a ShapeCastOp. This canonicalization checks if the total -// number of elements is the same before and after the broadcast to detect if -// the only change in the vector type are new dimensions of size 1. -class BroadcastToShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BroadcastOp broadcastOp, - PatternRewriter &rewriter) const override { - auto srcVecType = broadcastOp.getSourceType().dyn_cast(); - if (!srcVecType || broadcastOp.getVectorType().getNumElements() != - srcVecType.getNumElements()) - return failure(); - rewriter.replaceOpWithNewOp( - broadcastOp, broadcastOp.getVectorType(), broadcastOp.source()); - return success(); - } -}; - // Fold broadcast1(broadcast2(x)) into broadcast1(x). struct BroadcastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1440,7 +1402,9 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + // BroadcastToShapeCast is not a default canonicalization, it is opt-in by + // calling `populateCastAwayVectorLeadingOneDimPatterns` + results.add(context); } //===----------------------------------------------------------------------===// @@ -1605,31 +1569,10 @@ return success(); } -namespace { - -// If insertOp is only inserting unit dimensions it can be transformed to a -// shapecast. -class InsertToShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertOp insertOp, - PatternRewriter &rewriter) const override { - auto srcVecType = insertOp.getSourceType().dyn_cast(); - if (!srcVecType || insertOp.getDestVectorType().getNumElements() != - srcVecType.getNumElements()) - return failure(); - rewriter.replaceOpWithNewOp( - insertOp, insertOp.getDestVectorType(), insertOp.source()); - return success(); - } -}; - -} // namespace - void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + // InsertToShapeCast is not a default canonicalization, it is opt-in by + // calling `populateCastAwayVectorLeadingOneDimPatterns` } // Eliminates insert operations that produce values identical to their source diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1113,11 +1113,18 @@ Location loc = op.getLoc(); auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); - // Intended 2D/1D lowerings with better implementations. + + // Special case 2D/1D lowerings with better implementations. + // TODO: make is ND/1D to allow generic ND->1D->MD. int64_t srcRank = sourceVectorType.getRank(); int64_t resRank = resultVectorType.getRank(); if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) return failure(); + + // Generic ShapeCast lowering path goes all the way down to unrolled scalar + // extract/insert chains. + // TODO: consider evolving the semantics to only allow 1D source or dest and + // drop this potentially very expensive lowering. // Compute number of elements involved in the reshape. int64_t numElts = 1; for (int64_t r = 0; r < srcRank; r++) @@ -3177,6 +3184,63 @@ } }; +// If extractOp is only removing unit dimensions it can be transformed to a +// shapecast. +class ExtractToShapeCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter &rewriter) const override { + auto dstVecType = extractOp.getResult().getType().dyn_cast(); + if (!dstVecType || extractOp.getVectorType().getNumElements() != + dstVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp(extractOp, dstVecType, + extractOp.vector()); + return success(); + } +}; + +// If insertOp is only inserting unit dimensions it can be transformed to a +// shapecast. +class InsertToShapeCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp insertOp, + PatternRewriter &rewriter) const override { + auto srcVecType = insertOp.getSourceType().dyn_cast(); + if (!srcVecType || insertOp.getDestVectorType().getNumElements() != + srcVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getDestVectorType(), insertOp.source()); + return success(); + } +}; + +// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In +// the degenerated case where the broadcast only adds dimensions of size 1 it +// can be replaced by a ShapeCastOp. This canonicalization checks if the total +// number of elements is the same before and after the broadcast to detect if +// the only change in the vector type are new dimensions of size 1. +class BroadcastToShapeCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto srcVecType = broadcastOp.getSourceType().dyn_cast(); + if (!srcVecType || broadcastOp.getVectorType().getNumElements() != + srcVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp( + broadcastOp, broadcastOp.getVectorType(), broadcastOp.source()); + return success(); + } +}; + // Returns the values in `arrayAttr` as an integer vector. static SmallVector getIntValueVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>( @@ -3651,16 +3715,21 @@ patterns.getContext()); } +void mlir::vector::populateShapeCastFoldingPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( RewritePatternSet &patterns) { - patterns.add, - CastAwayBroadcastLeadingOneDim, - CastAwayElementwiseLeadingOneDim, ShapeCastOpFolder>( - patterns.getContext()); + patterns.add< + BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim, + CastAwayInsertStridedSliceLeadingOneDim, + CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim, + CastAwayBroadcastLeadingOneDim, + CastAwayBroadcastLeadingOneDim, CastAwayElementwiseLeadingOneDim, + ExtractToShapeCast, InsertToShapeCast>(patterns.getContext()); + populateShapeCastFoldingPatterns(patterns); } void mlir::vector::populateBubbleVectorBitCastOpPatterns( diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -717,16 +717,6 @@ // ----- -// CHECK-LABEL: broadcast_to_shapecast -// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16> -// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16> -func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> { - %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16> - return %0 : vector<1x4x4xf16> -} - -// ----- - // CHECK-LABEL: func @dead_transfer_op // CHECK-NOT: vector.transfer_read // CHECK-NOT: vector.transfer_write @@ -971,20 +961,6 @@ // ----- -// CHECK-LABEL: func @insert_extract_to_shapecast -// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) -// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> -// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> -// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> -func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>, - %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { - %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32> - %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> - return %0, %1 : vector<4xf32>, vector<1x1x4xf32> -} - -// ----- - // CHECK-LABEL: func @transfer_read_of_extract_slice( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index diff --git a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s + +// CHECK-LABEL: broadcast_to_shapecast +// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16> +// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16> +func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> { + %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16> + return %0 : vector<1x4x4xf16> +} + +// ----- + +// CHECK-LABEL: func @insert_extract_to_shapecast +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) +// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> +func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>, + %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { + %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32> + %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> + return %0, %1 : vector<4xf32>, vector<1x1x4xf32> +}