diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -519,6 +519,7 @@ return vector().getType().cast(); } }]; + let hasCanonicalizer = 1; let hasFolder = 1; } @@ -763,6 +764,7 @@ return dest().getType().cast(); } }]; + let hasCanonicalizer = 1; } def Vector_InsertSlicesOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1142,6 +1142,33 @@ return OpFoldResult(); } +namespace { + +// If extractOp is only removing unit dimensions it can be transformed to a +// shapecast. +class ExtractToShapeCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter &rewriter) const override { + auto dstVecType = extractOp.getResult().getType().dyn_cast(); + if (!dstVecType || extractOp.getVectorType().getNumElements() != + dstVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp(extractOp, dstVecType, + extractOp.vector()); + return success(); + } +}; + +} // namespace + +void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // ExtractSlicesOp //===----------------------------------------------------------------------===// @@ -1536,6 +1563,33 @@ return success(); } +namespace { + +// If insertOp is only inserting unit dimensions it can be transformed to a +// shapecast. +class InsertToShapeCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp insertOp, + PatternRewriter &rewriter) const override { + auto srcVecType = insertOp.getSourceType().dyn_cast(); + if (!srcVecType || insertOp.getDestVectorType().getNumElements() != + srcVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getDestVectorType(), insertOp.source()); + return success(); + } +}; + +} // namespace + +void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // InsertSlicesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -504,16 +504,18 @@ // CHECK: %[[R0:.*]] = vector.extract %[[A0]][1, 0, 1, 1] : vector<5x1x3x2xf32> // CHECK: %[[R1:.*]] = vector.extract %[[A0]][1, 0, 2] : vector<5x1x3x2xf32> // CHECK: %[[R2:.*]] = vector.extract %[[A1]][7] : vector<8x4x2xf32> -// CHECK: return %[[R0]], %[[R1]], %[[R2]] : f32, vector<2xf32>, vector<4x2xf32> +// CHECK: return %[[R0]], %[[R1]], %[[R2]], %[[A1]] : f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32> func @fold_extract_shapecast(%arg0 : vector<5x1x3x2xf32>, %arg1 : vector<8x4x2xf32>) - -> (f32, vector<2xf32>, vector<4x2xf32>) { + -> (f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32>) { %0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xf32> %1 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<4x2x4x2xf32> + %2 = vector.shape_cast %arg1 : vector<8x4x2xf32> to vector<1x8x4x2xf32> %r1 = vector.extract %0[4, 1] : vector<15x2xf32> %r2 = vector.extract %0[5] : vector<15x2xf32> %r3 = vector.extract %1[3, 1] : vector<4x2x4x2xf32> - return %r1, %r2, %r3 : f32, vector<2xf32>, vector<4x2xf32> + %r4 = vector.extract %2[0] : vector<1x8x4x2xf32> + return %r1, %r2, %r3, %r4 : f32, vector<2xf32>, vector<4x2xf32>, vector<8x4x2xf32> } // ----- @@ -932,3 +934,17 @@ vector<1x4xf32>, tensor<4x4xf32> return %w2 : tensor<4x4xf32> } + +// ----- + +// CHECK-LABEL: func @insert_extract_to_shapecast +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) +// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> +// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> +func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>, + %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { + %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32> + %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> + return %0, %1 : vector<4xf32>, vector<1x1x4xf32> +}