diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -764,7 +764,9 @@ return dest().getType().cast(); } }]; + let hasCanonicalizer = 1; + let hasFolder = 1; } def Vector_InsertSlicesOp : diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -656,6 +656,12 @@ if (!llvmResultType) return failure(); + // Extract entire vector. Should be handled by folder, but just to be safe. + if (positionArrayAttr.empty()) { + rewriter.replaceOp(extractOp, adaptor.vector()); + return success(); + } + // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value extracted = rewriter.create( @@ -762,6 +768,13 @@ if (!llvmResultType) return failure(); + // Overwrite entire vector with value. Should be handled by folder, but + // just to be safe. + if (positionArrayAttr.empty()) { + rewriter.replaceOp(insertOp, adaptor.source()); + return success(); + } + // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { Value inserted = rewriter.create( diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -872,8 +872,6 @@ static LogicalResult verify(vector::ExtractOp op) { auto positionAttr = op.position().getValue(); - if (positionAttr.empty()) - return op.emitOpError("expected non-empty position attribute"); if (positionAttr.size() > static_cast(op.getVectorType().getRank())) return op.emitOpError( "expected position attribute of rank smaller than vector rank"); @@ -1151,6 +1149,8 @@ } OpFoldResult ExtractOp::fold(ArrayRef) { + if (position().empty()) + return vector(); if (succeeded(foldExtractOpFromExtractChain(*this))) return getResult(); if (succeeded(foldExtractOpFromTranspose(*this))) @@ -1557,8 +1557,6 @@ static LogicalResult verify(InsertOp op) { auto positionAttr = op.position().getValue(); - if (positionAttr.empty()) - return op.emitOpError("expected non-empty position attribute"); auto destVectorType = op.getDestVectorType(); if (positionAttr.size() > static_cast(destVectorType.getRank())) return op.emitOpError( @@ -1612,6 +1610,15 @@ results.add(context); } +// Eliminates insert operations that produce values identical to their source +// value. This happens when the source and destination vectors have identical +// sizes. +OpFoldResult vector::InsertOp::fold(ArrayRef operands) { + if (position().empty()) + return source(); + return {}; +} + //===----------------------------------------------------------------------===// // InsertSlicesOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -80,13 +80,6 @@ // ----- -func @extract_position_empty(%arg0: vector<4x8x16xf32>) { - // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.extract %arg0[] : vector<4x8x16xf32> -} - -// ----- - func @extract_position_rank_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than vector}} %1 = vector.extract %arg0[0, 0, 0, 0] : vector<4x8x16xf32> @@ -138,13 +131,6 @@ // ----- -func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { - // expected-error@+1 {{expected non-empty position attribute}} - %1 = vector.insert %a, %b[] : f32 into vector<4x8x16xf32> -} - -// ----- - func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -154,14 +154,16 @@ } // CHECK-LABEL: @extract -func @extract(%arg0: vector<4x8x16xf32>) -> (vector<8x16xf32>, vector<16xf32>, f32) { +func @extract(%arg0: vector<4x8x16xf32>) -> (vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32) { + // CHECK: vector.extract {{.*}}[] : vector<4x8x16xf32> + %0 = vector.extract %arg0[] : vector<4x8x16xf32> // CHECK: vector.extract {{.*}}[3] : vector<4x8x16xf32> %1 = vector.extract %arg0[3] : vector<4x8x16xf32> // CHECK-NEXT: vector.extract {{.*}}[3, 3] : vector<4x8x16xf32> %2 = vector.extract %arg0[3, 3] : vector<4x8x16xf32> // CHECK-NEXT: vector.extract {{.*}}[3, 3, 3] : vector<4x8x16xf32> %3 = vector.extract %arg0[3, 3, 3] : vector<4x8x16xf32> - return %1, %2, %3 : vector<8x16xf32>, vector<16xf32>, f32 + return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32 } // CHECK-LABEL: @insert_element @@ -181,7 +183,9 @@ %2 = vector.insert %b, %res[3, 3] : vector<16xf32> into vector<4x8x16xf32> // CHECK: vector.insert %{{.*}}, %{{.*}}[3, 3, 3] : f32 into vector<4x8x16xf32> %3 = vector.insert %a, %res[3, 3, 3] : f32 into vector<4x8x16xf32> - return %3 : vector<4x8x16xf32> + // CHECK: vector.insert %{{.*}}, %{{.*}}[] : vector<4x8x16xf32> into vector<4x8x16xf32> + %4 = vector.insert %3, %3[] : vector<4x8x16xf32> into vector<4x8x16xf32> + return %4 : vector<4x8x16xf32> } // CHECK-LABEL: @outerproduct