diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -240,7 +240,7 @@ def SparseTensor_PushBackOp : SparseTensor_Op<"push_back", []>, Arguments<(ins StridedMemRefRankOf<[Index], [1]>:$bufferSizes, StridedMemRefRankOf<[AnyType], [1]>:$inBuffer, - AnyType:$value, IndexAttr:$idx)>, + AnyType:$value, IndexAttr:$idx, UnitAttr:$inbounds)>, Results<(outs StridedMemRefRankOf<[AnyType], [1]>:$outBuffer)> { string summary = "Pushes a value to the back of a given buffer"; string description = [{ @@ -250,6 +250,11 @@ current buffer is full, then `inBuffer.realloc` is called before pushing the data to the buffer. This is similar to std::vector push_back. + The `inbounds` attribute tells the compiler that the insertion won't go + beyond the current storage buffer. This allows the compiler to not generate + the code for capacity check and reallocation. The typical usage will be for + "dynamic" sparse tensors for which a capacity can be set beforehand. + The operation returns an SSA value for the memref. Referencing the memref through the old SSA value after this operation is undefined behavior. @@ -259,9 +264,14 @@ %r = sparse_tensor.push_back %bufferSizes, %buffer, %val {idx = 0 : index} : memref, memref, f64 -> memref ``` + + ```mlir + %r = sparse_tensor.push_back inbounds %bufferSizes, %buffer, %val + {idx = 0 : index} : memref, memref, f64 -> memref + ``` }]; - let assemblyFormat = "$bufferSizes `,` $inBuffer `,` $value" - " attr-dict `:` type($bufferSizes) `,`" + let assemblyFormat = "(`inbounds` $inbounds^)? $bufferSizes `,` $inBuffer" + " `,` $value attr-dict `:` type($bufferSizes) `,`" " type($inBuffer) `,` type($value) `to`" " type($outBuffer)"; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -350,6 +350,8 @@ // buffer = new_buffer // store(buffer, value) // size(buffer)++ + // + // The capacity check is skipped when the attribute inbounds is presented. Location loc = op->getLoc(); Value c0 = constantIndex(rewriter, loc, 0); Value buffer = op.getInBuffer(); @@ -357,28 +359,34 @@ Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); Value bufferSizes = op.getBufferSizes(); Value size = rewriter.create(loc, bufferSizes, idx); - Value cond = rewriter.create(loc, arith::CmpIPredicate::uge, - size, capacity); Value value = op.getValue(); - auto bufferType = - MemRefType::get({ShapedType::kDynamicSize}, value.getType()); - scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, - /*else=*/true); - // True branch. - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value c2 = constantIndex(rewriter, loc, 2); - capacity = rewriter.create(loc, capacity, c2); - Value newBuffer = - rewriter.create(loc, bufferType, buffer, capacity); - rewriter.create(loc, newBuffer); - - // False branch. - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - rewriter.create(loc, buffer); + + if (!op.getInbounds()) { + Value cond = rewriter.create( + loc, arith::CmpIPredicate::uge, size, capacity); + + auto bufferType = + MemRefType::get({ShapedType::kDynamicSize}, value.getType()); + scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, + /*else=*/true); + // True branch. + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value c2 = constantIndex(rewriter, loc, 2); + capacity = rewriter.create(loc, capacity, c2); + Value newBuffer = + rewriter.create(loc, bufferType, buffer, capacity); + rewriter.create(loc, newBuffer); + + // False branch. + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, buffer); + + // Prepare for adding the value to the end of the buffer. + rewriter.setInsertionPointAfter(ifOp); + buffer = ifOp.getResult(0); + } // Add the value to the end of the buffer. - rewriter.setInsertionPointAfter(ifOp); - buffer = ifOp.getResult(0); rewriter.create(loc, value, buffer, size); // Increment the size of the buffer by 1. diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -26,6 +26,22 @@ return %0 : memref } +// CHECK-LABEL: func @sparse_push_back_inbound( +// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] +// CHECK: memref.store %[[C]], %[[B]]{{\[}}%[[P]]] +// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] +// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] +// CHECK: return %[[B]] : memref +func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { + %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + return %0 : memref +} + // CHECK-LABEL: func.func private @_sparse_less_than_1_i8( // CHECK-SAME: %[[I:arg0]]: index, // CHECK-SAME: %[[J:.*]]: index, diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -145,6 +145,19 @@ // ----- +// CHECK-LABEL: func @sparse_push_back_inbound( +// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK: %[[D:.*]] = sparse_tensor.push_back inbounds %[[A]], %[[B]], %[[C]] {idx = 2 : index} : memref, memref, f64 to memref +// CHECK: return %[[D]] +func.func @sparse_push_back_inbound(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { + %0 = sparse_tensor.push_back inbounds %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + return %0 : memref +} + +// ----- + #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> // CHECK-LABEL: func @sparse_expansion(