diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h @@ -183,6 +183,26 @@ /// Scans to top of generated loop. Operation *getTop(Operation *op); +/// Iterate over a sparse constant, generates constantOp for value and indices. +/// E.g., +/// sparse<[ [0], [28], [31] ], +/// [ (-5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > +/// => +/// %c1 = arith.constant 0 +/// %v1 = complex.constant (5.13, 2.0) +/// callback({%c1}, %v1) +/// +/// %c2 = arith.constant 28 +/// %v2 = complex.constant (3.0, 4.0) +/// callback({%c2}, %v2) +/// +/// %c3 = arith.constant 31 +/// %v3 = complex.constant (5.0, 6.0) +/// callback({%c3}, %v3) +void foreachInSparseConstant( + Location loc, RewriterBase &rewriter, SparseElementsAttr attr, + function_ref, Value)> callback); + //===----------------------------------------------------------------------===// // Inlined constant generators. // @@ -197,9 +217,9 @@ //===----------------------------------------------------------------------===// /// Generates a 0-valued constant of the given type. In addition to -/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, `IntegerType`), -/// this also works for `RankedTensorType` and `VectorType` (for which it -/// generates a constant `DenseElementsAttr` of zeros). +/// the scalar types (`ComplexType`, ``FloatType`, `IndexType`, +/// `IntegerType`), this also works for `RankedTensorType` and `VectorType` +/// (for which it generates a constant `DenseElementsAttr` of zeros). inline Value constantZero(OpBuilder &builder, Location loc, Type tp) { if (auto ctp = tp.dyn_cast()) { auto zeroe = builder.getZeroAttr(ctp.getElementType()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -1024,3 +1024,36 @@ ; return op; } + +void sparse_tensor::foreachInSparseConstant( + Location loc, RewriterBase &rewriter, SparseElementsAttr attr, + function_ref, Value)> callback) { + int64_t rank = attr.getType().getRank(); + // Foreach on constant. + DenseElementsAttr indicesAttr = attr.getIndices(); + DenseElementsAttr valuesAttr = attr.getValues(); + + SmallVector coords; + for (int i = 0, e = valuesAttr.size(); i < e; i++) { + coords.clear(); + for (int j = 0; j < rank; j++) { + auto coordAttr = indicesAttr.getValues()[i * rank + j]; + auto coord = + rewriter.create(loc, coordAttr.getInt()); + // Remaps coordinates. + coords.push_back(coord); + } + Value val; + if (attr.getElementType().isa()) { + auto valAttr = valuesAttr.getValues()[i]; + val = rewriter.create(loc, attr.getElementType(), + valAttr); + } else { + auto valAttr = valuesAttr.getValues()[i]; + // Remaps value. + val = rewriter.create(loc, valAttr); + } + assert(val); + callback(coords, val); + } +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -170,6 +170,35 @@ } } +static LogicalResult genForeachOnSparseConstant(ForeachOp op, + RewriterBase &rewriter, + SparseElementsAttr attr) { + auto loc = op.getLoc(); + SmallVector reduc = op.getInitArgs(); + + // Foreach on constant. + foreachInSparseConstant( + loc, rewriter, attr, + [&reduc, &rewriter, op](ArrayRef coords, Value v) mutable { + SmallVector args; + args.append(coords.begin(), coords.end()); + args.push_back(v); + args.append(reduc); + // Clones the foreach op to get a copy of the loop body. + auto cloned = cast(rewriter.clone(*op.getOperation())); + assert(args.size() == cloned.getBody()->getNumArguments()); + Operation *yield = cloned.getBody()->getTerminator(); + rewriter.mergeBlockBefore(cloned.getBody(), op, args); + // clean up + rewriter.eraseOp(cloned); + reduc = yield->getOperands(); + rewriter.eraseOp(yield); + }); + + rewriter.replaceOp(op, reduc); + return success(); +} + //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// @@ -752,36 +781,7 @@ // rule. if (auto constOp = input.getDefiningOp()) { if (auto attr = constOp.getValue().dyn_cast()) { - // Foreach on constant. - DenseElementsAttr indicesAttr = attr.getIndices(); - DenseElementsAttr valuesAttr = attr.getValues(); - - SmallVector args; - for (int i = 0, e = valuesAttr.size(); i < e; i++) { - auto valAttr = valuesAttr.getValues()[i]; - for (int j = 0; j < rank; j++) { - auto coordAttr = indicesAttr.getValues()[i * rank + j]; - auto coord = rewriter.create( - loc, coordAttr.getInt()); - // Remaps coordinates. - args.push_back(coord); - } - // Remaps value. - auto val = rewriter.create(loc, valAttr); - args.push_back(val); - // Remaps iteration args. - args.append(reduc); - auto cloned = cast(rewriter.clone(*op.getOperation())); - Operation *yield = cloned.getBody()->getTerminator(); - rewriter.mergeBlockBefore(cloned.getBody(), op, args); - // clean up - args.clear(); - rewriter.eraseOp(cloned); - reduc = yield->getOperands(); - rewriter.eraseOp(yield); - } - rewriter.replaceOp(op, reduc); - return success(); + return genForeachOnSparseConstant(op, rewriter, attr); } } diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_complex_ops.mlir @@ -1,8 +1,15 @@ -// RUN: mlir-opt %s --sparse-compiler | \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = enable-runtime-library=false +// RUN: %{command} #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>