diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -834,16 +834,16 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", [SingleBlockImplicitTerminator<"YieldOp">]>, - Arguments<(ins AnySparseTensor:$tensor)>{ - let summary = "Iterates over non-zero elements in a sparse tensor"; + Arguments<(ins AnyTensor:$tensor)>{ + let summary = "Iterates over elements in a tensor"; let description = [{ - Iterates over every non-zero element in the given sparse tensor and executes - the block. + Iterates over stored elements in a tensor (which are typically, but not always, + non-zero for sparse tensors) and executes the block. - For a input sparse tensor with rank n, the block must take n + 1 arguments. The + For an input tensor with rank n, the block must take n + 1 arguments. The first n arguments must be Index type, together indicating the current coordinates of the element being visited. The last argument must have the same type as the - sparse tensor's element type, representing the actual value loaded from the input + tensor's element type, representing the actual value loaded from the input tensor at the given coordinates. Example: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -480,14 +480,17 @@ for (int64_t i = 0; i < rank; i++) loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i); - Value vals = loopEmitter.getTensorValueBuffer(0); - Value idx = loopEmitter.getLastLevelTensorPointerIndex(0); - Value val = rewriter.create(op.getLoc(), vals, idx); - SmallVector coords; coords.reserve(rank); loopEmitter.getCoordinateArray(coords); + Value vals = loopEmitter.getTensorValueBuffer(0); + Value pidx = loopEmitter.getLastLevelTensorPointerIndex(0); + // Loads the value from sparse tensor using pointer index; + // loads the value from dense tensor using coordinate array. + Value val = enc ? rewriter.create(loc, vals, pidx) + : rewriter.create(loc, vals, coords); + for (int64_t i = 0; i < rank; i++) loopEmitter.exitCurrentLoop(); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir @@ -78,6 +78,16 @@ return } + func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) { + sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + // // Main driver. // @@ -109,6 +119,19 @@ // CHECK-NEXT: 5 // CHECK-NEXT: 1 // CHECK-NEXT: 1 + // CHECK-NEXT: 6 + call @foreach_print_dense(%src) : (tensor<2x2xf64>) -> () + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1 // CHECK-NEXT: 6 call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> () // CHECK-NEXT: 0