diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -123,7 +123,7 @@ let hasVerifier = 1; } -def SparseTensor_UnpackOp : SparseTensor_Op<"unpack", [Pure]>, +def SparseTensor_UnpackOp : SparseTensor_Op<"unpack">, Arguments<(ins AnySparseTensor:$tensor)>, Results<(outs 1DTensorOf<[AnyType]>:$values, 2DTensorOf<[AnySignlessIntegerOrIndex]>:$coordinates, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -827,7 +827,7 @@ } private: - bool createDeallocs; + const bool createDeallocs; }; /// Sparse codegen rule for tensor rematerialization. @@ -1343,29 +1343,23 @@ break; case SparseTensorFieldKind::PosMemRef: { // TACO-style COO starts with a PosBuffer - // By creating a constant value for it, we avoid the complexity of - // memory management. const auto posTp = stt.getPosType(); if (isCompressedDLT(dlt)) { - RankedTensorType tensorType; - SmallVector posAttr; - tensorType = RankedTensorType::get({batchedCount + 1}, posTp); - posAttr.push_back(IntegerAttr::get(posTp, 0)); - for (unsigned i = 0; i < batchedCount; i++) { + auto memrefType = MemRefType::get({batchedCount + 1}, posTp); + field = rewriter.create(loc, memrefType); + Value c0 = constantIndex(rewriter, loc, 0); + genStore(rewriter, loc, c0, field, c0); + for (unsigned i = 1; i <= batchedCount; i++) { // The postion memref will have values as // [0, nse, 2 * nse, ..., batchedCount * nse] - posAttr.push_back(IntegerAttr::get(posTp, nse * (i + 1))); + Value idx = constantIndex(rewriter, loc, i); + Value val = constantIndex(rewriter, loc, nse * i); + genStore(rewriter, loc, val, field, idx); } - MemRefType memrefType = MemRefType::get( - tensorType.getShape(), tensorType.getElementType()); - auto cstPtr = rewriter.create( - loc, tensorType, DenseElementsAttr::get(tensorType, posAttr)); - field = rewriter.create( - loc, memrefType, cstPtr); } else { assert(isCompressedWithHiDLT(dlt) && !batchDimSzs.empty()); MemRefType posMemTp = MemRefType::get({batchedCount * 2}, posTp); - field = rewriter.create(loc, posMemTp); + field = rewriter.create(loc, posMemTp); populateCompressedWithHiPosArray(rewriter, loc, batchDimSzs, field, nse, op); } @@ -1430,6 +1424,11 @@ struct SparseUnpackOpConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; + SparseUnpackOpConverter(TypeConverter &typeConverter, MLIRContext *context, + bool createDeallocs) + : OpConversionPattern(typeConverter, context), + createDeallocs(createDeallocs) {} + LogicalResult matchAndRewrite(UnpackOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -1443,6 +1442,13 @@ Value flatBuf = lvlRank == 1 ? desc.getCrdMemRefOrView(rewriter, loc, 0) : desc.getAOSMemRef(); Value valuesBuf = desc.getValMemRef(); + Value posBuf = desc.getPosMemRef(0); + if (createDeallocs) { + // Unpack ends the lifetime of the sparse tensor. While the value array + // and coordinate array are unpacked and returned, the position array + // becomes useless and need to be freed (if user requests). + rewriter.create(loc, posBuf); + } // If frontend requests a static buffer, we reallocate the // values/coordinates to ensure that we meet their need. @@ -1474,6 +1480,9 @@ rewriter.replaceOp(op, {values, coordinates, nse}); return success(); } + +private: + const bool createDeallocs; }; struct SparseNewOpConverter : public OpConversionPattern { @@ -1627,11 +1636,11 @@ void mlir::populateSparseTensorCodegenPatterns( TypeConverter &typeConverter, RewritePatternSet &patterns, bool createSparseDeallocs, bool enableBufferInitialization) { - patterns.add, SparseSliceGetterOpConverter(typeConverter, patterns.getContext()); - patterns.add( + patterns.add( typeConverter, patterns.getContext(), createSparseDeallocs); patterns.add(typeConverter, patterns.getContext(), enableBufferInitialization); diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -7,26 +7,29 @@ // CHECK-LABEL: func.func @sparse_pack( // CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf64>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<6x2xi32>) -> (memref, memref, memref, -// CHECK: %[[VAL_2:.*]] = arith.constant dense<[0, 6]> : tensor<2xindex> -// CHECK: %[[VAL_3:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2xindex> -// CHECK: %[[VAL_4:.*]] = memref.cast %[[VAL_3]] : memref<2xindex> to memref -// CHECK: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32> -// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32> -// CHECK: %[[VAL_7:.*]] = memref.cast %[[VAL_6]] : memref<12xi32> to memref -// CHECK: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64> -// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<6xf64> to memref -// CHECK: %[[VAL_10:.*]] = sparse_tensor.storage_specifier.init : -// CHECK: %[[VAL_11:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_12:.*]] = arith.constant 100 : index -// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_10]] lvl_sz at 0 with %[[VAL_12]] +// CHECK-SAME: %[[VAL_1:.*]]: tensor<6x2xi32>) +// CHECK-DAG: %[[VAL_2:.*]] = memref.alloc() : memref<2xindex> +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : index +// CHECK-DAG: memref.store %[[VAL_3]], %[[VAL_2]]{{\[}}%[[VAL_3]]] : memref<2xindex> +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 6 : index +// CHECK-DAG: memref.store %[[VAL_5]], %[[VAL_2]]{{\[}}%[[VAL_4]]] : memref<2xindex> +// CHECK: %[[VAL_6:.*]] = memref.cast %[[VAL_2]] : memref<2xindex> to memref +// CHECK: %[[VAL_7:.*]] = bufferization.to_memref %[[VAL_1]] : memref<6x2xi32> +// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_7]] {{\[\[}}0, 1]] : memref<6x2xi32> into memref<12xi32> +// CHECK: %[[VAL_9:.*]] = memref.cast %[[VAL_8]] : memref<12xi32> to memref +// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_0]] : memref<6xf64> +// CHECK: %[[VAL_11:.*]] = memref.cast %[[VAL_10]] : memref<6xf64> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.storage_specifier.init +// CHECK: %[[VAL_13:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_14:.*]] = sparse_tensor.storage_specifier.set %[[VAL_12]] lvl_sz at 0 with %[[VAL_13]] // CHECK: %[[VAL_15:.*]] = arith.constant 2 : index -// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_15]] -// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] crd_mem_sz at 0 with %[[VAL_11]] -// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] lvl_sz at 1 with %[[VAL_12]] -// CHECK: %[[VAL_21:.*]] = sparse_tensor.storage_specifier.set %[[VAL_20]] crd_mem_sz at 1 with %[[VAL_11]] -// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier.set %[[VAL_21]] val_mem_sz with %[[VAL_11]] -// CHECK: return %[[VAL_4]], %[[VAL_7]], %[[VAL_9]], %[[VAL_22]] : memref, memref, memref, +// CHECK: %[[VAL_16:.*]] = sparse_tensor.storage_specifier.set %[[VAL_14]] pos_mem_sz at 0 with %[[VAL_15]] +// CHECK: %[[VAL_17:.*]] = sparse_tensor.storage_specifier.set %[[VAL_16]] crd_mem_sz at 0 with %[[VAL_5]] +// CHECK: %[[VAL_18:.*]] = sparse_tensor.storage_specifier.set %[[VAL_17]] lvl_sz at 1 with %[[VAL_13]] +// CHECK: %[[VAL_19:.*]] = sparse_tensor.storage_specifier.set %[[VAL_18]] crd_mem_sz at 1 with %[[VAL_5]] +// CHECK: %[[VAL_20:.*]] = sparse_tensor.storage_specifier.set %[[VAL_19]] val_mem_sz with %[[VAL_5]] +// CHECK: return %[[VAL_6]], %[[VAL_9]], %[[VAL_11]], %[[VAL_20]] // CHECK: } func.func @sparse_pack(%values: tensor<6xf64>, %coordinates: tensor<6x2xi32>) -> tensor<100x100xf64, #COO> { @@ -39,9 +42,10 @@ // CHECK-SAME: %[[VAL_0:.*]]: memref, // CHECK-SAME: %[[VAL_1:.*]]: memref, // CHECK-SAME: %[[VAL_2:.*]]: memref, -// CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier -// CHECK: %[[VAL_4:.*]] = arith.constant 6 : index -// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK-SAME: %[[VAL_3:.*]] +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK-DAG: memref.dealloc %[[VAL_0]] : memref // CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref // CHECK: %[[VAL_7:.*]] = arith.cmpi ugt, %[[VAL_4]], %[[VAL_6]] : index // CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) { diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir @@ -98,7 +98,6 @@ vector.print %v: f64 } - %d, %i, %n = sparse_tensor.unpack %s5 : tensor<10x10xf64, #SortedCOOI32> to tensor<3xf64>, tensor<3x2xi32>, i32 @@ -115,6 +114,8 @@ // CHECK-NEXT: 3 vector.print %n : i32 + %d1, %i1, %n1 = sparse_tensor.unpack %s4 : tensor<10x10xf64, #SortedCOO> + to tensor<3xf64>, tensor<3x2xindex>, index return } }