diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -577,12 +577,12 @@ PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, DeclareOpInterfaceMethods]>, - Arguments<(ins AnyVector:$vector, I64ArrayAttr:$position)>, + Arguments<(ins AnyVectorOfAnyRank:$vector, I64ArrayAttr:$position)>, Results<(outs AnyType)> { let summary = "extract operation"; let description = [{ Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at - the proper position. Degenerates to an element type in the 0-D case. + the proper position. Degenerates to an element type if n-k is zero. Example: @@ -694,13 +694,13 @@ PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, AllTypesMatch<["dest", "res"]>]>, - Arguments<(ins AnyType:$source, AnyVector:$dest, I64ArrayAttr:$position)>, - Results<(outs AnyVector:$res)> { + Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, I64ArrayAttr:$position)>, + Results<(outs AnyVectorOfAnyRank:$res)> { let summary = "insert operation"; let description = [{ Takes an n-D source vector, an (n+k)-D destination vector and a k-D position and inserts the n-D source into the (n+k)-D destination at the proper - position. Degenerates to a scalar source type when n = 0. + position. Can degenerate to a scalar source type when n = 0. Example: diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -219,6 +219,13 @@ return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32 } +// CHECK-LABEL: @extract_0d +func.func @extract_0d(%a: vector) -> f32 { + // CHECK-NEXT: vector.extract %{{.*}}[] : vector + %0 = vector.extract %a[] : vector + return %0 : f32 +} + // CHECK-LABEL: @insert_element_0d func.func @insert_element_0d(%a: f32, %b: vector) -> vector { // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector @@ -248,6 +255,13 @@ return %4 : vector<4x8x16xf32> } +// CHECK-LABEL: @insert_0d +func.func @insert_0d(%a: f32, %b: vector) -> vector { + // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector + %1 = vector.insert %a, %b[] : f32 into vector + return %1 : vector +} + // CHECK-LABEL: @outerproduct func.func @outerproduct(%arg0: vector<4xf32>, %arg1: vector<8xf32>, %arg2: vector<4x8xf32>) -> vector<4x8xf32> { // CHECK: vector.outerproduct {{.*}} : vector<4xf32>, vector<8xf32>