diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -389,7 +389,12 @@ return false; } -/// Generates buffer for the output tensor. +/// Generates buffer for the output tensor. Note that all sparse kernels +/// assume that when all elements are written to (viz. x(i) = y(i)*z(i)), +/// the output buffer is already initialized to all zeroes and only nonzeroes +/// are computed and written out. For updates (viz. x(i) += y(i)*z(i)), +/// only nonzeroes are updated without any a prior assumption on the original +/// contents of the output buffer. static Value genOutputBuffer(CodeGen &codegen, PatternRewriter &rewriter, linalg::GenericOp op, MemRefType denseTp, ArrayRef args) { @@ -404,7 +409,16 @@ // By default, a new buffer is allocated which is initialized to the // tensor defined in the outs() clause. This is always correct but // introduces a dense initialization component that may negatively - // impact the running complexity of the sparse kernel. + // impact the running complexity of the sparse kernel. If the tensor + // materializes within this method, we need to preserve the zero + // initialization assumption of all sparse output buffers. + if (auto init = tensor.getDefiningOp()) { + Type tp = denseTp.getElementType(); + Value alloc = rewriter.create(loc, denseTp, args); + Value zero = rewriter.create(loc, tp, rewriter.getZeroAttr(tp)); + rewriter.create(loc, zero, alloc); + return alloc; + } Value init = rewriter.create(loc, denseTp, tensor); Value alloc = rewriter.create(loc, denseTp, args); rewriter.create(loc, init, alloc); diff --git a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_1d.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_1d.mlir @@ -43,6 +43,36 @@ return %0 : tensor<32xf32> } +// CHECK-LABEL: func @add_d_init( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>, +// CHECK-SAME: %[[VAL_1:.*]]: f32) -> tensor<32xf32> { +// CHECK: %[[VAL_2:.*]] = constant 32 : index +// CHECK: %[[VAL_3:.*]] = constant 0.000000e+00 : f32 +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant 1 : index +// CHECK: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref +// CHECK: %[[VAL_7:.*]] = memref.alloc() : memref<32xf32> +// CHECK: linalg.fill(%[[VAL_3]], %[[VAL_7]]) : f32, memref<32xf32> +// CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_4]] to %[[VAL_2]] step %[[VAL_5]] { +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_8]]] : memref +// CHECK: %[[VAL_10:.*]] = addf %[[VAL_9]], %[[VAL_1]] : f32 +// CHECK: memref.store %[[VAL_10]], %[[VAL_7]]{{\[}}%[[VAL_8]]] : memref<32xf32> +// CHECK: } +// CHECK: %[[VAL_11:.*]] = memref.tensor_load %[[VAL_7]] : memref<32xf32> +// CHECK: return %[[VAL_11]] : tensor<32xf32> +// CHECK: } +func @add_d_init(%arga: tensor<32xf32, #DV>, %argb: f32) -> tensor<32xf32> { + %u = linalg.init_tensor [32] : tensor<32xf32> + %0 = linalg.generic #trait1 + ins(%arga: tensor<32xf32, #DV>) + outs(%u: tensor<32xf32>) { + ^bb(%a: f32, %x: f32): + %0 = addf %a, %argb : f32 + linalg.yield %0 : f32 + } -> tensor<32xf32> + return %0 : tensor<32xf32> +} + // CHECK-LABEL: func @mul_d( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>, // CHECK-SAME: %[[VAL_1:.*]]: f32,