diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -575,6 +575,34 @@ } } +/// Returns a memref that fits the requested length (reallocates if requested +/// length is larger, or creates a subview if it is smaller). +static Value reallocOrSubView(OpBuilder &builder, Location loc, int64_t len, + Value buffer) { + MemRefType memTp = getMemRefType(buffer); + auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType()); + + Value targetLen = constantIndex(builder, loc, len); + Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0); + Value reallocP = builder.create(loc, arith::CmpIPredicate::ult, + targetLen, bufferLen); + scf::IfOp ifOp = builder.create(loc, retTp, reallocP, true); + // If targetLen > bufferLen, reallocate to get enough sparse to return. + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + Value reallocBuf = builder.create(loc, retTp, buffer); + builder.create(loc, reallocBuf); + // Else, return a subview to fit the size. + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value subViewBuf = builder.create( + loc, retTp, buffer, /*offset=*/ArrayRef{0}, + /*size=*/ArrayRef{len}, + /*stride=*/ArrayRef{1}); + builder.create(loc, subViewBuf); + // Resets insertion point. + builder.setInsertionPointAfter(ifOp); + return ifOp.getResult(0); +} + //===----------------------------------------------------------------------===// // Codegen rules. //===----------------------------------------------------------------------===// @@ -1174,16 +1202,13 @@ // to ensure that we meet their need. TensorType dataTp = op.getData().getType(); if (dataTp.hasStaticShape()) { - dataBuf = rewriter.create( - loc, MemRefType::get(dataTp.getShape(), dataTp.getElementType()), - dataBuf); + dataBuf = reallocOrSubView(rewriter, loc, dataTp.getShape()[0], dataBuf); } TensorType indicesTp = op.getIndices().getType(); if (indicesTp.hasStaticShape()) { auto len = indicesTp.getShape()[0] * indicesTp.getShape()[1]; - flatBuf = rewriter.create( - loc, MemRefType::get({len}, indicesTp.getElementType()), flatBuf); + flatBuf = reallocOrSubView(rewriter, loc, len, flatBuf); } Value idxBuf = rewriter.create( diff --git a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_pack.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_pack.mlir @@ -43,14 +43,33 @@ // CHECK-SAME: %[[VAL_1:.*]]: memref, // CHECK-SAME: %[[VAL_2:.*]]: memref, // CHECK-SAME: %[[VAL_3:.*]]: !sparse_tensor.storage_specifier -// CHECK: %[[VAL_4:.*]] = memref.realloc %[[VAL_2]] : memref to memref<6xf64> -// CHECK: %[[VAL_5:.*]] = memref.realloc %[[VAL_1]] : memref to memref<12xi32> -// CHECK: %[[VAL_6:.*]] = memref.expand_shape %[[VAL_5]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32> -// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<6xf64> -// CHECK: %[[VAL_8:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<6x2xi32> -// CHECK: %[[VAL_9:.*]] = sparse_tensor.storage_specifier.get %[[VAL_3]] val_mem_sz -// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : i32 to index -// CHECK: return %[[VAL_7]], %[[VAL_8]], %[[VAL_10]] : tensor<6xf64>, tensor<6x2xi32>, index +// CHECK: %[[VAL_4:.*]] = arith.constant 6 : index +// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_2]], %[[VAL_5]] : memref +// CHECK: %[[VAL_7:.*]] = arith.cmpi ult, %[[VAL_4]], %[[VAL_6]] : index +// CHECK: %[[VAL_8:.*]] = scf.if %[[VAL_7]] -> (memref<6xf64>) { +// CHECK: %[[VAL_9:.*]] = memref.realloc %[[VAL_2]] : memref to memref<6xf64> +// CHECK: scf.yield %[[VAL_9]] : memref<6xf64> +// CHECK: } else { +// CHECK: %[[VAL_10:.*]] = memref.subview %[[VAL_2]][0] [6] [1] : memref to memref<6xf64> +// CHECK: scf.yield %[[VAL_10]] : memref<6xf64> +// CHECK: } +// CHECK: %[[VAL_11:.*]] = arith.constant 12 : index +// CHECK: %[[VAL_12:.*]] = memref.dim %[[VAL_1]], %[[VAL_5]] : memref +// CHECK: %[[VAL_13:.*]] = arith.cmpi ult, %[[VAL_11]], %[[VAL_12]] : index +// CHECK: %[[VAL_14:.*]] = scf.if %[[VAL_13]] -> (memref<12xi32>) { +// CHECK: %[[VAL_15:.*]] = memref.realloc %[[VAL_1]] : memref to memref<12xi32> +// CHECK: scf.yield %[[VAL_15]] : memref<12xi32> +// CHECK: } else { +// CHECK: %[[VAL_16:.*]] = memref.subview %[[VAL_1]][0] [12] [1] : memref to memref<12xi32> +// CHECK: scf.yield %[[VAL_16]] : memref<12xi32> +// CHECK: } +// CHECK: %[[VAL_17:.*]] = memref.expand_shape %[[VAL_18:.*]] {{\[\[}}0, 1]] : memref<12xi32> into memref<6x2xi32> +// CHECK: %[[VAL_19:.*]] = bufferization.to_tensor %[[VAL_20:.*]] : memref<6xf64> +// CHECK: %[[VAL_21:.*]] = bufferization.to_tensor %[[VAL_17]] : memref<6x2xi32> +// CHECK: %[[VAL_22:.*]] = sparse_tensor.storage_specifier +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index +// CHECK: return %[[VAL_19]], %[[VAL_21]], %[[VAL_23]] : tensor<6xf64>, tensor<6x2xi32>, index // CHECK: } func.func @sparse_unpack(%sp: tensor<100x100xf64, #COO>) -> (tensor<6xf64>, tensor<6x2xi32>, index) { %d, %i, %nnz = sparse_tensor.unpack %sp : tensor<100x100xf64, #COO>