diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -176,11 +176,27 @@ return call.getResult(0); } +/// Generates the comparison `v != 0` where `v` is of numeric type `t`. +/// For floating types, we use the "unordered" comparator (i.e., returns +/// true if `v` is NaN). +static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc, + Type t, Value v) { + Value zero = rewriter.create(loc, rewriter.getZeroAttr(t)); + if (t.isa()) + return rewriter.create(loc, CmpFPredicate::UNE, v, zero); + if (t.isIntOrIndex()) + return rewriter.create(loc, CmpIPredicate::ne, v, zero); + llvm_unreachable("Unknown element type"); +} + /// Generates a call that adds one element to a coordinate scheme. +/// In particular, this generates code like the following: +/// val = a[i1,..,ik]; +/// if val != 0 +/// t->add(val, [i1,..,ik], [p1,..,pk]); static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op, Value ptr, Value tensor, Value ind, Value perm, ValueRange ivs) { - Location loc = op->getLoc(); StringRef name; Type eltType = tensor.getType().cast().getElementType(); if (eltType.isF64()) @@ -197,8 +213,11 @@ name = "addEltI8"; else llvm_unreachable("Unknown element type"); + Location loc = op->getLoc(); Value val = rewriter.create(loc, tensor, ivs); - // TODO: add if here? + Value cond = genIsNonzero(rewriter, loc, eltType, val); + scf::IfOp ifOp = rewriter.create(loc, cond, /*else*/ false); + rewriter.setInsertionPointToStart(&ifOp.thenRegion().front()); unsigned i = 0; for (auto iv : ivs) { Value idx = rewriter.create(loc, rewriter.getIndexAttr(i++)); @@ -313,6 +332,9 @@ // Note that the dense tensor traversal code is actually implemented // using MLIR IR to avoid having to expose too much low-level // memref traversal details to the runtime support library. + // Also note that the code below only generates the "new" ops and + // the loop-nest per se; whereas the entire body of the innermost + // loop is generated by genAddElt(). Location loc = op->getLoc(); ShapedType shape = resType.cast(); auto memTp = diff --git a/mlir/lib/ExecutionEngine/SparseUtils.cpp b/mlir/lib/ExecutionEngine/SparseUtils.cpp --- a/mlir/lib/ExecutionEngine/SparseUtils.cpp +++ b/mlir/lib/ExecutionEngine/SparseUtils.cpp @@ -552,8 +552,6 @@ uint64_t pstride) { \ assert(istride == 1 && pstride == 1 && isize == psize); \ uint64_t *indx = idata + ioff; \ - if (!value) \ - return tensor; \ uint64_t *perm = pdata + poff; \ std::vector indices(isize); \ for (uint64_t r = 0; r < isize; r++) \