diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -487,6 +487,7 @@ }]; let assemblyFormat = "operands $mask attr-dict `:` type(operands)"; let hasVerifier = 1; + let hasCanonicalizer = 1; } def Vector_ExtractElementOp : diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1882,6 +1882,36 @@ return DenseElementsAttr::get(getVectorType(), results); } +namespace { + +/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. +class ShuffleSplat final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShuffleOp op, + PatternRewriter &rewriter) const override { + auto v1Splat = op.getV1().getDefiningOp(); + auto v2Splat = op.getV2().getDefiningOp(); + + if (!v1Splat || !v2Splat) + return failure(); + + if (v1Splat.getInput() != v2Splat.getInput()) + return failure(); + + rewriter.replaceOpWithNewOp(op, op.getType(), v1Splat.getInput()); + return success(); + } +}; + +} // namespace + +void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // InsertElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1655,3 +1655,17 @@ : vector<2x4xf32> into vector<8x16xf32> return %1 : vector<8x16xf32> } + +// ----- + +// CHECK-LABEL: func @shuffle_splat +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<4xi32> +// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32> +func.func @shuffle_splat(%x : i32) -> vector<4xi32> { + %v0 = vector.splat %x : vector<4xi32> + %v1 = vector.splat %x : vector<2xi32> + %shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32> + return %shuffle : vector<4xi32> +} +