diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -183,6 +183,57 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// InsertOp +//===----------------------------------------------------------------------===// + +def Tensor_InsertOp : Tensor_Op<"insert", + [NoSideEffect, + TypesMatchWith<"result type matches type of dest", + "dest", "result", + "$_self.cast()">, + TypesMatchWith<"scalar type matches element type of dest", + "dest", "scalar", + "$_self.cast().getElementType()">]> { + let summary = "element insertion operation"; + let description = [{ + The `tensor.insert` op writes a tensor into a tensor `dest`as specified by + the operation's indices. + + It returns a copy of `dest` with the proper subtensor updated with the value + of `scalar`. + + The arity of indices must match the rank of the tensor `dest` (i.e., if a + tensor is of rank 3, then 3 indices are required for the extract. The + indices should all be of `index` type. + + Example: + + ```mlir + %4 = tensor.insert %t into %dest[%1, %2] : tensor<4x4xi32> + %5 = tensor.insert %rt into %dest[%1, %2] : tensor + %6 = tensor.insert %ut into %dest[%1, %2] : tensor<*xi32> + ``` + }]; + + let arguments = (ins AnyType:$scalar, + AnyTensor:$dest, + Variadic:$indices); + let results = (outs AnyTensor:$result); + let assemblyFormat = [{ + $scalar `into` $dest `[` $indices `]` attr-dict `:` type($dest) + }]; + + let builders = [ + OpBuilder<(ins "Value":$scalar, "Value":$dest, + CArg<"ValueRange", "{}">:$indices), [{ + auto resType = dest.getType(); + build($_builder, $_state, resType, scalar, dest, indices); + }]>]; + + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -286,6 +286,28 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// InsertOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(InsertOp op) { + // Verify the # indices match if we have a ranked type. + if (auto destType = op.dest().getType().dyn_cast()) + if (destType.getRank() != static_cast(op.indices().size())) + return op.emitOpError("incorrect number of indices"); + return success(); +} + +OpFoldResult InsertOp::fold(ArrayRef operands) { + Attribute scalar = operands[0]; + Attribute dest = operands[1]; + if (scalar && dest) + if (auto splatDest = dest.dyn_cast()) + if (scalar == splatDest.getSplatValue()) + return dest; + return {}; +} + //===----------------------------------------------------------------------===// // GenerateOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -96,6 +96,19 @@ // ----- +// CHECK-LABEL: func @fold_insert +func @fold_insert(%arg0 : index) -> (tensor<4xf32>) { + // Fold an insert into a splat. + // CHECK-DAG: %[[C4:.+]] = constant dense<4.{{0*}}e+00> : tensor<4xf32> + %0 = constant dense<4.0> : tensor<4xf32> + %1 = constant 4.0 : f32 + %ins_1 = tensor.insert %1 into %0[%arg0] : tensor<4xf32> + // CHECK-NEXT: return %[[C4]] + return %ins_1 : tensor<4xf32> +} + +// ----- + // CHECK-LABEL: func @extract_from_tensor.cast // CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32> func @extract_from_tensor.cast(%tensor: tensor<*xf32>) -> f32 { diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -16,6 +16,14 @@ // ----- +func @insert_too_many_indices(%arg0: f32, %arg1: tensor) { + // expected-error@+1 {{incorrect number of indices}} + %0 = tensor.insert %arg0 into %arg1[] : tensor + return +} + +// ----- + func @tensor.from_elements_wrong_result_type() { // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} %c0 = constant 0 : i32 diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -22,6 +22,19 @@ return } +// CHECK-LABEL: func @insert( +// CHECK-SAME: %[[SCALAR:.*]]: f32 +// CHECK-SAME: %[[INDEX:.*]]: index +// CHECK-SAME: %[[DEST1:.*]]: tensor +// CHECK-SAME: %[[DEST2:.*]]: tensor<*xf32> +func @insert(%arg0: f32, %arg1: index, %arg2: tensor, %arg3: tensor<*xf32>) { + // CHECK: tensor.insert %[[SCALAR]] into %[[DEST1]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor + %0 = tensor.insert %arg0 into %arg2[%arg1, %arg1, %arg1] : tensor + // CHECK: tensor.insert %[[SCALAR]] into %[[DEST2]][%[[INDEX]], %[[INDEX]], %[[INDEX]]] : tensor<*xf32> + %1 = tensor.insert %arg0 into %arg3[%arg1, %arg1, %arg1] : tensor<*xf32> + return +} + // CHECK-LABEL: func @tensor.from_elements() { func @tensor.from_elements() { %c0 = "std.constant"() {value = 0: index} : () -> index