diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -143,6 +143,37 @@ // ones when packing into a COO format. return {{op->getOpResult(0), BufferRelation::Equivalent}}; } + + BufferRelation bufferRelation(Operation *oo, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Unknown; + } +}; + +struct UnpackOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, OpResult opResult) const { + // Similar to InsertOp, reallocation is not considered to allocate a new + // piece of memroy. + return false; + } + + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + AliasingOpResultList getAliasingOpResults(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Conceptually, UnpackOp equals to a list of toIndices/toValueOp + return {}; + } }; struct InsertOpInterface @@ -285,6 +316,8 @@ sparse_tensor::InsertOp::attachInterface(*ctx); sparse_tensor::NumberOfEntriesOp::attachInterface< NumberOfEntriesOpInterface>(*ctx); + sparse_tensor::PackOp::attachInterface(*ctx); + sparse_tensor::UnpackOp::attachInterface(*ctx); sparse_tensor::ToIndicesBufferOp::attachInterface< ToIndicesBufferOpInterface>(*ctx); sparse_tensor::ToIndicesOp::attachInterface(*ctx); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1154,6 +1154,53 @@ } }; +struct SparseUnpackOpConverter : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(UnpackOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto desc = getDescriptorFromTensorTuple(adaptor.getTensor()); + Location loc = op.getLoc(); + int64_t rank = op.getTensor().getType().getRank(); + + assert(isUniqueCOOType(op.getTensor().getType()) && + desc.getFields().size() == 4); + + Value flatBuf = rank == 1 ? desc.getIdxMemRefOrView(rewriter, loc, 0) + : desc.getAOSMemRef(); + Value dataBuf = desc.getValMemRef(); + + // If frontend requests a static buffer, we reallocate the data/indices + // to ensure that we meet their need. + TensorType dataTp = op.getData().getType(); + if (dataTp.hasStaticShape()) { + dataBuf = rewriter.create( + loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()), + dataBuf); + } + + TensorType indicesTp = op.getIndices().getType(); + if (indicesTp.hasStaticShape()) { + auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1]; + flatBuf = rewriter.create( + loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf); + } + + Value idxBuf = rewriter.create( + loc, MemRefType::get(indicesTp.getShape(), indicesTp.getElementType()), + flatBuf, ArrayRef{ReassociationIndices{0, 1}}); + + // Converts MemRefs back to Tensors. + Value data = rewriter.create(loc, dataBuf); + Value indices = rewriter.create(loc, idxBuf); + Value nnz = toType(rewriter, loc, desc.getValMemSize(rewriter, loc), + op.getNnz().getType()); + + rewriter.replaceOp(op, {data, indices, nnz}); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1165,15 +1212,16 @@ void mlir::populateSparseTensorCodegenPatterns( TypeConverter &typeConverter, RewritePatternSet &patterns, bool enableBufferInitialization) { - patterns.add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); patterns.add(typeConverter, patterns.getContext(), enableBufferInitialization); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -37,3 +37,23 @@ to tensor<100x100xf64, #COO> return %0 : tensor<100x100xf64, #COO> } + +// CHECK-LABEL: func.func @sparse_unpack( +// CHECK-SAME: %[[VAL_0:.*]]: memref, +// CHECK-SAME: %[[VAL_1:.*]]: memref, +// CHECK-SAME: %[[VAL_2:.*]]: memref, +// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier +// CHECK: %[[VAL_4:.*]] = memref.realloc %[[VAL_2]] : memref to memref<6xf64> +// CHECK: %[[VAL_5:.*]] = memref.realloc %[[VAL_1]] : memref to memref<12xi32> +// CHECK: %[[VAL_6:.*]] = memref.expand_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32> +// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<6xf64> +// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<6x2xi32> +// CHECK: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] val_mem_sz +// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index +// CHECK: return %[[VAL_7]], %[[VAL_8]], %[[VAL_10]] : tensor<6xf64>, tensor<6x2xi32>, index +// CHECK: } +func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) { + %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO> + to tensor<6xf64>, tensor<6x2xi32>, index + return %d, %i, %nnz : tensor<6xf64>, tensor<6x2xi32>, index +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -36,6 +36,9 @@ // Main driver. // func.func @entry() { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f64 + %i0 = arith.constant 0 : i32 // // Initialize a 3-dim dense tensor. // @@ -95,6 +98,23 @@ vector.print %v: f64 } + + %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32> + to tensor<3xf64>, tensor<3x2xi32>, i32 + + + + // CHECK-NEXT: ( 1, 2, 3 ) + %vd = vector.transfer_read %d[%c0], %f0 : tensor<3xf64>, vector<3xf64> + vector.print %vd : vector<3xf64> + + // CHECK-NEXT: ( ( 1, 2 ), ( 5, 6 ), ( 7, 8 ) ) + %vi = vector.transfer_read %i[%c0, %c0], %i0 : tensor<3x2xi32>, vector<3x2xi32> + vector.print %vi : vector<3x2xi32> + + // CHECK-NEXT: 3 + vector.print %n : i32 + return } }