diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -815,16 +815,16 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", [SingleBlockImplicitTerminator<"YieldOp">]>, - Arguments<(ins AnySparseTensor:$tensor)>{ - let summary = "Iterates over non-zero elements in a sparse tensor"; + Arguments<(ins AnyTensor:$tensor)>{ + let summary = "Iterates over elements in a tensor"; let description = [{ - Iterates over every non-zero element in the given sparse tensor and executes - the block. + Iterates over every stored value (typically, but not always, non-zero for sparse + tensor) in the given tensor and executes the block. - For a input sparse tensor with rank n, the block must take n + 1 arguments. The + For an input tensor with rank n, the block must take n + 1 arguments. The first n arguments must be Index type, together indicating the current coordinates of the element being visited. The last argument must have the same type as the - sparse tensor's element type, representing the actual value loaded from the input + tensor's element type, representing the actual value loaded from the input tensor at the given coordinates. Example: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -479,14 +479,19 @@ for (int64_t i = 0; i < rank; i++) loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i); - Value vals = loopEmitter.getTensorValueBuffer(0); - Value idx = loopEmitter.getLastLevelTensorPointerIndex(0); - Value val = rewriter.create(op.getLoc(), vals, idx); - SmallVector coords; coords.reserve(rank); loopEmitter.getCoordinateArray(coords); + Value vals = loopEmitter.getTensorValueBuffer(0); + Value val; + if (enc) { + Value idx = loopEmitter.getLastLevelTensorPointerIndex(0); + val = rewriter.create(op.getLoc(), vals, idx); + } else { + val = rewriter.create(op.getLoc(), vals, coords); + } + for (int64_t i = 0; i < rank; i++) loopEmitter.exitCurrentLoop(); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir @@ -78,6 +78,16 @@ return } + func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) { + sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + // // Main driver. // @@ -109,6 +119,19 @@ // CHECK-NEXT: 5 // CHECK-NEXT: 1 // CHECK-NEXT: 1 + // CHECK-NEXT: 6 + call @foreach_print_dense(%src) : (tensor<2x2xf64>) -> () + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1 // CHECK-NEXT: 6 call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> () // CHECK-NEXT: 0