diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -1115,6 +1115,7 @@ } static StringRef getIndexAttrName() { return "index"; } }]; + let hasFolder = 1; } def Vector_PrintOp : diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -1681,6 +1681,19 @@ return success(); } +OpFoldResult TupleGetOp::fold(ArrayRef operands) { + // Rewrite: + // %t = vector.tuple .., %e_i, .. + // %x = vector.tuple_get %t, i + // into: + // %t = vector.tuple .., %e_i, .. // one less use + // %x = %e_i + if (auto tupleOp = dyn_cast_or_null(getOperand().getDefiningOp())) { + return tupleOp.getOperand(getIndex()); + } + return {}; +} + //===----------------------------------------------------------------------===// // ConstantMaskOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -302,3 +302,12 @@ } return } + +// CHECK-LABEL: func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) +// CHECK: return %arg1 + +func @tuple_get(%arg0: vector<4xf32>, %arg1: vector<8xf32>) -> vector<8xf32> { + %0 = vector.tuple %arg0, %arg1 : vector<4xf32>, vector<8xf32> + %1 = vector.tuple_get %0, 1 : tuple, vector<8xf32>> + return %1 : vector<8xf32> +}