diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -170,6 +170,55 @@ } } +static LogicalResult genForeachOnSparseConstant(ForeachOp op, + RewriterBase &rewriter, + SparseElementsAttr attr) { + auto loc = op.getLoc(); + Value input = op.getTensor(); + SmallVector reduc = op.getInitArgs(); + auto rtp = input.getType().cast(); + int64_t rank = rtp.getRank(); + + // Foreach on constant. + DenseElementsAttr indicesAttr = attr.getIndices(); + DenseElementsAttr valuesAttr = attr.getValues(); + + SmallVector args; + for (int i = 0, e = valuesAttr.size(); i < e; i++) { + for (int j = 0; j < rank; j++) { + auto coordAttr = indicesAttr.getValues()[i * rank + j]; + auto coord = + rewriter.create(loc, coordAttr.getInt()); + // Remaps coordinates. + args.push_back(coord); + } + Value val; + if (rtp.getElementType().isa()) { + auto valAttr = valuesAttr.getValues()[i]; + val = rewriter.create(loc, rtp.getElementType(), + valAttr); + } else { + auto valAttr = valuesAttr.getValues()[i]; + // Remaps value. + val = rewriter.create(loc, valAttr); + } + assert(val); + args.push_back(val); + // Remaps iteration args. + args.append(reduc); + auto cloned = cast(rewriter.clone(*op.getOperation())); + Operation *yield = cloned.getBody()->getTerminator(); + rewriter.mergeBlockBefore(cloned.getBody(), op, args); + // clean up + args.clear(); + rewriter.eraseOp(cloned); + reduc = yield->getOperands(); + rewriter.eraseOp(yield); + } + rewriter.replaceOp(op, reduc); + return success(); +} + //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// @@ -752,36 +801,7 @@ // rule. if (auto constOp = input.getDefiningOp()) { if (auto attr = constOp.getValue().dyn_cast()) { - // Foreach on constant. - DenseElementsAttr indicesAttr = attr.getIndices(); - DenseElementsAttr valuesAttr = attr.getValues(); - - SmallVector args; - for (int i = 0, e = valuesAttr.size(); i < e; i++) { - auto valAttr = valuesAttr.getValues()[i]; - for (int j = 0; j < rank; j++) { - auto coordAttr = indicesAttr.getValues()[i * rank + j]; - auto coord = rewriter.create( - loc, coordAttr.getInt()); - // Remaps coordinates. - args.push_back(coord); - } - // Remaps value. - auto val = rewriter.create(loc, valAttr); - args.push_back(val); - // Remaps iteration args. - args.append(reduc); - auto cloned = cast(rewriter.clone(*op.getOperation())); - Operation *yield = cloned.getBody()->getTerminator(); - rewriter.mergeBlockBefore(cloned.getBody(), op, args); - // clean up - args.clear(); - rewriter.eraseOp(cloned); - reduc = yield->getOperands(); - rewriter.eraseOp(yield); - } - rewriter.replaceOp(op, reduc); - return success(); + return genForeachOnSparseConstant(op, rewriter, attr); } } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir @@ -1,8 +1,15 @@ -// RUN: mlir-opt %s --sparse-compiler | \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = enable-runtime-library=false +// RUN: %{command} #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>