diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -270,7 +270,11 @@ AllTypesMatch<["inBuffer", "outBuffer"]>]>, Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes, StridedMemRefRankOf<[AnyType], [1]>:$inBuffer, - AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>, + AnyType:$value, IndexAttr:$idx, + ConfinedAttr, + [IntMinValue<2>]>:$n, + + UnitAttr:$inbounds)>, Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> { string summary = "Pushes a value to the back of a given buffer"; string description = [{ @@ -280,6 +284,9 @@ current buffer is full, then `inBuffer.realloc` is called before pushing the data to the buffer. This is similar to std::vector push_back. + The optional attribute `n` specifies the number of times to repeately push + the value to the back of the tensor. + The `inbounds` attribute tells the compiler that the insertion won't go beyond the current storage buffer. This allows the compiler to not generate the code for capacity check and reallocation. The typical usage will be for @@ -300,10 +307,20 @@ %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val {idx = 0 : index} : memref, memref, f64 ``` + + ```mlir + %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val + {idx = 0 : index, n = 1000 : index} : memref, memref, f64 + ``` }]; let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer" " `,` $value attr-dict `:` type($bufferSizes) `,`" " type($inBuffer) `,` type($value)"; + + let builders = [ + // Build an op without `n`. + OpBuilder<(ins "Type":$outBuffer, "Value":$bufferSizes, "Value":$inBuffer, "Value":$value, "APInt":$idx)>, + ]; } def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -538,6 +538,13 @@ return success(); } +void PushBackOp::build(OpBuilder &builder, OperationState &result, + Type outBuffer, Value bufferSizes, Value inBuffer, + Value value, APInt idx) { + build(builder, result, outBuffer, bufferSizes, inBuffer, value, idx, + /*n=*/IntegerAttr()); +} + LogicalResult CompressOp::verify() { RankedTensorType ttp = getTensor().getType().cast(); if (ttp.getRank() != 1 + static_cast(getIndices().size())) diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -132,6 +132,14 @@ // ----- +func.func @sparse_push_back_n_invalid(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { + // expected-error@+1 {{'sparse_tensor.push_back' op attribute 'n' failed to satisfy constraint: index attribute whose minimum value is 2}} + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 0 : index, n = 1 : index}: memref, memref, f64 + return %0 : memref +} + +// ----- + func.func @sparse_unannotated_expansion(%arg0: tensor<128xf64>) { // expected-error@+1 {{'sparse_tensor.expand' op operand #0 must be sparse tensor of any type values, but got 'tensor<128xf64>'}} %values, %filled, %added, %count = sparse_tensor.expand %arg0 diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -171,6 +171,19 @@ // ----- +// CHECK-LABEL: func @sparse_push_back_n( +// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK: %[[D:.*]] = sparse_tensor.push_back %[[A]], %[[B]], %[[C]] {idx = 2 : index, n = 100 : index} : memref, memref, f64 +// CHECK: return %[[D]] +func.func @sparse_push_back_n(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index, n = 100 : index}: memref, memref, f64 + return %0 : memref +} + +// ----- + #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> // CHECK-LABEL: func @sparse_expansion(