diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -857,21 +857,44 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", [SingleBlockImplicitTerminator<"YieldOp">]>, - Arguments<(ins AnyTensor:$tensor)>{ + Arguments<(ins AnyTensor:$tensor, + Variadic:$initArgs)>, + Results<(outs Variadic:$results)> { let summary = "Iterates over elements in a tensor"; let description = [{ Iterates over stored elements in a tensor (which are typically, but not always, non-zero for sparse tensors) and executes the block. - For an input tensor with rank n, the block must take n + 1 arguments. The - first n arguments must be Index type, together indicating the current coordinates - of the element being visited. The last argument must have the same type as the + For an input tensor with rank n, the block must take n + 1 (and additional loop + carried variables as described below) arguments. The first n arguments must be + Index type, together indicating the current coordinates of the element being visited. + The last argument must have the same type as the tensor's element type, representing the actual value loaded from the input tensor at the given coordinates. - Note that foreach generated loop iterates over the stored elements in the storage - order. However, no matter what storage order is used, the indices passed to the block - always obey the original dimension order. + `sparse_tensor.foreach` can also operate on loop-carried variables and returns + the final values after loop termination. The initial values of the variables are + passed as additional SSA operands to the "sparse_tensor.foreach" following the n + 1 + SSA values mentioned above (n coordinate and 1 value). + + The region must terminate with a "sparse_tensor.yield" that passes the current + values of all loop-carried variables to the next iteration, or to the + result, if at the last iteration. The number and static types of loop-carried + variables may not change with iterations. + + For example: + ```mlir + %c0 = arith.constant 0 : i32 + %ret = sparse_tensor.foreach in %0 init(%c0): tensor, i32 -> i32 do { + ^bb0(%arg1: index, %arg2: index, %arg3: i32, %iter: i32): + %sum = arith.add %iter, %arg3 + sparse_tensor.yield %sum + } + ``` + + It is important to note that foreach generated loop iterates over the stored elements + in the storage order. However, no matter what storage order is used, the indices passed + to the block always obey the original dimension order. For example: ```mlir @@ -879,10 +902,10 @@ dimLevelType = [ "compressed", "compressed" ], dimOrdering = affine_map<(i,j) -> (j,i)> }> - + // foreach on a column-major sparse tensor sparse_tensor.foreach in %0 : tensor<2x3xf64, #COL_MAJOR> do { - ^bb0(%row: index, %col: index, %arg3: f64): + ^bb0(%row: index, %col: index, %arg3: f64): // [%row, %col] -> [0, 0], [1, 0], [2, 0], [0, 1], [1, 1], [2, 1] } @@ -892,30 +915,25 @@ // foreach on a row-major sparse tensor sparse_tensor.foreach in %0 : tensor<2x3xf64, #ROW_MAJOR> do { - ^bb0(%row: index, %col: index, %arg3: f64): + ^bb0(%row: index, %col: index, %arg3: f64): // [%row, %col] -> [0, 0], [0, 1], [1, 0], [1, 1], [2, 0], [2, 1] } ``` - - Example: - - ```mlir - sparse_tensor.foreach in %0 : tensor do { - ^bb0(%arg1: index, %arg2: index, %arg3: f64): - do something... - } - ``` }]; let builders = [ - OpBuilder<( - ins "Value":$tensor, - "function_ref")> + OpBuilder<(ins "Value":$tensor, + "function_ref")>, + OpBuilder<(ins "Value":$tensor, + "ValueRange":$iterArgs, + "function_ref")> ]; - let regions = (region AnyRegion:$region); - let assemblyFormat = "`in` $tensor attr-dict `:` type($tensor) `do` $region"; + let regions = (region SizedRegion<1>:$region); + let assemblyFormat = "`in` $tensor (`init``(`$initArgs^`)`)? attr-dict" + " `:` type($tensor) (`,` type($initArgs)^)?" + " (`->` type($results)^)? `do` $region"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -581,11 +581,20 @@ void ForeachOp::build( OpBuilder &builder, OperationState &result, Value tensor, - function_ref bodyBuilder) { - build(builder, result, tensor); + function_ref + bodyBuilder) { + build(builder, result, tensor, llvm::None, bodyBuilder); +} + +void ForeachOp::build( + OpBuilder &builder, OperationState &result, Value tensor, + ValueRange initArgs, + function_ref + bodyBuilder) { + build(builder, result, initArgs.getTypes(), tensor, initArgs); + // Builds foreach body. if (!bodyBuilder) return; - auto rtp = tensor.getType().cast(); int64_t rank = rtp.getRank(); @@ -602,23 +611,38 @@ auto ®ion = *result.regions.front(); Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); - bodyBuilder(builder, result.location, bodyBlock->getArguments()); + bodyBuilder(builder, result.location, + bodyBlock->getArguments().slice(0, rank), + bodyBlock->getArguments()[rank], + bodyBlock->getArguments().drop_front(rank + 1)); } LogicalResult ForeachOp::verify() { auto t = getTensor().getType().cast(); auto args = getBody()->getArguments(); - if (static_cast(t.getRank()) + 1 != args.size()) + if (static_cast(t.getRank()) + 1 + getInitArgs().size() != + args.size()) return emitError("Unmatched number of arguments in the block"); + if (getNumResults() != getInitArgs().size()) + return emitError("Mismatch in number of init arguments and results"); + + if (getResultTypes() != getInitArgs().getTypes()) + return emitError("Mismatch in types of init arguments and results"); + + auto yield = cast(getBody()->getTerminator()); + if (yield.getNumOperands() != getNumResults() || + yield.getOperands().getTypes() != getResultTypes()) + return emitError("Mismatch in types of yield values and results"); + for (int64_t i = 0, e = t.getRank(); i < e; i++) if (args[i].getType() != IndexType::get(getContext())) emitError( llvm::formatv("Expecting Index type for argument at index {0}", i)); auto elemTp = t.getElementType(); - auto valueTp = args.back().getType(); + auto valueTp = args[t.getRank()].getType(); if (elemTp != valueTp) emitError(llvm::formatv("Unmatched element type between input tensor and " "block argument, expected:{0}, got: {1}", diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -357,7 +357,9 @@ auto cooBuffer = rewriter.create(loc, cooTp, dstDynSizes).getResult(); rewriter.create( - loc, srcTensor, [&](OpBuilder &builder, Location loc, ValueRange args) { + loc, srcTensor, llvm::None, + [&](OpBuilder &builder, Location loc, ValueRange args, Value v, + ValueRange reduc) { SmallVector srcIndices; SmallVector dstIndices; for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { @@ -366,7 +368,7 @@ } translateIndicesArray(builder, loc, op.getReassociationIndices(), srcIndices, srcSizes, dstSizes, dstIndices); - builder.create(loc, args.back(), cooBuffer, dstIndices); + builder.create(loc, v, cooBuffer, dstIndices); builder.create(loc); }); @@ -446,7 +448,9 @@ // Build a for op for each input tensor to append new values into the // output tensor. rewriter.create( - loc, input, [&](OpBuilder &builder, Location loc, ValueRange args) { + loc, input, llvm::None, + [&](OpBuilder &builder, Location loc, ValueRange args, Value v, + ValueRange reduc) { SmallVector indices; for (int64_t i = 0; i < rank; i++) { uint64_t dim = @@ -457,7 +461,7 @@ idx = builder.create(loc, idx, offset); indices.push_back(idx); } - builder.create(loc, args.back(), cooBuffer, indices); + builder.create(loc, v, cooBuffer, indices); builder.create(loc); }); // Accumulates the offset. Note that only static-shaped inputs are allowed @@ -558,12 +562,13 @@ sizesForTensor(rewriter, sizes, loc, srcTp, src); Value dst = allocDenseTensor(rewriter, loc, dstTp, sizes); - rewriter.create( - loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) { - builder.create(loc, args.back(), dst, - args.drop_back()); - builder.create(loc); - }); + rewriter.create(loc, src, llvm::None, + [&](OpBuilder &builder, Location loc, + ValueRange args, Value v, ValueRange reduc) { + builder.create(loc, v, dst, + args); + builder.create(loc); + }); rewriter.replaceOpWithNewOp(op, dstTp, dst); return success(); @@ -598,13 +603,15 @@ tmpCoo = rewriter.create(loc, srcTp, dynSrcSizes).getResult(); rewriter.create( - loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) { + loc, src, llvm::None, + [&](OpBuilder &builder, Location loc, ValueRange args, Value v, + ValueRange reduc) { SmallVector indices; for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { uint64_t dim = toStoredDim(encSrc, i); indices.push_back(args[dim]); } - builder.create(loc, args.back(), tmpCoo, indices); + builder.create(loc, v, tmpCoo, indices); builder.create(loc); }); src = tmpCoo; @@ -646,16 +653,18 @@ getDynamicSizes(dstTp, srcSizes, dynDstSizes); Value dst = rewriter.create(loc, dstTp, dynDstSizes).getResult(); - rewriter.create( - loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) { - SmallVector indices; - for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) { - uint64_t dim = toStoredDim(encDst, i); - indices.push_back(args[dim]); - } - builder.create(loc, args.back(), dst, indices); - builder.create(loc); - }); + rewriter.create(loc, src, llvm::None, + [&](OpBuilder &builder, Location loc, + ValueRange args, Value v, ValueRange reduc) { + SmallVector indices; + for (int64_t i = 0, e = srcTp.getRank(); i < e; + i++) { + uint64_t dim = toStoredDim(encDst, i); + indices.push_back(args[dim]); + } + builder.create(loc, v, dst, indices); + builder.create(loc); + }); // Release the temporary COO if it is created. if (tmpCoo) @@ -866,12 +875,14 @@ ModuleOp module = op->getParentOfType(); // For each element in the source tensor, output the element. rewriter.create( - loc, src, [&](OpBuilder &builder, Location loc, ValueRange args) { + loc, src, llvm::None, + [&](OpBuilder &builder, Location loc, ValueRange args, Value v, + ValueRange reduc) { for (uint64_t i = 0; i < rank; i++) { rewriter.create(loc, args[i], indices, constantIndex(builder, loc, i)); } - rewriter.create(loc, args.back(), value); + rewriter.create(loc, v, value); SmallVector operands{writer, rankValue, indices, value}; FlatSymbolRefAttr fn = getFunc(module, outNextFuncName, {}, operands, EmitCInterface::On); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -551,6 +551,51 @@ // ----- +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>) -> () { + // expected-error@+1 {{Unmatched element type between input tensor and block argument}} + sparse_tensor.foreach in %arg0 : tensor<2x4xf64, #DCSR> do { + ^bb0(%1: index, %2: index, %v: f32) : + } + return +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () { + // expected-error@+1 {{Mismatch in number of init arguments and results}} + sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 do { + ^bb0(%1: index, %2: index, %v: f32, %r1 : i32) : + } + return +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () { + // expected-error@+1 {{Mismatch in types of init arguments and results}} + %1 = sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 -> i32 do { + ^bb0(%1: index, %2: index, %v: f32, %r0 : f32) : + } + return +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () { + // expected-error@+1 {{Mismatch in types of yield values and results}} + %1 = sparse_tensor.foreach in %arg0 init(%arg1) : tensor<2x4xf64, #DCSR>, f32 -> f32 do { + ^bb0(%1: index, %2: index, %v: f32, %r0 : f32) : + sparse_tensor.yield %1 : index + } + return +} + +// ----- + // TODO: a test case with empty xs doesn't work due to some parser issues. func.func @sparse_sort_x_type( %arg0: index, %arg1: memref) { diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -411,6 +411,26 @@ return } +// ----- + +#DCSR = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_tensor_foreach( +// CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[A1:.*]]: f32 +// CHECK-NEXT: %[[RET:.*]] = sparse_tensor.foreach in %[[A0]] init(%[[A1]]) +// CHECK-NEXT: ^bb0(%[[TMP_1:.*]]: index, %[[TMP_2:.*]]: index, %[[TMP_v:.*]]: f64, %[[TMP_r:.*]]: f32) +// CHECK: sparse_tensor.yield %[[TMP_r]] : f32 +// CHECK: } +func.func @sparse_tensor_foreach(%arg0: tensor<2x4xf64, #DCSR>, %arg1: f32) -> () { + %ret = sparse_tensor.foreach in %arg0 init(%arg1): tensor<2x4xf64, #DCSR>, f32 -> f32 + do { + ^bb0(%1: index, %2: index, %v: f64, %r: f32) : + sparse_tensor.yield %r : f32 + } + return +} + // ---- // CHECK-LABEL: func @sparse_sort_1d0v(