diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -205,11 +205,30 @@ ValueRange dynSizes, bool enableInit, SmallVectorImpl &fields) { RankedTensorType rtp = type.cast(); - Value heuristic = constantIndex(builder, loc, 16); + // Build original sizes. + SmallVector sizes; + auto shape = rtp.getShape(); + unsigned rank = shape.size(); + for (unsigned r = 0, o = 0; r < rank; r++) { + if (ShapedType::isDynamic(shape[r])) + sizes.push_back(dynSizes[o++]); + else + sizes.push_back(constantIndex(builder, loc, shape[r])); + } + Value heuristic = constantIndex(builder, loc, 16); + Value valHeuristic = heuristic; + SparseTensorEncodingAttr enc = getSparseTensorEncoding(rtp); + if (enc.isAllDense()) { + Value linear = sizes[0]; + for (unsigned r = 1; r < rank; r++) { + linear = builder.create(loc, linear, sizes[r]); + } + valHeuristic = linear; + } foreachFieldAndTypeInSparseTensor( rtp, - [&builder, &fields, rtp, loc, heuristic, + [&builder, &fields, rtp, loc, heuristic, valHeuristic, enableInit](Type fType, unsigned fIdx, SparseTensorFieldKind fKind, unsigned /*dim*/, DimLevelType /*dlt*/) -> bool { assert(fields.size() == fIdx); @@ -222,7 +241,10 @@ case SparseTensorFieldKind::IdxMemRef: case SparseTensorFieldKind::ValMemRef: field = createAllocation(builder, loc, fType.cast(), - heuristic, enableInit); + fKind == SparseTensorFieldKind::ValMemRef + ? valHeuristic + : heuristic, + enableInit); break; } assert(field); @@ -233,16 +255,6 @@ MutSparseTensorDescriptor desc(rtp, fields); - // Build original sizes. - SmallVector sizes; - auto shape = rtp.getShape(); - unsigned rank = shape.size(); - for (unsigned r = 0, o = 0; r < rank; r++) { - if (ShapedType::isDynamic(shape[r])) - sizes.push_back(dynSizes[o++]); - else - sizes.push_back(constantIndex(builder, loc, shape[r])); - } // Initialize the storage scheme to an empty tensor. Initialized memSizes // to all zeros, sets the dimSizes to known values and gives all pointer // fields an initial zero entry, so that it is easier to maintain the diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -345,8 +345,8 @@ // CHECK: %[[A2:.*]] = arith.constant 10 : i64 // CHECK: %[[A3:.*]] = arith.constant 30 : i64 // CHECK: %[[A4:.*]] = arith.constant 0.000000e+00 : f64 -// CHECK: %[[A5:.*]] = memref.alloc() : memref<16xf64> -// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<16xf64> to memref +// CHECK: %[[A5:.*]] = memref.alloc() : memref<6000xf64> +// CHECK: %[[A6:.*]] = memref.cast %[[A5]] : memref<6000xf64> to memref // CHECK: %[[A7:.*]] = sparse_tensor.storage_specifier.init : !sparse_tensor.storage_specifier // CHECK: %[[A8:.*]] = sparse_tensor.storage_specifier.set %[[A7]] dim_sz at 0 with %[[A3]] : i64, !sparse_tensor.storage_specifier // CHECK: %[[A9:.*]] = sparse_tensor.storage_specifier.set %[[A8]] dim_sz at 1 with %[[A2]] : i64, !sparse_tensor.storage_specifier