diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -577,18 +577,19 @@ PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, DeclareOpInterfaceMethods]>, - Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>, + Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>, Results<(outs AnyType)> { let summary = "extract operation"; let description = [{ Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at - the proper position. Degenerates to an element type in the 0-D case. + the proper position. Degenerates to an element type if n-k is zero. Example: ```mlir %1 = vector.extract %0[3]: vector<4x8x16xf32> %2 = vector.extract %0[3, 3, 3]: vector<4x8x16xf32> + %3 = vector.extract %1[]: vector ``` }]; let builders = [ @@ -694,19 +695,21 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>, - Results<(outs AnyVector:$res)> { + Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>, + Results<(outs AnyVectorOfAnyRank:$res)> { let summary = "insert operation"; let description = [{ Takes an n-D source vector, an (n+k)-D destination vector and a k-D position and inserts the n-D source into the (n+k)-D destination at the proper - position. Degenerates to a scalar source type when n = 0. + position. Degenerates to a scalar or a 0-d vector source type when n = 0. Example: ```mlir %2 = vector.insert %0, %1[3] : vector<8x16xf32> into vector<4x8x16xf32> %5 = vector.insert %3, %4[3, 3, 3] : f32 into vector<4x8x16xf32> + %8 = vector.insert %6, %7[] : f32 into vector + %11 = vector.insert %9, %10[3, 3, 3] : vector into vector<4x8x16xf32> ``` }]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1163,8 +1163,7 @@ if (static_cast(op.getPosition().size()) == vectorType.getRank()) { inferredReturnTypes.push_back(vectorType.getElementType()); } else { - auto n = - std::min(op.getPosition().size(), vectorType.getRank() - 1); + auto n = std::min(op.getPosition().size(), vectorType.getRank()); inferredReturnTypes.push_back(VectorType::get( vectorType.getShape().drop_front(n), vectorType.getElementType())); } @@ -2328,7 +2327,7 @@ auto destVectorType = getDestVectorType(); if (positionAttr.size() > static_cast(destVectorType.getRank())) return emitOpError( - "expected position attribute of rank smaller than dest vector rank"); + "expected position attribute of rank no greater than dest vector rank"); auto srcVectorType = llvm::dyn_cast(getSourceType()); if (srcVectorType && (static_cast(srcVectorType.getRank()) + positionAttr.size() != diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -152,6 +152,13 @@ // ----- +func.func @extract_0d(%arg0: vector) { + // expected-error@+1 {{expected position attribute of rank smaller than vector rank}} + %1 = vector.extract %arg0[0] : vector +} + +// ----- + func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) { // expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}} %1 = vector.extract %arg0[0, 0, -1] : vector<4x8x16xf32> @@ -192,7 +199,7 @@ // ----- func.func @insert_vector_type(%a: f32, %b: vector<4x8x16xf32>) { - // expected-error@+1 {{expected position attribute of rank smaller than dest vector rank}} + // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} %1 = vector.insert %a, %b[3, 3, 3, 3, 3, 3] : f32 into vector<4x8x16xf32> } @@ -226,6 +233,20 @@ // ----- +func.func @insert_0d(%a: vector, %b: vector<4x8x16xf32>) { + // expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}} + %1 = vector.insert %a, %b[2, 6] : vector into vector<4x8x16xf32> +} + +// ----- + +func.func @insert_0d(%a: f32, %b: vector) { + // expected-error@+1 {{expected position attribute of rank no greater than dest vector rank}} + %1 = vector.insert %a, %b[0] : f32 into vector +} + +// ----- + func.func @outerproduct_num_operands(%arg0: f32) { // expected-error@+1 {{expected at least 2 operands}} %1 = vector.outerproduct %arg0 : f32, f32 diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -219,6 +219,13 @@ return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32 } +// CHECK-LABEL: @extract_0d +func.func @extract_0d(%a: vector) -> f32 { + // CHECK-NEXT: vector.extract %{{.*}}[] : vector + %0 = vector.extract %a[] : vector + return %0 : f32 +} + // CHECK-LABEL: @insert_element_0d func.func @insert_element_0d(%a: f32, %b: vector) -> vector { // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector @@ -248,6 +255,15 @@ return %4 : vector<4x8x16xf32> } +// CHECK-LABEL: @insert_0d +func.func @insert_0d(%a: f32, %b: vector, %c: vector<2x3xf32>) -> (vector, vector<2x3xf32>) { + // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector + %1 = vector.insert %a, %b[] : f32 into vector + // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector into vector<2x3xf32> + %2 = vector.insert %b, %c[0, 1] : vector into vector<2x3xf32> + return %1, %2 : vector, vector<2x3xf32> +} + // CHECK-LABEL: @outerproduct func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> { // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>