diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -815,7 +815,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach", [SingleBlockImplicitTerminator<"YieldOp">]>, - Arguments<(ins AnySparseTensor:$tensor)>{ + Arguments<(ins AnyTensor:$tensor)>{ let summary = "Iterates over non-zero elements in a sparse tensor"; let description = [{ Iterates over every non-zero element in the given sparse tensor and executes diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -479,14 +479,19 @@ for (int64_t i = 0; i < rank; i++) loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i); - Value vals = loopEmitter.getTensorValueBuffer(0); - Value idx = loopEmitter.getLastLevelTensorPointerIndex(0); - Value val = rewriter.create(op.getLoc(), vals, idx); - SmallVector coords; coords.reserve(rank); loopEmitter.getCoordinateArray(coords); + Value vals = loopEmitter.getTensorValueBuffer(0); + Value val; + if (enc) { + Value idx = loopEmitter.getLastLevelTensorPointerIndex(0); + val = rewriter.create(op.getLoc(), vals, idx); + } else { + val = rewriter.create(op.getLoc(), vals, coords); + } + for (int64_t i = 0; i < rank; i++) loopEmitter.exitCurrentLoop(); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir @@ -78,6 +78,16 @@ return } + func.func @foreach_print_dense(%arg0: tensor<2x2xf64>) { + sparse_tensor.foreach in %arg0 : tensor<2x2xf64> do { + ^bb0(%1: index, %2: index, %v: f64) : + vector.print %1: index + vector.print %2: index + vector.print %v: f64 + } + return + } + // // Main driver. // @@ -109,6 +119,19 @@ // CHECK-NEXT: 5 // CHECK-NEXT: 1 // CHECK-NEXT: 1 + // CHECK-NEXT: 6 + call @foreach_print_dense(%src) : (tensor<2x2xf64>) -> () + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 + // CHECK-NEXT: 2 + // CHECK-NEXT: 1 + // CHECK-NEXT: 0 + // CHECK-NEXT: 5 + // CHECK-NEXT: 1 + // CHECK-NEXT: 1 // CHECK-NEXT: 6 call @foreach_print_1(%s1) : (tensor<2x2xf64, #Row>) -> () // CHECK-NEXT: 0