diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -337,6 +337,60 @@ namespace { +/// Sparse rewriting rule for the push_back operator. +struct PushBackRewriter : OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PushBackOp op, + PatternRewriter &rewriter) const override { + // Rewrite push_back(buffer, value) to: + // if (size(buffer) >= capacity(buffer)) + // new_capacity = capacity(buffer)*2 + // new_buffer = realloc(buffer, new_capacity) + // buffer = new_buffer + // store(buffer, value) + // size(buffer)++ + Location loc = op->getLoc(); + Value c0 = constantIndex(rewriter, loc, 0); + Value buffer = op.getInBuffer(); + Value capacity = rewriter.create(loc, buffer, c0); + Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); + Value bufferSizes = op.getBufferSizes(); + Value size = rewriter.create(loc, bufferSizes, idx); + Value cond = rewriter.create(loc, arith::CmpIPredicate::uge, + size, capacity); + Value value = op.getValue(); + auto bufferType = + MemRefType::get({ShapedType::kDynamicSize}, value.getType()); + scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, + /*else=*/true); + // True branch. + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value c2 = constantIndex(rewriter, loc, 2); + capacity = rewriter.create(loc, capacity, c2); + Value newBuffer = + rewriter.create(loc, bufferType, buffer, capacity); + rewriter.create(loc, newBuffer); + + // False branch. + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, buffer); + + // Add the value to the end of the buffer. + rewriter.setInsertionPointAfter(ifOp); + buffer = ifOp.getResult(0); + rewriter.create(loc, value, buffer, size); + + // Increment the size of the buffer by 1. + Value c1 = constantIndex(rewriter, loc, 1); + size = rewriter.create(loc, size, c1); + rewriter.create(loc, size, bufferSizes, idx); + + rewriter.replaceOp(op, buffer); + return success(); + } +}; + /// Sparse rewriting rule for the sort operator. struct SortRewriter : public OpRewritePattern { public: @@ -378,5 +432,5 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseBufferRewriting(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -564,61 +564,6 @@ } }; -/// Sparse codegen rule for the push_back operator. -class SparsePushBackConverter : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(PushBackOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // Lower push_back(buffer, value) to: - // if (size(buffer) >= capacity(buffer)) - // new_capacity = capacity(buffer)*2 - // new_buffer = realloc(buffer, new_capacity) - // buffer = new_buffer - // store(buffer, value) - // size(buffer)++ - Location loc = op->getLoc(); - Value c0 = constantIndex(rewriter, loc, 0); - Value buffer = adaptor.getInBuffer(); - Value capacity = rewriter.create(loc, buffer, c0); - Value idx = constantIndex(rewriter, loc, op.getIdx().getZExtValue()); - Value bufferSizes = adaptor.getBufferSizes(); - Value size = rewriter.create(loc, bufferSizes, idx); - Value cond = rewriter.create(loc, arith::CmpIPredicate::uge, - size, capacity); - Value value = adaptor.getValue(); - auto bufferType = - MemRefType::get({ShapedType::kDynamicSize}, value.getType()); - scf::IfOp ifOp = rewriter.create(loc, bufferType, cond, - /*else=*/true); - // True branch. - rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); - Value c2 = constantIndex(rewriter, loc, 2); - capacity = rewriter.create(loc, capacity, c2); - Value newBuffer = - rewriter.create(loc, bufferType, buffer, capacity); - rewriter.create(loc, newBuffer); - - // False branch. - rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); - rewriter.create(loc, buffer); - - // Add the value to the end of the buffer. - rewriter.setInsertionPointAfter(ifOp); - buffer = ifOp.getResult(0); - rewriter.create(loc, value, buffer, size); - - // Increment the size of the buffer by 1. - Value c1 = constantIndex(rewriter, loc, 1); - size = rewriter.create(loc, size, c1); - rewriter.create(loc, size, bufferSizes, idx); - - rewriter.replaceOp(op, buffer); - return success(); - } -}; - /// Base class for getter-like operations, e.g., to_indices, to_pointers. template class SparseGetterOpConverter : public OpConversionPattern { @@ -703,7 +648,6 @@ SparseCastConverter, SparseTensorAllocConverter, SparseTensorDeallocConverter, SparseTensorLoadConverter, SparseExpandConverter, SparseCompressConverter, - SparsePushBackConverter, SparseToPointersConverter, - SparseToIndicesConverter, SparseToValuesConverter>( - typeConverter, patterns.getContext()); + SparseToPointersConverter, SparseToIndicesConverter, + SparseToValuesConverter>(typeConverter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -160,6 +160,7 @@ // Most ops in the sparse dialect must go! target.addIllegalDialect(); target.addLegalOp(); + target.addLegalOp(); // All dynamic rules below accept new function, call, return, and various // tensor and bufferization operations as legal output of the rewriting // provided that all sparse tensor types have been fully rewritten. diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -1,5 +1,31 @@ // RUN: mlir-opt %s --sparse-buffer-rewrite --canonicalize --cse | FileCheck %s +// CHECK-LABEL: func @sparse_push_back( +// CHECK-SAME: %[[A:.*]]: memref, +// CHECK-SAME: %[[B:.*]]: memref, +// CHECK-SAME: %[[C:.*]]: f64) -> memref { +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]] +// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] +// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]] +// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { +// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]] +// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]]) +// CHECK: scf.yield %[[M2]] : memref +// CHECK: } else { +// CHECK: scf.yield %[[B]] : memref +// CHECK: } +// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]] +// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] +// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] +// CHECK: return %[[M]] : memref +func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { + %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref + return %0 : memref +} + // CHECK-LABEL: func.func private @_sparse_less_than_1_i8( // CHECK-SAME: %[[I:arg0]]: index, // CHECK-SAME: %[[J:.*]]: index, diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -425,29 +425,3 @@ : memref, memref, memref, tensor<8x8xf64, #UCSR> return } - -// CHECK-LABEL: func @sparse_push_back( -// CHECK-SAME: %[[A:.*]]: memref, -// CHECK-SAME: %[[B:.*]]: memref, -// CHECK-SAME: %[[C:.*]]: f64) -> memref { -// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[S:.*]] = memref.dim %[[B]], %[[C0]] -// CHECK: %[[P:.*]] = memref.load %[[A]]{{\[}}%[[C2]]] -// CHECK: %[[T:.*]] = arith.cmpi uge, %[[P]], %[[S]] -// CHECK: %[[M:.*]] = scf.if %[[T]] -> (memref) { -// CHECK: %[[P1:.*]] = arith.muli %[[S]], %[[C2]] -// CHECK: %[[M2:.*]] = memref.realloc %[[B]](%[[P1]]) -// CHECK: scf.yield %[[M2]] : memref -// CHECK: } else { -// CHECK: scf.yield %[[B]] : memref -// CHECK: } -// CHECK: memref.store %[[C]], %[[M]]{{\[}}%[[P]]] -// CHECK: %[[P2:.*]] = arith.addi %[[P]], %[[C1]] -// CHECK: memref.store %[[P2]], %[[A]]{{\[}}%[[C2]]] -// CHECK: return %[[M]] : memref -func.func @sparse_push_back(%arg0: memref, %arg1: memref, %arg2: f64) -> memref { - %0 = sparse_tensor.push_back %arg0, %arg1, %arg2 {idx = 2 : index} : memref, memref, f64 to memref - return %0 : memref -} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir rename from mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_push_back.mlir rename to mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_rewrite_push_back.mlir