diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -666,16 +666,18 @@ "result", "source", "$_self.cast().getElementType()">, AllTypesMatch<["dest", "result"]>]>, - Arguments<(ins AnyType:$source, AnyVector:$dest, - AnySignlessIntegerOrIndex:$position)>, - Results<(outs AnyVector:$result)> { + Arguments<(ins AnyType:$source, AnyVectorOfAnyRank:$dest, + Optional:$position)>, + Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "insertelement operation"; let description = [{ - Takes a scalar source, an 1-D destination vector and a dynamic index - position and inserts the source into the destination at the proper - position. Note that this instruction resembles vector.insert, but - is restricted to 1-D vectors and relaxed to dynamic indices. It is - meant to be closer to LLVM's version: + Takes a scalar source, a 0-D or 1-D destination vector and a dynamic index + position and inserts the source into the destination at the proper position. + + Note that this instruction resembles vector.insert, but is restricted to 0-D + and 1-D vectors and relaxed to dynamic indices. + + It is meant to be closer to LLVM's version: https://llvm.org/docs/LangRef.html#insertelement-instruction Example: @@ -684,14 +686,18 @@ %c = arith.constant 15 : i32 %f = arith.constant 0.0f : f32 %1 = vector.insertelement %f, %0[%c : i32]: vector<16xf32> + %2 = vector.insertelement %f, %z[]: vector ``` }]; let assemblyFormat = [{ - $source `,` $dest `[` $position `:` type($position) `]` attr-dict `:` + $source `,` $dest `[` ($position^ `:` type($position))? `]` attr-dict `:` type($result) }]; let builders = [ + // 0-D builder. + OpBuilder<(ins "Value":$source, "Value":$dest)>, + // 1-D + position builder. OpBuilder<(ins "Value":$source, "Value":$dest, "Value":$position)> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -663,6 +663,17 @@ if (!llvmType) return failure(); + if (vectorType.getRank() == 0) { + Location loc = insertEltOp.getLoc(); + auto idxType = rewriter.getIndexType(); + auto zero = rewriter.create( + loc, typeConverter->convertType(idxType), + rewriter.getIntegerAttr(idxType, 0)); + rewriter.replaceOpWithNewOp( + insertEltOp, llvmType, adaptor.dest(), adaptor.source(), zero); + return success(); + } + rewriter.replaceOpWithNewOp( insertEltOp, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1553,6 +1553,12 @@ // InsertElementOp //===----------------------------------------------------------------------===// +void InsertElementOp::build(OpBuilder &builder, OperationState &result, + Value source, Value dest) { + result.addOperands({source, dest}); + result.addTypes(dest.getType()); +} + void InsertElementOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, Value position) { result.addOperands({source, dest, position}); @@ -1561,8 +1567,15 @@ static LogicalResult verify(InsertElementOp op) { auto dstVectorType = op.getDestVectorType(); + if (dstVectorType.getRank() == 0) { + if (op.position()) + return op.emitOpError("expected position to be empty with 0-D vector"); + return success(); + } if (dstVectorType.getRank() != 1) - return op.emitOpError("expected 1-D vector"); + return op.emitOpError("unexpected >1 vector rank"); + if (!op.position()) + return op.emitOpError("expected position for 1-D vector"); return success(); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -512,6 +512,19 @@ // ----- +// CHECK-LABEL: @insert_element_0d +// CHECK-SAME: %[[A:.*]]: f32, +func @insert_element_0d(%a: f32, %b: vector) -> vector { + // CHECK: %[[B:.*]] = builtin.unrealized_conversion_cast %{{.*}} : + // CHECK: vector to vector<1xf32> + // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 + // CHECK: %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[C0]] : {{.*}}] : vector<1xf32> + %1 = vector.insertelement %a, %b[] : vector + return %1 : vector +} + +// ----- + func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> { %0 = arith.constant 3 : i32 %1 = vector.insertelement %arg0, %arg1[%0 : i32] : vector<4xf32> diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -79,7 +79,7 @@ } // ----- - + func @extract_element(%arg0: vector<4xf32>) { %c = arith.constant 3 : i32 // expected-error@+1 {{expected position for 1-D vector}} @@ -138,9 +138,25 @@ // ----- +func @insert_element(%arg0: f32, %arg1: vector) { + %c = arith.constant 3 : i32 + // expected-error@+1 {{expected position to be empty with 0-D vector}} + %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector +} + +// ----- + +func @insert_element(%arg0: f32, %arg1: vector<4xf32>) { + %c = arith.constant 3 : i32 + // expected-error@+1 {{expected position for 1-D vector}} + %0 = vector.insertelement %arg0, %arg1[] : vector<4xf32> +} + +// ----- + func @insert_element(%arg0: f32, %arg1: vector<4x4xf32>) { %c = arith.constant 3 : i32 - // expected-error@+1 {{'vector.insertelement' op expected 1-D vector}} + // expected-error@+1 {{unexpected >1 vector rank}} %0 = vector.insertelement %arg0, %arg1[%c : i32] : vector<4x4xf32> } diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -192,6 +192,13 @@ return %0, %1, %2, %3 : vector<4x8x16xf32>, vector<8x16xf32>, vector<16xf32>, f32 } +// CHECK-LABEL: @insert_element_0d +func @insert_element_0d(%a: f32, %b: vector) -> vector { + // CHECK-NEXT: vector.insertelement %{{.*}}, %{{.*}}[] : vector + %1 = vector.insertelement %a, %b[] : vector + return %1 : vector +} + // CHECK-LABEL: @insert_element func @insert_element(%a: f32, %b: vector<16xf32>) -> vector<16xf32> { // CHECK: %[[C15:.*]] = arith.constant 15 : i32 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-0-d-vectors.mlir @@ -10,8 +10,15 @@ return } +func @insert_element_0d(%a: f32, %b: vector) -> (vector) { + %1 = vector.insertelement %a, %b[] : vector + return %1: vector +} + func @entry() { - %1 = arith.constant dense<42.0> : vector - call @extract_element_0d(%1) : (vector) -> () + %0 = arith.constant 42.0 : f32 + %1 = arith.constant dense<0.0> : vector + %2 = call @insert_element_0d(%0, %1) : (f32, vector) -> (vector) + call @extract_element_0d(%2) : (vector) -> () return }