diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -434,6 +434,8 @@ return hasOutput && tid == tensors.size() - 1; } + bool isSparseOutput(size_t tid) { return isOutputTensor(tid) && isSparseOut; } + /// Setups [lo, hi] for iterating tensor[dim], it assumes that tensor[0 /// ...dims-1] has already been setup. void prepareLoopOverTensorAtDim(OpBuilder &builder, Location loc, size_t tid, @@ -462,6 +464,7 @@ // Whether the loop emitter needs to treat the last tensor as the output // tensor. bool hasOutput; + bool isSparseOut; /// Input and (optional) output tensors. std::vector tensors; /// The dim type array for each tensor. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -97,10 +97,11 @@ SparseTensorLoopEmitter::SparseTensorLoopEmitter(ValueRange tensors, bool hasOutput, bool isSparseOut) - : hasOutput(hasOutput), tensors(tensors.begin(), tensors.end()), - dimTypes(tensors.size()), pidxs(tensors.size()), coord(tensors.size()), - highs(tensors.size()), ptrBuffer(tensors.size()), - idxBuffer(tensors.size()), valBuffer(tensors.size()), loopStack() { + : hasOutput(hasOutput), isSparseOut(isSparseOut), + tensors(tensors.begin(), tensors.end()), dimTypes(tensors.size()), + pidxs(tensors.size()), coord(tensors.size()), highs(tensors.size()), + ptrBuffer(tensors.size()), idxBuffer(tensors.size()), + valBuffer(tensors.size()), loopStack() { for (size_t tid = 0, e = tensors.size(); tid < e; tid++) { auto t = tensors[tid]; // a scalar or 0-dimension tensors @@ -246,7 +247,7 @@ coord[tid][dim] = iv; // generate pidx for dense dim (pidx = i * sz + j) auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc) + if (enc && !isSparseOutput(tid)) pidxs[tid][dim] = genAddress(builder, loc, tid, dim, iv); } @@ -353,7 +354,7 @@ pidxs[tid][dim] = min; // generate pidx for dense dim (pidx = i * sz + j) auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc) + if (enc && !isSparseOutput(tid)) pidxs[tid][dim] = genAddress(builder, loc, tid, dim, min); } // NOTE: we can also prepares for next dim here in advance @@ -419,7 +420,7 @@ for (auto [tid, dim] : llvm::zip(tids, dims)) { assert(isDenseDLT(dimTypes[tid][dim])); auto enc = getSparseTensorEncoding(tensors[tid].getType()); - if (enc) { + if (enc && !isSparseOutput(tid)) { bool validPidx = dim == 0 || pidxs[tid][dim - 1]; if (!validPidx) { // We might not find the pidx for the sparse output tensor as it is diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -1130,13 +1130,13 @@ assert(all.test(b)); assert(merger.index(b) == idx); if (isUndefDLT(merger.getDimLevelType(b))) { - // This could be a synthetic tensor (for invariants and sparse output - // tensor). - // In both cases, we mean to generate loops over output tensor. - // e.g., - // out[i][j] = invariant; - if (merger.getSynTensorID() == tid) - tid = merger.getOutTensorID(); + // An undefined dlt in the lattices, we probably mean to iterate based + // on the dim of output tensor. + // E.g., this could be a synthetic tensor (for invariants and sparse + // output tensor). + // out[i][j] = invariant; or a broadcast + // out[i][j] = in[i] (j is undef for input) + tid = merger.getOutTensorID(); } auto dim = codegen.loopIdxToDim[tid][idx]; if (dim != INVALID_ID) { diff --git a/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_broadcast.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt %s --sparsification --canonicalize --cse | FileCheck %s + +#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> +#SparseTensor = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed", "compressed" ] }> + +#trait = { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)> + ], + iterator_types = ["parallel", "parallel", "parallel"] +} + +// CHECK-LABEL: @main( +// CHECK-SAME: %[[TMP_arg0:.*]]: tensor<4x5xi32, +// CHECK-DAG: %[[TMP_c3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[TMP_c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[TMP_c1:.*]] = arith.constant 1 : index +// CHECK: %[[TMP_0:.*]] = bufferization.alloc_tensor() +// CHECK: %[[TMP_1:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 0 : index} +// CHECK: %[[TMP_2:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 0 : index} +// CHECK: %[[TMP_3:.*]] = sparse_tensor.pointers %[[TMP_arg0]] {dimension = 1 : index} +// CHECK: %[[TMP_4:.*]] = sparse_tensor.indices %[[TMP_arg0]] {dimension = 1 : index} +// CHECK: %[[TMP_5:.*]] = sparse_tensor.values %[[TMP_arg0]] +// CHECK: %[[TMP_6:.*]] = memref.load %[[TMP_1]][%[[TMP_c0]]] : memref +// CHECK: %[[TMP_7:.*]] = memref.load %[[TMP_1]][%[[TMP_c1]]] : memref +// CHECK: scf.for %[[TMP_arg1:.*]] = %[[TMP_6]] to %[[TMP_7]] step %[[TMP_c1]] { +// CHECK: %[[TMP_9:.*]] = memref.load %[[TMP_2]][%[[TMP_arg1]]] : memref +// CHECK: scf.for %[[TMP_arg2:.*]] = %[[TMP_c0]] to %[[TMP_c3]] step %[[TMP_c1]] { +// CHECK: %[[TMP_10:.*]] = memref.load %[[TMP_3]][%[[TMP_arg1]]] : memref +// CHECK: %[[TMP_11:.*]] = arith.addi %[[TMP_arg1]], %[[TMP_c1]] : index +// CHECK: %[[TMP_12:.*]] = memref.load %[[TMP_3]][%[[TMP_11]]] : memref +// CHECK: scf.for %[[TMP_arg3:.*]] = %[[TMP_10]] to %[[TMP_12]] step %[[TMP_c1]] { +// CHECK: %[[TMP_13:.*]] = memref.load %[[TMP_4]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_14:.*]] = memref.load %[[TMP_5]][%[[TMP_arg3]]] : memref +// CHECK: %[[TMP_15:.*]] = sparse_tensor.insert %[[TMP_14]] into %[[TMP_0]][%[[TMP_9]], %[[TMP_arg2]], %[[TMP_13]]] +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[TMP_8:.*]] = sparse_tensor.load %[[TMP_0]] hasInserts +// CHECK: return %[[TMP_8]] +module @func_sparse { + func.func public @main(%arg0: tensor<4x5xi32, #DCSR>) -> tensor<4x3x5xi32, #SparseTensor> { + %0 = bufferization.alloc_tensor() : tensor<4x3x5xi32, #SparseTensor> + %1 = linalg.generic #trait + ins(%arg0 : tensor<4x5xi32, #DCSR>) outs(%0 : tensor<4x3x5xi32, #SparseTensor>) { + ^bb0(%in: i32, %out: i32): + linalg.yield %in : i32 + } -> tensor<4x3x5xi32, #SparseTensor> + return %1 : tensor<4x3x5xi32, #SparseTensor> + } +}