diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1115,6 +1115,7 @@ } static StringRef getIndexAttrName() { return "index"; } }]; + let hasCanonicalizer = 1; } def Vector_PrintOp : diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1681,6 +1681,36 @@ return success(); } +namespace { + +class TupleGetFolder : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(TupleGetOp op, + PatternRewriter &rewriter) const override { + // Rewrite: + // %t = vector.tuple .., %e_i, .. + // %x = vector.tuple_get %t, i + // into: + // %t = vector.tuple .., %e_i, .. // one less use + // %x = %e_i + if (auto tupleOp = + dyn_cast_or_null(op.getOperand().getDefiningOp())) { + rewriter.replaceOp(op, tupleOp.getOperand(op.getIndex())); + return matchSuccess(); + } + return matchFailure(); + } +}; + +} // end anonymous namespace + +void TupleGetOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ConstantMaskOp //===----------------------------------------------------------------------===// @@ -1814,7 +1844,9 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); + patterns + .insert( + context); } namespace mlir { diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -302,3 +302,13 @@ } return } + +// CHECK-LABEL: func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) +// CHECK: return %arg1 + +func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { + %0 = vector.tuple %arg0, %arg1 : vector<4xf32>, vector<8xf32> + %1 = vector.tuple_get %0, 1 : tuple, vector<8xf32>> + return %1 : vector<8xf32> +} +