diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1147,13 +1147,20 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", [SingleBlockImplicitTerminator<"YieldOp">]>, Arguments<(ins AnyTensor:$tensor, - Variadic:$initArgs)>, + Variadic:$initArgs, + OptionalAttr:$order)>, Results<(outs Variadic:$results)> { let summary = "Iterates over elements in a tensor"; let description = [{ Iterates over stored elements in a tensor (which are typically, but not always, non-zero for sparse tensors) and executes the block. + `tensor`: the input tensor to iterate over. + `initArgs`: the initial loop argument to carry and update during each iteration. + `order`: an optional permutation affine map that specifies the order in which + the dimensions are visited (e.g., row first or column first). This is only + applicable when the input tensor is a non-annotated dense tensor. + For an input tensor with rank n, the block must take n + 1 (and additional loop carried variables as described below) arguments. The first n arguments must be Index type, together indicating the current coordinates of the element being visited. @@ -1208,15 +1215,33 @@ // [%row, %col] -> [0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1] } + // foreach on a row-major dense tensor but visit column first + sparse_tensor.foreach in %0 {order=affine_map<(i,j)->(j,i)>}: tensor<2x3xf64> do { + ^bb0(%row: index, %col: index, %arg3: f64): + // [%row, %col] -> [0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1] + } + ``` }]; let builders = [ - OpBuilder<(ins "Value":$tensor, + OpBuilder<(ins "Value":$tensor, "ValueRange":$iterArgs, "AffineMapAttr":$order, "function_ref")>, + OpBuilder<(ins "Value":$tensor, "AffineMapAttr":$order, + "function_ref":$bodyBuilder), + [{ + build($_builder, $_state, tensor, ValueRange(), order, bodyBuilder); + }]>, OpBuilder<(ins "Value":$tensor, - "ValueRange":$iterArgs, - "function_ref")> + "function_ref":$bodyBuilder), + [{ + build($_builder, $_state, tensor, ValueRange(), nullptr, bodyBuilder); + }]>, + OpBuilder<(ins "Value":$tensor, "ValueRange":$iterArgs, + "function_ref":$bodyBuilder), + [{ + build($_builder, $_state, tensor, iterArgs, nullptr, bodyBuilder); + }]> ]; let regions = (region SizedRegion<1>:$region); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -982,17 +982,10 @@ void ForeachOp::build( OpBuilder &builder, OperationState &result, Value tensor, + ValueRange initArgs, AffineMapAttr order, function_ref bodyBuilder) { - build(builder, result, tensor, std::nullopt, bodyBuilder); -} - -void ForeachOp::build( - OpBuilder &builder, OperationState &result, Value tensor, - ValueRange initArgs, - function_ref - bodyBuilder) { - build(builder, result, initArgs.getTypes(), tensor, initArgs); + build(builder, result, initArgs.getTypes(), tensor, initArgs, order); // Builds foreach body. if (!bodyBuilder) return; @@ -1023,6 +1016,10 @@ const Dimension dimRank = t.getDimRank(); const auto args = getBody()->getArguments(); + if (getOrder().has_value() && + (t.getEncoding() || !getOrder()->isPermutation())) + return emitError("Only support permuted order on non encoded dense tensor"); + if (static_cast(dimRank) + 1 + getInitArgs().size() != args.size()) return emitError("Unmatched number of arguments in the block"); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -216,7 +216,7 @@ /// callback({%c3}, %v3) void foreachInSparseConstant( Location loc, RewriterBase &rewriter, SparseElementsAttr attr, - function_ref, Value)> callback); + AffineMap order, function_ref, Value)> callback); /// Converts the vector indices and store it into the memory pointed by /// `ind`, apply (optional) `offset` on `offsetDim`. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -500,17 +500,46 @@ void sparse_tensor::foreachInSparseConstant( Location loc, RewriterBase &rewriter, SparseElementsAttr attr, - function_ref, Value)> callback) { - int64_t rank = attr.getType().getRank(); + AffineMap order, function_ref, Value)> callback) { + Dimension dimRank = getSparseTensorType(attr).getDimRank(); // Foreach on constant. DenseElementsAttr indicesAttr = attr.getIndices(); DenseElementsAttr valuesAttr = attr.getValues(); + SmallVector, Attribute>> cooV; + for (size_t i = 0, nse = valuesAttr.size(); i < nse; i++) { + cooV.emplace_back(); + for (Dimension j = 0; j < dimRank; j++) { + auto coordAttr = indicesAttr.getValues()[i * dimRank + j]; + cooV.back().first.push_back(coordAttr); + } + auto valAttr = valuesAttr.getValues()[i]; + cooV.back().second = valAttr; + } + + // Sorts the sparse element attribute based on coordinates. + std::sort(cooV.begin(), cooV.end(), + [order](const std::pair, Attribute> &lhs, + const std::pair, Attribute> &rhs) { + const SmallVectorImpl &lc = lhs.first; + const SmallVectorImpl &rc = rhs.first; + for (size_t i = 0, e = lc.size(); i < e; i++) { + auto l = + order + ? order.getResult(i).cast().getPosition() + : i; + if (lc[l].getInt() == rc[l].getInt()) + continue; + return lc[l].getInt() < rc[l].getInt(); + } + llvm_unreachable("no equal coordinate in sparse element attr"); + }); + SmallVector coords; - for (int i = 0, e = valuesAttr.size(); i < e; i++) { + for (size_t i = 0, nse = valuesAttr.size(); i < nse; i++) { coords.clear(); - for (int j = 0; j < rank; j++) { - auto coordAttr = indicesAttr.getValues()[i * rank + j]; + for (Dimension j = 0; j < dimRank; j++) { + auto coordAttr = cooV[i].first[j]; auto coord = rewriter.create(loc, coordAttr.getInt()); // Remaps coordinates. @@ -518,11 +547,11 @@ } Value val; if (attr.getElementType().isa()) { - auto valAttr = valuesAttr.getValues()[i]; + auto valAttr = cooV[i].second.cast(); val = rewriter.create(loc, attr.getElementType(), valAttr); } else { - auto valAttr = valuesAttr.getValues()[i]; + auto valAttr = cooV[i].second.cast(); // Remaps value. val = rewriter.create(loc, valAttr); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -158,7 +158,7 @@ // Foreach on constant. foreachInSparseConstant( - loc, rewriter, attr, + loc, rewriter, attr, op.getOrder().value_or(AffineMap()), [&reduc, &rewriter, op](ArrayRef coords, Value v) mutable { SmallVector args; args.append(coords.begin(), coords.end()); @@ -669,19 +669,16 @@ } const auto encDst = dstTp.getEncoding(); - // We don't need a temporary COO tensor if the destination has an identity - // ordering. Otherwise, we use the destination ordering for the temporary - // COO tensor. - // TODO: enhance foreachOp to take ordering to remove the need of a - // temporary COO tensor here. - const RankedTensorType bufferTp = dstTp.isIdentity() - ? dstTp.getRankedTensorType() - : getUnorderedCOOFromTypeWithOrdering( - dstTp, dstTp.getDimToLvlMap()); + // We don't need a temporary COO tensor for dense => sparse conversion. + const RankedTensorType bufferTp = dstTp.getRankedTensorType(); auto buffer = rewriter.create(loc, bufferTp, dynSizes).getResult(); + AffineMapAttr foreachOrder = nullptr; + if (encDst.getDimOrdering()) + foreachOrder = AffineMapAttr::get(encDst.getDimOrdering()); + auto foreachOp = rewriter.create( - loc, src, buffer, + loc, src, buffer, foreachOrder, [&](OpBuilder &builder, Location loc, ValueRange indices, Value v, ValueRange reduc) { Value input = reduc.front(); @@ -709,14 +706,8 @@ }); rewriter.setInsertionPointAfter(op); src = rewriter.create(loc, foreachOp.getResult(0), true); - if (bufferTp != dstTp) { - rewriter.replaceOpWithNewOp(op, dstTp.getRankedTensorType(), - src); - rewriter.create(loc, src); - } else { - rewriter.replaceOp(op, src); - } + rewriter.replaceOp(op, src); return success(); } @@ -928,16 +919,28 @@ for (Dimension d = 0; d < dimRank; d++) { // TODO: provide utility function for loop sequences that only contains // one for loop? - loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast(d)); + Dimension ld = + op.getOrder() + ? op.getOrder()->getResult(d).cast().getPosition() + : d; + loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast(ld)); // Note that reduc will be taken care of by loop emitter and get updated // in place. - loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, d, reduc); + + loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, ld, reduc); } SmallVector coords; coords.reserve(dimRank); loopEmitter.getCoordinateArray(coords); + if (op.getOrder()) { + SmallVector tmp = coords; // keep a copy + for (Dimension d = 0; d < dimRank; d++) { + auto l = op.getOrder()->getDimPosition(d); + coords[l] = tmp[d]; + } + } Value vals = loopEmitter.getValBuffer()[0]; Value pidx = loopEmitter.getPidxs()[0].back(); // Loads the value from sparse tensor using pointer index; diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -183,30 +183,17 @@ return %1 : tensor<8x7xf32, #CSR> } -// CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ], dimOrdering = affine_map<(d0, d1) -> (d1, d0)> }>> { -// CHECK-RWT: %[[F0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> -// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[F0]] init(%[[T0]]) -// CHECK-RWT: ^bb0(%[[L0I0:.*]]: index, %[[L0I1:.*]]: index, %[[L0V:.*]]: f32, %[[L0T:.*]]: tensor -// CHECK-RWT: %[[L0T2:.*]] = sparse_tensor.insert %[[L0V]] into %[[L0T]]{{\[}}%[[L0I1]], %[[L0I0]]] -// CHECK-RWT: sparse_tensor.yield %[[L0T2]] -// CHECK-RWT: } -// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[T1]] hasInserts -// CHECK-RWT: %[[NNZ:.*]] = sparse_tensor.number_of_entries %[[COO]] -// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[COO]] -// CHECK-RWT: %[[I:.*]] = sparse_tensor.indices_buffer %[[COO]] -// CHECK-RWT: sparse_tensor.sort_coo hybrid_quick_sort %[[NNZ]], %[[I]] jointly %[[V]] {nx = 2 : index, ny = 0 : index} -// CHECK-RWT: %[[T3:.*]] = bufferization.alloc_tensor() -// CHECK-RWT: %[[T4:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[T3]]) -// CHECK-RWT: ^bb0(%[[L1I0:.*]]: index, %[[L1I1:.*]]: index, %[[L1V:.*]]: f32, %[[L1T:.*]]: tensor -// CHECK-RWT: %[[L1T2:.*]] = sparse_tensor.insert %[[L1V]] into %[[L1T]]{{\[}}%[[L1I1]], %[[L1I0]]] -// CHECK-RWT: sparse_tensor.yield %[[L1T2]] -// CHECK-RWT: } -// CHECK-RWT: %[[T5:.*]] = sparse_tensor.load %[[T4]] hasInserts -// CHECK-RWT: %[[T6:.*]] = sparse_tensor.convert %[[T5]] -// CHECK-RWT: bufferization.dealloc_tensor %[[COO]] -// CHECK-RWT: return %[[T6]] -// CHECK-RWT: } +// CHECK-RWT-LABEL: func.func @sparse_constant_csc() -> tensor<8x7xf32, +// CHECK-RWT: %[[VAL_0:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32> +// CHECK-RWT: %[[VAL_1:.*]] = bufferization.alloc_tensor() : +// CHECK-RWT: %[[VAL_2:.*]] = sparse_tensor.foreach in %[[VAL_0]] init(%[[VAL_1]]) {order = #map} : tensor<8x7xf32>, +// CHECK-RWT: ^bb0(%[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: f32, %[[VAL_6:.*]]: tensor +// CHECK-RWT: %[[VAL_7:.*]] = sparse_tensor.insert %[[VAL_5]] into %[[VAL_6]]{{\[}}%[[VAL_4]], %[[VAL_3]]] : +// CHECK-RWT: sparse_tensor.yield %[[VAL_7]] : +// CHECK-RWT: } +// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.load %[[VAL_9:.*]] hasInserts : +// CHECK-RWT: return %[[VAL_8]] : +// CHECK-RWT: } func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{ // Initialize a tensor. %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32> diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_ptr.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_ptr.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_ptr.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conversion_ptr.mlir @@ -90,7 +90,7 @@ %c1 = arith.constant 1 : index %t1 = arith.constant sparse< [ [0,0], [0,1], [0,63], [1,0], [1,1], [31,0], [31,63] ], - [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 ]> : tensor<32x64xf64> + [ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0 ]> : tensor<32x64xf64> %t2 = tensor.cast %t1 : tensor<32x64xf64> to tensor // Dense to sparse.