diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -78,6 +78,12 @@ /// initializing the loop emitter (e.g., to fill a dense output with zeros). using OutputUpdater = function_ref; + + /// Optional callback function to set the bound for the synthetic tensor, + /// which essentially is the dense loop bound. + using SynTensorBoundSetter = + function_ref; + // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression // index on sparse tensors. // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines @@ -114,7 +120,8 @@ /// Starts a loop emitting session by generating all the buffers needed /// for iterating over the tensors. void initializeLoopEmit(OpBuilder &builder, Location loc, - OutputUpdater updater = nullptr); + OutputUpdater updater = nullptr, + SynTensorBoundSetter synSetter = nullptr); /// Generates code to compute an affine expression whose variables are /// `LoopId`s (i.e., `a.cast().getPosition()` is a valid diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -376,8 +376,15 @@ loopIdToOrd[topSort[n]] = n; } -void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, - LoopEmitter::OutputUpdater updater) { +void LoopEmitter::initializeLoopEmit( + OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, + LoopEmitter::SynTensorBoundSetter synSetter) { + + // For every synthetic tensor, set the high bound by calling the callback. + if (synSetter) + for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++) + highs[getSynTensorId()][i] = synSetter(builder, loc, i); + // For every manifest tensor: // * get the values buffer. // * For every level: @@ -534,27 +541,15 @@ // Prepares for all the tensors used in the current loop sequence. std::vector> slicedTids; - bool hasSynTensor = false; - std::optional> loopBoundDefLevel = std::nullopt; for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { if (!dependentLvlMap[tid][lvl].empty()) { bool fullyRed = genSliceBegin(builder, loc, tid, lvl); slicedTids.emplace_back(tid, lvl, fullyRed); - } else { - if (isSynTensor(tid)) { - hasSynTensor = true; - } else { - loopBoundDefLevel = std::make_pair(tid, lvl); - prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); - } + } else if (!isSynTensor(tid)) { + prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); } } - if (hasSynTensor && loopBoundDefLevel.has_value()) { - // TODO: compute the loopBound for index reduction by d - sum(unres_lvls). - highs[getSynTensorId()][getCurrentDepth()] = - lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second]; - } // Universal Index starts from 0. loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -832,6 +832,21 @@ Location loc = op.getLoc(); assert(op.getNumOperands() == op.getNumDpsInputs() + 1); + SmallVector loopRange = + llvm::cast(op.getOperation()) + .createLoopRanges(builder, loc); + + assert(loopRange.size() == env.merger().getStartingFilterLoopId()); + SmallVector sortedRange; + for (unsigned i = 0, e = env.topSortSize(); i < e; i++) { + LoopId ldx = env.topSortAt(i); + // FIXME: Gets rid of filter loops since we have a better algorithm to deal + // with affine index expression. + if (ldx < env.merger().getStartingFilterLoopId()) { + sortedRange.push_back(loopRange[ldx]); + } + } + env.emitter().initializeLoopEmit( builder, loc, /// Generates buffer for the output tensor. @@ -865,6 +880,16 @@ ValueRange{init}); } return init; + }, + [&sortedRange, &env](OpBuilder &b, Location loc, Level l) { + assert(l < env.topSortSize()); + // FIXME: Remove filter loop since we have a better algorithm to + // deal with affine index expression. + if (l >= env.merger().getStartingFilterLoopId()) + return Value(); + + return mlir::getValueOrCreateConstantIndexOp(b, loc, + sortedRange[l].size); }); } @@ -1594,7 +1619,9 @@ // iterate based on the level of output tensor. E.g., this // could be a synthetic tensor (for invariants and sparse // output tensor). - if (env.isReduc() && env.merger().getSynTensorID() == tid) { + auto itType = env.op().getIteratorTypesArray()[ldx]; + if (linalg::isReductionIterator(itType) && + env.merger().getSynTensorID() == tid) { // Coiterating with an invariant, and this is a reduction loop // e.g., out = prod(in[i][j] op invariant); // In this case, we can not infer the loop bound from output @@ -1669,7 +1696,14 @@ tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl)); } - assert(numloopCond > 0); + if (numloopCond == 0) { + // Corner cases where the loop bound is defined by a *unused* operand, in + // this case, we just generate a dense "fake" loop by iterating over the + // synthetic tensor. + tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(), + env.emitter().getCurrentDepth())); + numloopCond++; + } // If we just need to one loop conditions and the conditions is not imposed on // non-unique level, the loop can be generated by a for loop. return numloopCond == 1 && !hasNonUnique; diff --git a/mlir/test/Dialect/SparseTensor/unused-tensor.mlir b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/unused-tensor.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s -sparsification | FileCheck %s + +// +// A contrived example where the sparse tensor B is only +// used in the linalg op to determine the number of iterations +// for the k-loop. This is included to make sure the sparse +// compiler still generates the correct loop nest for this case. +// + +#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }> + +#trait = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j)>, // A + affine_map<(i,j,k) -> (k,j)>, // B + affine_map<(i,j,k) -> (i,j)> // S_out + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "C(i,j) = SUM_k A(i,j)" +} + +// CHECK-LABEL: func.func @b_ununsed( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x4xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<2x4xf64>) -> tensor<2x4xf64> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_8:.*]] = bufferization.to_memref %[[VAL_0]] : memref<2x4xf64> +// CHECK-DAG: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_2]] : memref<2x4xf64> +// CHECK: scf.for %[[VAL_10:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_7]] { +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_7]] { +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_5]] step %[[VAL_7]] { +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64> +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64> +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f64 +// CHECK: memref.store %[[VAL_15]], %[[VAL_9]]{{\[}}%[[VAL_10]], %[[VAL_12]]] : memref<2x4xf64> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_9]] : memref<2x4xf64> +// CHECK: return %[[VAL_16]] : tensor<2x4xf64> +// CHECK: } +func.func @b_ununsed(%argA: tensor<2x4xf64>, + %argB: tensor<8x4xf64, #SM>, + %argC: tensor<2x4xf64>) -> tensor<2x4xf64> { + %result = linalg.generic #trait + ins(%argA, %argB: tensor<2x4xf64>, tensor<8x4xf64, #SM>) + outs(%argC: tensor<2x4xf64>) { + ^bb(%a: f64, %b: f64, %c: f64): + %0 = arith.addf %c, %a : f64 + linalg.yield %0 : f64 + } -> tensor<2x4xf64> + return %result : tensor<2x4xf64> +}