diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -270,7 +270,9 @@ AllTypesMatch<["inBuffer", "outBuffer"]>]>, Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes, StridedMemRefRankOf<[AnyType], [1]>:$inBuffer, - AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>, + AnyType:$value, IndexAttr:$idx, Optional:$n, + + UnitAttr:$inbounds)>, Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> { string summary = "Pushes a value to the back of a given buffer"; string description = [{ @@ -280,6 +282,11 @@ current buffer is full, then `inBuffer.realloc` is called before pushing the data to the buffer. This is similar to std::vector push_back. + The optional input `n` specifies the number of times to repeately push + the value to the back of the tensor. When `n` is a compile-time constant, + its value can't be less than 1. If `n` is a runtime value that is less + than 1, the behavior is undefined. + The `inbounds` attribute tells the compiler that the insertion won't go beyond the current storage buffer. This allows the compiler to not generate the code for capacity check and reallocation. The typical usage will be for @@ -300,10 +307,24 @@ %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val {idx = 0 : index} : memref, memref, f64 ``` + + ```mlir + %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val, %n + {idx = 0 : index} : memref, memref, f64 + ``` }]; let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer" - " `,` $value attr-dict `:` type($bufferSizes) `,`" - " type($inBuffer) `,` type($value)"; + " `,` $value (`,` $n^ )? attr-dict `:`" + " type($bufferSizes) `,` type($inBuffer) `,`" + " type($value) (`,` type($n)^ )?"; + + let builders = [ + //Build an op without input `n`. + OpBuilder<(ins "Type":$outBuffer, "Value":$bufferSizes, "Value":$inBuffer, + "Value":$value, "APInt":$idx)> + ]; + + let hasVerifier = 1; } def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -538,6 +538,22 @@ return success(); } +void PushBackOp::build(OpBuilder &builder, OperationState &result, + Type outBuffer, Value bufferSizes, Value inBuffer, + Value value, APInt idx) { + build(builder, result, outBuffer, bufferSizes, inBuffer, value, idx, Value()); +} + +LogicalResult PushBackOp::verify() { + Value n = getN(); + if (n) { + auto nValue = dyn_cast_or_null(n.getDefiningOp()); + if (nValue && nValue.value() < 1) + return emitOpError("n must be not less than 1"); + } + return success(); +} + LogicalResult CompressOp::verify() { RankedTensorType ttp = getTensor().getType().cast(); if (ttp.getRank() != 1 + static_cast(getIndices().size())) diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -132,6 +132,15 @@ // ----- +func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f32) -> memref { + %c0 = arith.constant 0: index + // expected-error@+1 {{'sparse_tensor.push_back' op n must be not less than 1}} + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %c0 {idx = 2 : index} : memref, memref, f32, index + return %0 : memref +} + +// ----- + func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) { // expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} %values, %filled, %added, %count = sparse_tensor.expand %arg0 diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -171,6 +171,20 @@ // ----- +// CHECK-LABEL: func @sparse_push_back_n( +// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: f64, +// CHECK-SAME: %[[D:.*]]: index) -> memref { +// CHECK: %[[E:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]], %[[D]] {idx = 2 : index} : memref, memref, f64, index +// CHECK: return %[[E]] +func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64, %arg3: index) -> memref { + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2, %arg3 {idx = 2 : index} : memref, memref, f64, index + return %0 : memref +} + +// ----- + #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> // CHECK-LABEL: func @sparse_expansion(