diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h @@ -78,6 +78,9 @@ /// initializing the loop emitter (e.g., to fill a dense output with zeros). using OutputUpdater = function_ref; + using SynTensorBoundSetter = + function_ref; + // Map from [tid, dim] to a list of dependent [tid, dim] for affine expression // index on sparse tensors. // E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines @@ -114,7 +117,8 @@ /// Starts a loop emitting session by generating all the buffers needed /// for iterating over the tensors. void initializeLoopEmit(OpBuilder &builder, Location loc, - OutputUpdater updater = nullptr); + OutputUpdater updater = nullptr, + SynTensorBoundSetter synSetter = nullptr); /// Generates code to compute an affine expression whose variables are /// `LoopId`s (i.e., `a.cast().getPosition()` is a valid diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -376,8 +376,15 @@ loopIdToOrd[topSort[n]] = n; } -void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc, - LoopEmitter::OutputUpdater updater) { +void LoopEmitter::initializeLoopEmit( + OpBuilder &builder, Location loc, LoopEmitter::OutputUpdater updater, + LoopEmitter::SynTensorBoundSetter synSetter) { + + // For every synthetic tensor, set the high bound by calling the callback. + if (synSetter) + for (unsigned i = 0, e = highs[getSynTensorId()].size(); i < e; i++) + highs[getSynTensorId()][i] = synSetter(builder, loc, i); + // For every manifest tensor: // * get the values buffer. // * For every level: @@ -534,27 +541,17 @@ // Prepares for all the tensors used in the current loop sequence. std::vector> slicedTids; - bool hasSynTensor = false; - std::optional> loopBoundDefLevel = std::nullopt; for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) { if (!dependentLvlMap[tid][lvl].empty()) { bool fullyRed = genSliceBegin(builder, loc, tid, lvl); slicedTids.emplace_back(tid, lvl, fullyRed); } else { - if (isSynTensor(tid)) { - hasSynTensor = true; - } else { - loopBoundDefLevel = std::make_pair(tid, lvl); + if (!isSynTensor(tid)) { prepareLoopOverTensorAtLvl(builder, loc, tid, lvl); } } } - if (hasSynTensor && loopBoundDefLevel.has_value()) { - // TODO: compute the loopBound for index reduction by d - sum(unres_lvls). - highs[getSynTensorId()][getCurrentDepth()] = - lvlSizes[loopBoundDefLevel->first][loopBoundDefLevel->second]; - } // Universal Index starts from 0. loopSeqStack.emplace_back(C_IDX(0), std::move(slicedTids)); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -832,6 +832,21 @@ Location loc = op.getLoc(); assert(op.getNumOperands() == op.getNumDpsInputs() + 1); + SmallVector loopRange = + llvm::cast(op.getOperation()) + .createLoopRanges(builder, loc); + + assert(loopRange.size() == env.merger().getStartingFilterLoopId()); + SmallVector sortedRange; + for (unsigned i = 0, e = env.topSortSize(); i < e; i++) { + LoopId ldx = env.topSortAt(i); + // FIXME: Gets rid of filter loops since we have a better algorithm to deal + // with affine index expression. + if (ldx < env.merger().getStartingFilterLoopId()) { + sortedRange.push_back(loopRange[ldx]); + } + } + env.emitter().initializeLoopEmit( builder, loc, /// Generates buffer for the output tensor. @@ -865,6 +880,16 @@ ValueRange{init}); } return init; + }, + [&sortedRange, &env](OpBuilder &b, Location loc, Level l) { + assert(l < env.topSortSize()); + // FIXME: Remove filter loop since we have a better algorithm to + // deal with affine index expression. + if (l >= env.merger().getStartingFilterLoopId()) + return Value(); + + return mlir::getValueOrCreateConstantIndexOp(b, loc, + sortedRange[l].size); }); } @@ -1587,7 +1612,9 @@ // iterate based on the level of output tensor. E.g., this // could be a synthetic tensor (for invariants and sparse // output tensor). - if (env.isReduc() && env.merger().getSynTensorID() == tid) { + auto itType = env.op().getIteratorTypesArray()[ldx]; + if (linalg::isReductionIterator(itType) && + env.merger().getSynTensorID() == tid) { // Coiterating with an invariant, and this is a reduction loop // e.g., out = prod(in[i][j] op invariant); // In this case, we can not infer the loop bound from output @@ -1662,7 +1689,14 @@ tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl)); } - assert(numloopCond > 0); + if (numloopCond == 0) { + // Corner cases where the loop bound is defined by a *unused* operand, in + // this case, we just generate a dense "fake" loop by iterating over the + // synthetic tensor. + tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(), + env.emitter().getCurrentDepth())); + numloopCond++; + } // If we just need to one loop conditions and the conditions is not imposed on // non-unique level, the loop can be generated by a for loop. return numloopCond == 1 && !hasNonUnique; diff --git a/mlir/test/Dialect/SparseTensor/sampled_dense_dense_matmul_with_spy.mlir b/mlir/test/Dialect/SparseTensor/sampled_dense_dense_matmul_with_spy.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sampled_dense_dense_matmul_with_spy.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt %s -sparsification --canonicalize --cse | FileCheck %s + +#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }> +#trait_sampled_dense_dense = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,j)>, // S + affine_map<(i,j,k) -> (i,k)>, // A + affine_map<(i,j,k) -> (k,j)>, // B + affine_map<(i,j,k) -> (i,j)> // S_out + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "S_out(i,j) = spy[S(i,j)] x SUM_k A(i,k) B(k,j)" +} + +// CHECK-LABEL: func.func @sparse_sampled_dd( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant false +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true +// CHECK-DAG: %[[VAL_8:.*]] = bufferization.alloc_tensor() : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK-DAG: %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK-DAG: %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_13:.*]] = %[[VAL_8]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>) { +// CHECK: %[[VAL_14:.*]], %[[VAL_15:.*]], %[[VAL_16:.*]], %[[VAL_17:.*]] = sparse_tensor.expand %[[VAL_8]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref, memref, memref +// CHECK: %[[VAL_18:.*]] = scf.for %[[VAL_19:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_20:.*]] = %[[VAL_17]]) -> (index) { +// CHECK: %[[VAL_21:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_12]], %[[VAL_19]]] : tensor<8x8xf64> +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_12]], %[[VAL_5]] : index +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_23]]] : memref +// CHECK: %[[VAL_25:.*]] = scf.for %[[VAL_26:.*]] = %[[VAL_22]] to %[[VAL_24]] step %[[VAL_5]] iter_args(%[[VAL_27:.*]] = %[[VAL_20]]) -> (index) { +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_26]]] : memref +// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_2]]{{\[}}%[[VAL_19]], %[[VAL_28]]] : tensor<8x8xf64> +// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_21]], %[[VAL_29]] : f64 +// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_28]]] : memref +// CHECK: %[[VAL_32:.*]] = arith.cmpi eq, %[[VAL_31]], %[[VAL_6]] : i1 +// CHECK: %[[VAL_33:.*]] = scf.if %[[VAL_32]] -> (index) { +// CHECK: memref.store %[[VAL_7]], %[[VAL_15]]{{\[}}%[[VAL_28]]] : memref +// CHECK: memref.store %[[VAL_28]], %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_34:.*]] = arith.addi %[[VAL_27]], %[[VAL_5]] : index +// CHECK: scf.yield %[[VAL_34]] : index +// CHECK: } else { +// CHECK: scf.yield %[[VAL_27]] : index +// CHECK: } +// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_28]]] : memref +// CHECK: scf.yield %[[VAL_35:.*]] : index +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: scf.yield %[[VAL_36:.*]] : index +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_37:.*]] = sparse_tensor.compress %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[VAL_38:.*]] into %[[VAL_13]]{{\[}}%[[VAL_12]]] : memref, memref, memref, tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: scf.yield %[[VAL_37]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } {"Emitted from" = "linalg.generic"} +// CHECK: %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_40:.*]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: return %[[VAL_39]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +func.func @sparse_sampled_dd(%argS: tensor<8x8xf64, #SM>, + %argA: tensor<8x8xf64>, + %argB: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { + %f0 = arith.constant 0.0 : f64 + %f1 = arith.constant 1.0 : f64 + %init = bufferization.alloc_tensor() : tensor<8x8xf64, #SM> + %result = linalg.generic #trait_sampled_dense_dense + ins(%argS, %argA, %argB: tensor<8x8xf64, #SM>, tensor<8x8xf64>, tensor<8x8xf64>) + outs(%init: tensor<8x8xf64, #SM>) { + ^bb(%s: f64, %a: f64, %b: f64, %x: f64): + // We only care whether %s is present or not, but not the actually value in it. + %exist = sparse_tensor.unary %s : f64 to f64 + present={ + ^bb0(%p: f64): + sparse_tensor.yield %f1 : f64 + } + absent={} + %mul = arith.mulf %a, %b : f64 + %2 = sparse_tensor.reduce %mul, %exist, %f0 : f64 { + ^bb0(%m: f64, %e: f64): + %sel = arith.mulf %m, %e : f64 // should fold + sparse_tensor.yield %sel : f64 + } + %add = arith.addf %x, %2 : f64 + linalg.yield %2 : f64 + } -> tensor<8x8xf64, #SM> + return %result : tensor<8x8xf64, #SM> +}