diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -886,6 +886,7 @@ let hasFolder = 1; let hasVerifier = 1; + let hasCanonicalizer = 1; } def Vector_OuterProductOp : diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2179,6 +2179,40 @@ return success(); } +namespace { +/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type, +/// SplatOp(X):dst_type) to SplatOp(X):dst_type. +class FoldInsertStridedSliceSplat final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, + PatternRewriter &rewriter) const override { + auto srcSplatOp = + insertStridedSliceOp.getSource().getDefiningOp(); + auto destSplatOp = + insertStridedSliceOp.getDest().getDefiningOp(); + + if (!srcSplatOp || !destSplatOp) + return failure(); + + if (srcSplatOp.getInput() != destSplatOp.getInput()) + return failure(); + + rewriter.replaceOpWithNewOp( + insertStridedSliceOp, insertStridedSliceOp.getDestVectorType(), + srcSplatOp.getInput()); + return success(); + } +}; +} // namespace + +void vector::InsertStridedSliceOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + OpFoldResult InsertStridedSliceOp::fold(ArrayRef operands) { if (getSourceVectorType() == getDestVectorType()) return getSource(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1627,3 +1627,17 @@ %1 = vector.bitcast %0 : vector<4x8xi32> to vector<4x16xi16> return %1 : vector<4x16xi16> } + +// ----- + +// CHECK-LABEL: @insert_strided_slice_splat +// CHECK-SAME: (%[[ARG:.*]]: f32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8x16xf32> +// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32> +func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) { + %splat0 = vector.splat %x : vector<4x4xf32> + %splat1 = vector.splat %x : vector<8x16xf32> + %0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]} + : vector<4x4xf32> into vector<8x16xf32> + return %0 : vector<8x16xf32> +}