diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -613,6 +613,61 @@ } }; +/// Sparse codegen rule for the push_back operator. +class SparsePushBackConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(PushBackOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Lower push_back(buffer, value) to: + // if (size(buffer) >= capacity(buffer)) + // new_capacity = capacity(buffer)*2 + // new_buffer = realloc(buffer, new_capacity) + // buffer = new_buffer + // store(buffer, value) + // size(buffer)++ + Location loc = op->getLoc(); + Value c0 = constantIndex(rewriter, loc, 0); + Value buffer = adaptor.getInBuffer(); + Value capacity = rewriter.create(loc, buffer, c0); + Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); + Value bufferSizes = adaptor.getBufferSizes(); + Value size = rewriter.create(loc, bufferSizes, idx); + Value cond = rewriter.create(loc, arith::CmpIPredicate::uge, + size, capacity); + Value value = adaptor.getValue(); + auto bufferType = + MemRefType::get({ShapedType::kDynamicSize}, value.getType()); + scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, + /*else=*/true); + // True branch. + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value c2 = constantIndex(rewriter, loc, 2); + capacity = rewriter.create(loc, capacity, c2); + Value newBuffer = + rewriter.create(loc, bufferType, buffer, capacity); + rewriter.create(loc, newBuffer); + + // False branch. + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, buffer); + + // Add the value to the end of the buffer. + rewriter.setInsertionPointAfter(ifOp); + buffer = ifOp.getResult(0); + rewriter.create(loc, value, buffer, size); + + // Increment the size of the buffer by 1. + Value c1 = constantIndex(rewriter, loc, 1); + size = rewriter.create(loc, size, c1); + rewriter.create(loc, size, bufferSizes, idx); + + rewriter.replaceOp(op, buffer); + return success(); + } +}; + /// Base class for getter-like operations, e.g., to_indices, to_pointers. template class SparseGetterOpConverter : public OpConversionPattern { @@ -697,6 +752,7 @@ SparseCastConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, - SparseToPointersConverter, SparseToIndicesConverter, - SparseToValuesConverter>(typeConverter, patterns.getContext()); + SparsePushBackConverter, SparseToPointersConverter, + SparseToIndicesConverter, SparseToValuesConverter>( + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -384,3 +384,29 @@ : tensor<8x8xf64, #CSR>, memref, memref, memref, memref, index return } + +// CHECK-LABEL: func @sparse_push_back( +// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]] +// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] +// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]] +// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { +// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]] +// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]]) +// CHECK: scf.yield %[[M2]] : memref +// CHECK: } else { +// CHECK: scf.yield %[[B]] : memref +// CHECK: } +// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]] +// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] +// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] +// CHECK: return %[[M]] : memref +func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + return %0 : memref +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s --sparse-compiler=enable-runtime-library=false | \ +// RUN: mlir-cpu-runner \ +// RUN: -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +module { + func.func @entry() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = arith.constant 0.0 : f32 + %d1 = arith.constant 1.0 : f32 + %d2 = arith.constant 2.0 : f32 + + %bufferSizes = memref.alloc(%c1) : memref + %buffer = memref.alloc(%c1) : memref + + memref.store %c0, %bufferSizes[%c0] : memref + %buffer2 = sparse_tensor.push_back %bufferSizes, %buffer, %d2 {idx=0 : index} : memref, memref, f32 to memref + %buffer3 = sparse_tensor.push_back %bufferSizes, %buffer2, %d1 {idx=0 : index} : memref, memref, f32 to memref + + // CHECK: ( 2 ) + %sizeValue = vector.transfer_read %bufferSizes[%c0], %c0: memref, vector<1xindex> + vector.print %sizeValue : vector<1xindex> + + // CHECK ( 2, 1 ) + %bufferValue = vector.transfer_read %buffer3[%c0], %d0: memref, vector<2xf32> + vector.print %bufferValue : vector<2xf32> + + // Release the buffers. + memref.dealloc %bufferSizes : memref + memref.dealloc %buffer3 : memref + return + } +} +