diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2028,11 +2028,32 @@ } }; +/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp. +class InsertSplatToSplat final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp op, + PatternRewriter &rewriter) const override { + auto srcSplat = op.getSource().getDefiningOp(); + auto dstSplat = op.getDest().getDefiningOp(); + + if (!srcSplat || !dstSplat) + return failure(); + + if (srcSplat.getInput() != dstSplat.getInput()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), srcSplat.getInput()); + return success(); + } +}; + } // namespace void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } // Eliminates insert operations that produce values identical to their source diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1669,3 +1669,16 @@ return %shuffle : vector<4xi32> } + +// ----- + +// CHECK-LABEL: func @insert_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<2x4x3xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32> +func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> { + %v0 = vector.splat %x : vector<4x3xi32> + %v1 = vector.splat %x : vector<2x4x3xi32> + %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> + return %insert : vector<2x4x3xi32> +}