diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -198,8 +198,10 @@ /// static void createAllocFields(OpBuilder &builder, Location loc, Type type, ValueRange dynSizes, bool enableInit, - SmallVectorImpl &fields) { + SmallVectorImpl &fields, Value sizeHint) { RankedTensorType rtp = type.cast(); + SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); + // Build original sizes. SmallVector sizes; auto shape = rtp.getShape(); @@ -211,19 +213,34 @@ sizes.push_back(constantIndex(builder, loc, shape[r])); } - Value heuristic = constantIndex(builder, loc, 16); - Value valHeuristic = heuristic; - SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); + // Set up some heuristic sizes. We try to set the initial + // size based on available information. Otherwise we just + // initialize a few elements to start the reallocation chain. + // TODO: refine this + Value ptrHeuristic, idxHeuristic, valHeuristic; if (enc.isAllDense()) { Value linear = sizes[0]; for (unsigned r = 1; r < rank; r++) { linear = builder.create(loc, linear, sizes[r]); } valHeuristic = linear; + } else if (sizeHint) { + if (getCOOStart(enc) == 0) { + ptrHeuristic = constantIndex(builder, loc, 2); + idxHeuristic = builder.create( + loc, constantIndex(builder, loc, rank), sizeHint); // AOS + } else { + ptrHeuristic = idxHeuristic = constantIndex(builder, loc, 16); + } + valHeuristic = sizeHint; + } else { + ptrHeuristic = idxHeuristic = valHeuristic = + constantIndex(builder, loc, 16); } + foreachFieldAndTypeInSparseTensor( rtp, - [&builder, &fields, rtp, loc, heuristic, valHeuristic, + [&builder, &fields, rtp, loc, ptrHeuristic, idxHeuristic, valHeuristic, enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); @@ -235,11 +252,12 @@ case SparseTensorFieldKind::PtrMemRef: case SparseTensorFieldKind::IdxMemRef: case SparseTensorFieldKind::ValMemRef: - field = createAllocation(builder, loc, fType.cast(), - fKind == SparseTensorFieldKind::ValMemRef - ? valHeuristic - : heuristic, - enableInit); + field = createAllocation( + builder, loc, fType.cast(), + (fKind == SparseTensorFieldKind::PtrMemRef) ? ptrHeuristic + : (fKind == SparseTensorFieldKind::IdxMemRef) ? idxHeuristic + : valHeuristic, + enableInit); break; } assert(field); @@ -691,9 +709,10 @@ // Construct allocation for each field. Location loc = op.getLoc(); + Value sizeHint = op.getSizeHint(); SmallVector fields; createAllocFields(rewriter, loc, resType, adaptor.getOperands(), - enableBufferInitialization, fields); + enableBufferInitialization, fields, sizeHint); // Replace operation with resulting memrefs. rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields)); return success(); diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -363,6 +363,19 @@ return %1 : tensor<10x20x30xf64, #Dense3D> } +// CHECK-LABEL: func.func @sparse_alloc_coo_with_size_hint( +// CHECK-SAME: %[[HINT:.*]]: index) +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[M2:.*]] = arith.muli %[[HINT]], %c2 : index +// CHECK: %[[A1:.*]] = memref.alloc() : memref<2xindex> +// CHECK: %[[A2:.*]] = memref.alloc(%[[M2]]) : memref +// CHECK: %[[A3:.*]] = memref.alloc(%[[HINT]]) : memref +func.func @sparse_alloc_coo_with_size_hint(%arg0: index) -> tensor<10x20xf64, #Coo> { + %0 = bufferization.alloc_tensor() size_hint=%arg0 : tensor<10x20xf64, #Coo> + %1 = sparse_tensor.load %0 : tensor<10x20xf64, #Coo> + return %1 : tensor<10x20xf64, #Coo> +} + // CHECK-LABEL: func.func @sparse_expansion1() // CHECK: %[[A:.*]] = memref.alloc() : memref<8xf64> // CHECK: %[[B:.*]] = memref.alloc() : memref<8xi1>