diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -696,11 +696,49 @@ auto loc = op.getLoc(); Value input = op.getTensor(); + SmallVector reduc = op.getInitArgs(); auto rtp = input.getType().cast(); int64_t rank = rtp.getRank(); - auto enc = getSparseTensorEncoding(rtp); - SmallVector reduc = op.getInitArgs(); + // Special-case: for each over a sparse constant uses its own rewriting + // rule. + if (auto constOp = input.getDefiningOp()) { + if (auto attr = constOp.getValue().dyn_cast()) { + // Foreach on constant. + DenseElementsAttr indicesAttr = attr.getIndices(); + DenseElementsAttr valuesAttr = attr.getValues(); + + SmallVector args; + for (int i = 0, e = valuesAttr.size(); i < e; i++) { + auto valAttr = valuesAttr.getValues()[i]; + for (int j = 0; j < rank; j++) { + auto coordAttr = indicesAttr.getValues()[i * rank + j]; + auto coord = rewriter.create( + loc, coordAttr.getInt()); + // Remaps coordinates. + args.push_back(coord); + } + // Remaps value. + auto val = rewriter.create(loc, valAttr); + args.push_back(val); + // Remaps iteration args. + args.append(reduc); + auto cloned = cast(rewriter.clone(*op.getOperation())); + Operation *yield = cloned.getBody()->getTerminator(); + rewriter.mergeBlockBefore(cloned.getBody(), op, args); + // clean up + args.clear(); + rewriter.eraseOp(cloned); + reduc = yield->getOperands(); + rewriter.eraseOp(yield); + } + rewriter.replaceOp(op, reduc); + return success(); + } + } + + // Otherwise, use loop emitter to generate loops. + auto enc = getSparseTensorEncoding(rtp); // 1. Generates loop for the sparse input. SparseTensorLoopEmitter loopEmitter(ValueRange{input}); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_codegen_foreach.mlir @@ -26,6 +26,18 @@ }> module { + /// uses foreach operator to print coords and values. + func.func @foreach_print_const() { + // Initialize a tensor. + %0 = arith.constant sparse<[[0, 0], [1, 6]], [1.0, 5.0]> : tensor<8x7xf32> + sparse_tensor.foreach in %0 : tensor<8x7xf32> do { + ^bb0(%1: index, %2: index, %v: f32) : + vector.print %1: index + vector.print %2: index + vector.print %v: f32 + } + return + } /// uses foreach operator to print coords and values. func.func @foreach_print_1(%arg0: tensor<2x2xf64, #Row>) { @@ -111,6 +123,13 @@ // CHECK: 0 // CHECK-NEXT: 0 // CHECK-NEXT: 1 + // CHECK-NEXT: 1 + // CHECK-NEXT: 6 + // CHECK-NEXT: 5 + call @foreach_print_const() : () -> () + // CHECK-NEXT: 0 + // CHECK-NEXT: 0 + // CHECK-NEXT: 1 // CHECK-NEXT: 0 // CHECK-NEXT: 1 // CHECK-NEXT: 2