diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -763,6 +763,7 @@ return dest().getType().cast(); } }]; + let hasFolder = 1; } def Vector_InsertSlicesOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -850,8 +850,6 @@ static LogicalResult verify(vector::ExtractOp op) { auto positionAttr = op.position().getValue(); - if (positionAttr.empty()) - return op.emitOpError("expected non-empty position attribute"); if (positionAttr.size() > static_cast(op.getVectorType().getRank())) return op.emitOpError( "expected position attribute of rank smaller than vector rank"); @@ -1128,6 +1126,13 @@ return extractOp.getResult(); } +// Fold extractOp if position is empty, i.e., the entire vector is extracted. +static Value foldExtractEntireVector(ExtractOp extractOp) { + if (extractOp.position().empty()) + return extractOp.vector(); + return Value(); +} + OpFoldResult ExtractOp::fold(ArrayRef) { if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); @@ -1139,6 +1144,8 @@ return val; if (auto val = foldExtractFromShapeCast(*this)) return val; + if (auto val = foldExtractEntireVector(*this)) + return val; return OpFoldResult(); } @@ -1508,8 +1515,6 @@ static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); - if (positionAttr.empty()) - return op.emitOpError("expected non-empty position attribute"); auto destVectorType = op.getDestVectorType(); if (positionAttr.size() > static_cast(destVectorType.getRank())) return op.emitOpError( @@ -1536,6 +1541,15 @@ return success(); } +// Eliminates insert operations that produce values identical to their source +// value. This happens when the source and destination vectors have identical +// sizes. +OpFoldResult vector::InsertOp::fold(ArrayRef operands) { + if (position().empty()) + return source(); + return {}; +} + //===----------------------------------------------------------------------===// // InsertSlicesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -80,13 +80,6 @@ // ----- -func @extract_position_empty(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.extract %arg0[] : vector<4x8x16xf32> -} - -// ----- - func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than vector}} %1 = vector.extract %arg0[0, 0, 0, 0] : vector<4x8x16xf32> @@ -138,13 +131,6 @@ // ----- -func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { - // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32> -} - -// ----- - func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -154,14 +154,16 @@ } // CHECK-LABEL: @extract -func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { +func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { + // CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32> + %0 = vector.extract %arg0[] : vector<4x8x16xf32> // CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32> %1 = vector.extract %arg0[3] : vector<4x8x16xf32> // CHECK-NEXT: vector.extract {{.*}}[3, 3] : vector<4x8x16xf32> %2 = vector.extract %arg0[3, 3] : vector<4x8x16xf32> // CHECK-NEXT: vector.extract {{.*}}[3, 3, 3] : vector<4x8x16xf32> %3 = vector.extract %arg0[3, 3, 3] : vector<4x8x16xf32> - return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32 + return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32 } // CHECK-LABEL: @insert_element @@ -181,7 +183,9 @@ %2 = vector.insert %b, %res[3, 3] : vector<16xf32> into vector<4x8x16xf32> // CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3, 3] : f32 into vector<4x8x16xf32> %3 = vector.insert %a, %res[3, 3, 3] : f32 into vector<4x8x16xf32> - return %3 : vector<4x8x16xf32> + // CHECK: vector.insert %{{.*}}, %{{.*}}[] : vector<4x8x16xf32> into vector<4x8x16xf32> + %4 = vector.insert %3, %3[] : vector<4x8x16xf32> into vector<4x8x16xf32> + return %4 : vector<4x8x16xf32> } // CHECK-LABEL: @outerproduct