diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -302,6 +302,16 @@ assert(fields.size() == lastField); } +/// Creates a straightforward counting for-loop. +static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count) { + Type indexType = builder.getIndexType(); + Value zero = constantZero(builder, loc, indexType); + Value one = constantOne(builder, loc, indexType); + scf::ForOp forOp = builder.create(loc, zero, count, one); + builder.setInsertionPointToStart(forOp.getBody()); + return forOp; +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -518,12 +528,12 @@ auto memTp = MemRefType::get({ShapedType::kDynamicSize}, t); return rewriter.create(loc, memTp, ValueRange{*sz}); }; - // Allocate temporary buffers for values, filled-switch, and indices. + // Allocate temporary buffers for values/filled-switch and added. // We do not use stack buffers for this, since the expanded size may // be rather large (as it envelops a single expanded dense dimension). Value values = genAlloc(eltType); Value filled = genAlloc(boolType); - Value indices = genAlloc(idxType); + Value added = genAlloc(idxType); Value zero = constantZero(rewriter, loc, idxType); // Reset the values/filled-switch to all-zero/false. Note that this // introduces an O(N) operation into the computation, but this reset @@ -543,6 +553,66 @@ } }; +/// Sparse codegen rule for the compress operator. +class SparseCompressConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(CompressOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + ShapedType srcType = op.getTensor().getType().cast(); + Type eltType = srcType.getElementType(); + Value values = adaptor.getValues(); + Value filled = adaptor.getFilled(); + Value added = adaptor.getAdded(); + Value count = adaptor.getCount(); + + // + // TODO: need to implement "std::sort(added, added + count);" for ordered + // + + // While performing the insertions, we also need to reset the elements + // of the values/filled-switch by only iterating over the set elements, + // to ensure that the runtime complexity remains proportional to the + // sparsity of the expanded access pattern. + // + // Generate + // for (i = 0; i < count; i++) { + // index = added[i]; + // value = values[index]; + // + // TODO: insert prev_indices, index, value + // + // values[index] = 0; + // filled[index] = false; + // } + Value i = createFor(rewriter, loc, count).getInductionVar(); + Value index = rewriter.create(loc, added, i); + rewriter.create(loc, values, index); + // TODO: insert + rewriter.create(loc, constantZero(rewriter, loc, eltType), + values, index); + rewriter.create(loc, constantI1(rewriter, loc, false), + filled, index); + + // Deallocate the buffers on exit of the full loop nest. + Operation *parent = op; + for (; isa(parent->getParentOp()) || + isa(parent->getParentOp()) || + isa(parent->getParentOp()) || + isa(parent->getParentOp()); + parent = parent->getParentOp()) + ; + rewriter.setInsertionPointAfter(parent); + rewriter.create(loc, values); + rewriter.create(loc, filled); + rewriter.create(loc, added); + rewriter.eraseOp(op); + return success(); + } +}; + /// Base class for getter-like operations, e.g., to_indices, to_pointers. template class SparseGetterOpConverter : public OpConversionPattern { @@ -626,7 +696,7 @@ patterns.add( - typeConverter, patterns.getContext()); + SparseExpandConverter, SparseCompressConverter, + SparseToPointersConverter, SparseToIndicesConverter, + SparseToValuesConverter>(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -347,3 +347,40 @@ : tensor to memref, memref, memref, index return %added : memref } + +// CHECK-LABEL: func @sparse_compression( +// CHECK-SAME: %[[A0:.*0]]: memref<2xindex>, +// CHECK-SAME: %[[A1:.*1]]: memref<3xindex>, +// CHECK-SAME: %[[A2:.*2]]: memref, +// CHECK-SAME: %[[A3:.*3]]: memref, +// CHECK-SAME: %[[A4:.*4]]: memref, +// CHECK-SAME: %[[A5:.*5]]: memref, +// CHECK-SAME: %[[A6:.*6]]: memref, +// CHECK-SAME: %[[A7:.*7]]: memref, +// CHECK-SAME: %[[A8:.*8]]: memref, +// CHECK-SAME: %[[A9:.*9]]: index) +// CHECK-DAG: %[[B0:.*]] = arith.constant false +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// TODO: sort +// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[A9]] step %[[C1]] { +// CHECK-NEXT: %[[INDEX:.*]] = memref.load %[[A8]][%[[I]]] : memref +// TODO: insert +// CHECK-DAG: memref.store %[[F0]], %[[A6]][%[[INDEX]]] : memref +// CHECK-DAG: memref.store %[[B0]], %[[A7]][%[[INDEX]]] : memref +// CHECK-NEXT: } +// CHECK-DAG: memref.dealloc %[[A6]] : memref +// CHECK-DAG: memref.dealloc %[[A7]] : memref +// CHECK-DAG: memref.dealloc %[[A8]] : memref +// CHECK: return +func.func @sparse_compression(%arg0: tensor<8x8xf64, #CSR>, + %arg1: memref, + %arg2: memref, + %arg3: memref, + %arg4: memref, + %arg5: index) { + sparse_tensor.compress %arg0, %arg1, %arg2, %arg3, %arg4, %arg5 + : tensor<8x8xf64, #CSR>, memref, memref, memref, memref, index + return +}